1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http_client::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
50 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(update_dev_server_project))
435 .add_request_handler(user_handler(delete_dev_server_project))
436 .add_request_handler(user_handler(create_dev_server))
437 .add_request_handler(user_handler(regenerate_dev_server_token))
438 .add_request_handler(user_handler(rename_dev_server))
439 .add_request_handler(user_handler(delete_dev_server))
440 .add_request_handler(user_handler(list_remote_directory))
441 .add_request_handler(dev_server_handler(share_dev_server_project))
442 .add_request_handler(dev_server_handler(shutdown_dev_server))
443 .add_request_handler(dev_server_handler(reconnect_dev_server))
444 .add_message_handler(user_message_handler(leave_project))
445 .add_request_handler(update_project)
446 .add_request_handler(update_worktree)
447 .add_message_handler(start_language_server)
448 .add_message_handler(update_language_server)
449 .add_message_handler(update_diagnostic_summary)
450 .add_message_handler(update_worktree_settings)
451 .add_request_handler(user_handler(
452 forward_project_request_for_owner::<proto::TaskContextForLocation>,
453 ))
454 .add_request_handler(user_handler(
455 forward_project_request_for_owner::<proto::TaskTemplates>,
456 ))
457 .add_request_handler(user_handler(
458 forward_read_only_project_request::<proto::GetHover>,
459 ))
460 .add_request_handler(user_handler(
461 forward_read_only_project_request::<proto::GetDefinition>,
462 ))
463 .add_request_handler(user_handler(
464 forward_read_only_project_request::<proto::GetTypeDefinition>,
465 ))
466 .add_request_handler(user_handler(
467 forward_read_only_project_request::<proto::GetReferences>,
468 ))
469 .add_request_handler(user_handler(
470 forward_read_only_project_request::<proto::SearchProject>,
471 ))
472 .add_request_handler(user_handler(
473 forward_read_only_project_request::<proto::GetDocumentHighlights>,
474 ))
475 .add_request_handler(user_handler(
476 forward_read_only_project_request::<proto::GetProjectSymbols>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::OpenBufferById>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::SynchronizeBuffers>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::InlayHints>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::OpenBufferByPath>,
492 ))
493 .add_request_handler(user_handler(
494 forward_mutating_project_request::<proto::GetCompletions>,
495 ))
496 .add_request_handler(user_handler(
497 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
498 ))
499 .add_request_handler(user_handler(
500 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
501 ))
502 .add_request_handler(user_handler(
503 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::GetCodeActions>,
507 ))
508 .add_request_handler(user_handler(
509 forward_mutating_project_request::<proto::ApplyCodeAction>,
510 ))
511 .add_request_handler(user_handler(
512 forward_mutating_project_request::<proto::PrepareRename>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::PerformRename>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ReloadBuffers>,
519 ))
520 .add_request_handler(user_handler(
521 forward_mutating_project_request::<proto::FormatBuffers>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::CreateProjectEntry>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::RenameProjectEntry>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::CopyProjectEntry>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::DeleteProjectEntry>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::ExpandProjectEntry>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::OnTypeFormatting>,
540 ))
541 .add_request_handler(user_handler(
542 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::BlameBuffer>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::MultiLspQuery>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::RestartLanguageServers>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::LinkedEditingRange>,
555 ))
556 .add_message_handler(create_buffer_for_peer)
557 .add_request_handler(update_buffer)
558 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
559 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
560 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
561 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
562 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
563 .add_request_handler(get_users)
564 .add_request_handler(user_handler(fuzzy_search_users))
565 .add_request_handler(user_handler(request_contact))
566 .add_request_handler(user_handler(remove_contact))
567 .add_request_handler(user_handler(respond_to_contact_request))
568 .add_message_handler(subscribe_to_channels)
569 .add_request_handler(user_handler(create_channel))
570 .add_request_handler(user_handler(delete_channel))
571 .add_request_handler(user_handler(invite_channel_member))
572 .add_request_handler(user_handler(remove_channel_member))
573 .add_request_handler(user_handler(set_channel_member_role))
574 .add_request_handler(user_handler(set_channel_visibility))
575 .add_request_handler(user_handler(rename_channel))
576 .add_request_handler(user_handler(join_channel_buffer))
577 .add_request_handler(user_handler(leave_channel_buffer))
578 .add_message_handler(user_message_handler(update_channel_buffer))
579 .add_request_handler(user_handler(rejoin_channel_buffers))
580 .add_request_handler(user_handler(get_channel_members))
581 .add_request_handler(user_handler(respond_to_channel_invite))
582 .add_request_handler(user_handler(join_channel))
583 .add_request_handler(user_handler(join_channel_chat))
584 .add_message_handler(user_message_handler(leave_channel_chat))
585 .add_request_handler(user_handler(send_channel_message))
586 .add_request_handler(user_handler(remove_channel_message))
587 .add_request_handler(user_handler(update_channel_message))
588 .add_request_handler(user_handler(get_channel_messages))
589 .add_request_handler(user_handler(get_channel_messages_by_id))
590 .add_request_handler(user_handler(get_notifications))
591 .add_request_handler(user_handler(mark_notification_as_read))
592 .add_request_handler(user_handler(move_channel))
593 .add_request_handler(user_handler(follow))
594 .add_message_handler(user_message_handler(unfollow))
595 .add_message_handler(user_message_handler(update_followers))
596 .add_request_handler(user_handler(get_private_user_info))
597 .add_message_handler(user_message_handler(acknowledge_channel_message))
598 .add_message_handler(user_message_handler(acknowledge_buffer_version))
599 .add_request_handler(user_handler(get_supermaven_api_key))
600 .add_request_handler(user_handler(
601 forward_mutating_project_request::<proto::OpenContext>,
602 ))
603 .add_request_handler(user_handler(
604 forward_mutating_project_request::<proto::SynchronizeContexts>,
605 ))
606 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
607 .add_message_handler(update_context)
608 .add_streaming_request_handler({
609 let app_state = app_state.clone();
610 move |request, response, session| {
611 complete_with_language_model(
612 request,
613 response,
614 session,
615 app_state.config.openai_api_key.clone(),
616 app_state.config.google_ai_api_key.clone(),
617 app_state.config.anthropic_api_key.clone(),
618 )
619 }
620 })
621 .add_request_handler({
622 let app_state = app_state.clone();
623 user_handler(move |request, response, session| {
624 count_tokens_with_language_model(
625 request,
626 response,
627 session,
628 app_state.config.google_ai_api_key.clone(),
629 )
630 })
631 })
632 .add_request_handler({
633 user_handler(move |request, response, session| {
634 get_cached_embeddings(request, response, session)
635 })
636 })
637 .add_request_handler({
638 let app_state = app_state.clone();
639 user_handler(move |request, response, session| {
640 compute_embeddings(
641 request,
642 response,
643 session,
644 app_state.config.openai_api_key.clone(),
645 )
646 })
647 });
648
649 Arc::new(server)
650 }
651
652 pub async fn start(&self) -> Result<()> {
653 let server_id = *self.id.lock();
654 let app_state = self.app_state.clone();
655 let peer = self.peer.clone();
656 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
657 let pool = self.connection_pool.clone();
658 let live_kit_client = self.app_state.live_kit_client.clone();
659
660 let span = info_span!("start server");
661 self.app_state.executor.spawn_detached(
662 async move {
663 tracing::info!("waiting for cleanup timeout");
664 timeout.await;
665 tracing::info!("cleanup timeout expired, retrieving stale rooms");
666 if let Some((room_ids, channel_ids)) = app_state
667 .db
668 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
669 .await
670 .trace_err()
671 {
672 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
673 tracing::info!(
674 stale_channel_buffer_count = channel_ids.len(),
675 "retrieved stale channel buffers"
676 );
677
678 for channel_id in channel_ids {
679 if let Some(refreshed_channel_buffer) = app_state
680 .db
681 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
682 .await
683 .trace_err()
684 {
685 for connection_id in refreshed_channel_buffer.connection_ids {
686 peer.send(
687 connection_id,
688 proto::UpdateChannelBufferCollaborators {
689 channel_id: channel_id.to_proto(),
690 collaborators: refreshed_channel_buffer
691 .collaborators
692 .clone(),
693 },
694 )
695 .trace_err();
696 }
697 }
698 }
699
700 for room_id in room_ids {
701 let mut contacts_to_update = HashSet::default();
702 let mut canceled_calls_to_user_ids = Vec::new();
703 let mut live_kit_room = String::new();
704 let mut delete_live_kit_room = false;
705
706 if let Some(mut refreshed_room) = app_state
707 .db
708 .clear_stale_room_participants(room_id, server_id)
709 .await
710 .trace_err()
711 {
712 tracing::info!(
713 room_id = room_id.0,
714 new_participant_count = refreshed_room.room.participants.len(),
715 "refreshed room"
716 );
717 room_updated(&refreshed_room.room, &peer);
718 if let Some(channel) = refreshed_room.channel.as_ref() {
719 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
720 }
721 contacts_to_update
722 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
723 contacts_to_update
724 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
725 canceled_calls_to_user_ids =
726 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
727 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
728 delete_live_kit_room = refreshed_room.room.participants.is_empty();
729 }
730
731 {
732 let pool = pool.lock();
733 for canceled_user_id in canceled_calls_to_user_ids {
734 for connection_id in pool.user_connection_ids(canceled_user_id) {
735 peer.send(
736 connection_id,
737 proto::CallCanceled {
738 room_id: room_id.to_proto(),
739 },
740 )
741 .trace_err();
742 }
743 }
744 }
745
746 for user_id in contacts_to_update {
747 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
748 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
749 if let Some((busy, contacts)) = busy.zip(contacts) {
750 let pool = pool.lock();
751 let updated_contact = contact_for_user(user_id, busy, &pool);
752 for contact in contacts {
753 if let db::Contact::Accepted {
754 user_id: contact_user_id,
755 ..
756 } = contact
757 {
758 for contact_conn_id in
759 pool.user_connection_ids(contact_user_id)
760 {
761 peer.send(
762 contact_conn_id,
763 proto::UpdateContacts {
764 contacts: vec![updated_contact.clone()],
765 remove_contacts: Default::default(),
766 incoming_requests: Default::default(),
767 remove_incoming_requests: Default::default(),
768 outgoing_requests: Default::default(),
769 remove_outgoing_requests: Default::default(),
770 },
771 )
772 .trace_err();
773 }
774 }
775 }
776 }
777 }
778
779 if let Some(live_kit) = live_kit_client.as_ref() {
780 if delete_live_kit_room {
781 live_kit.delete_room(live_kit_room).await.trace_err();
782 }
783 }
784 }
785 }
786
787 app_state
788 .db
789 .delete_stale_servers(&app_state.config.zed_environment, server_id)
790 .await
791 .trace_err();
792 }
793 .instrument(span),
794 );
795 Ok(())
796 }
797
798 pub fn teardown(&self) {
799 self.peer.teardown();
800 self.connection_pool.lock().reset();
801 let _ = self.teardown.send(true);
802 }
803
804 #[cfg(test)]
805 pub fn reset(&self, id: ServerId) {
806 self.teardown();
807 *self.id.lock() = id;
808 self.peer.reset(id.0 as u32);
809 let _ = self.teardown.send(false);
810 }
811
812 #[cfg(test)]
813 pub fn id(&self) -> ServerId {
814 *self.id.lock()
815 }
816
817 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
818 where
819 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
820 Fut: 'static + Send + Future<Output = Result<()>>,
821 M: EnvelopedMessage,
822 {
823 let prev_handler = self.handlers.insert(
824 TypeId::of::<M>(),
825 Box::new(move |envelope, session| {
826 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
827 let received_at = envelope.received_at;
828 tracing::info!("message received");
829 let start_time = Instant::now();
830 let future = (handler)(*envelope, session);
831 async move {
832 let result = future.await;
833 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
834 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
835 let queue_duration_ms = total_duration_ms - processing_duration_ms;
836 let payload_type = M::NAME;
837
838 match result {
839 Err(error) => {
840 tracing::error!(
841 ?error,
842 total_duration_ms,
843 processing_duration_ms,
844 queue_duration_ms,
845 payload_type,
846 "error handling message"
847 )
848 }
849 Ok(()) => tracing::info!(
850 total_duration_ms,
851 processing_duration_ms,
852 queue_duration_ms,
853 "finished handling message"
854 ),
855 }
856 }
857 .boxed()
858 }),
859 );
860 if prev_handler.is_some() {
861 panic!("registered a handler for the same message twice");
862 }
863 self
864 }
865
866 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
867 where
868 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
869 Fut: 'static + Send + Future<Output = Result<()>>,
870 M: EnvelopedMessage,
871 {
872 self.add_handler(move |envelope, session| handler(envelope.payload, session));
873 self
874 }
875
876 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
877 where
878 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
879 Fut: Send + Future<Output = Result<()>>,
880 M: RequestMessage,
881 {
882 let handler = Arc::new(handler);
883 self.add_handler(move |envelope, session| {
884 let receipt = envelope.receipt();
885 let handler = handler.clone();
886 async move {
887 let peer = session.peer.clone();
888 let responded = Arc::new(AtomicBool::default());
889 let response = Response {
890 peer: peer.clone(),
891 responded: responded.clone(),
892 receipt,
893 };
894 match (handler)(envelope.payload, response, session).await {
895 Ok(()) => {
896 if responded.load(std::sync::atomic::Ordering::SeqCst) {
897 Ok(())
898 } else {
899 Err(anyhow!("handler did not send a response"))?
900 }
901 }
902 Err(error) => {
903 let proto_err = match &error {
904 Error::Internal(err) => err.to_proto(),
905 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
906 };
907 peer.respond_with_error(receipt, proto_err)?;
908 Err(error)
909 }
910 }
911 }
912 })
913 }
914
915 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
916 where
917 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
918 Fut: Send + Future<Output = Result<()>>,
919 M: RequestMessage,
920 {
921 let handler = Arc::new(handler);
922 self.add_handler(move |envelope, session| {
923 let receipt = envelope.receipt();
924 let handler = handler.clone();
925 async move {
926 let peer = session.peer.clone();
927 let response = StreamingResponse {
928 peer: peer.clone(),
929 receipt,
930 };
931 match (handler)(envelope.payload, response, session).await {
932 Ok(()) => {
933 peer.end_stream(receipt)?;
934 Ok(())
935 }
936 Err(error) => {
937 let proto_err = match &error {
938 Error::Internal(err) => err.to_proto(),
939 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
940 };
941 peer.respond_with_error(receipt, proto_err)?;
942 Err(error)
943 }
944 }
945 }
946 })
947 }
948
949 #[allow(clippy::too_many_arguments)]
950 pub fn handle_connection(
951 self: &Arc<Self>,
952 connection: Connection,
953 address: String,
954 principal: Principal,
955 zed_version: ZedVersion,
956 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
957 executor: Executor,
958 ) -> impl Future<Output = ()> {
959 let this = self.clone();
960 let span = info_span!("handle connection", %address,
961 connection_id=field::Empty,
962 user_id=field::Empty,
963 login=field::Empty,
964 impersonator=field::Empty,
965 dev_server_id=field::Empty
966 );
967 principal.update_span(&span);
968
969 let mut teardown = self.teardown.subscribe();
970 async move {
971 if *teardown.borrow() {
972 tracing::error!("server is tearing down");
973 return
974 }
975 let (connection_id, handle_io, mut incoming_rx) = this
976 .peer
977 .add_connection(connection, {
978 let executor = executor.clone();
979 move |duration| executor.sleep(duration)
980 });
981 tracing::Span::current().record("connection_id", format!("{}", connection_id));
982 tracing::info!("connection opened");
983
984 let http_client = match IsahcHttpClient::new() {
985 Ok(http_client) => Arc::new(http_client),
986 Err(error) => {
987 tracing::error!(?error, "failed to create HTTP client");
988 return;
989 }
990 };
991
992 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
993 Some(Arc::new(SupermavenAdminApi::new(
994 supermaven_admin_api_key.to_string(),
995 http_client.clone(),
996 )))
997 } else {
998 None
999 };
1000
1001 let session = Session {
1002 principal: principal.clone(),
1003 connection_id,
1004 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1005 peer: this.peer.clone(),
1006 connection_pool: this.connection_pool.clone(),
1007 live_kit_client: this.app_state.live_kit_client.clone(),
1008 http_client,
1009 rate_limiter: this.app_state.rate_limiter.clone(),
1010 _executor: executor.clone(),
1011 supermaven_client,
1012 };
1013
1014 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1015 tracing::error!(?error, "failed to send initial client update");
1016 return;
1017 }
1018
1019 let handle_io = handle_io.fuse();
1020 futures::pin_mut!(handle_io);
1021
1022 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1023 // This prevents deadlocks when e.g., client A performs a request to client B and
1024 // client B performs a request to client A. If both clients stop processing further
1025 // messages until their respective request completes, they won't have a chance to
1026 // respond to the other client's request and cause a deadlock.
1027 //
1028 // This arrangement ensures we will attempt to process earlier messages first, but fall
1029 // back to processing messages arrived later in the spirit of making progress.
1030 let mut foreground_message_handlers = FuturesUnordered::new();
1031 let concurrent_handlers = Arc::new(Semaphore::new(256));
1032 loop {
1033 let next_message = async {
1034 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1035 let message = incoming_rx.next().await;
1036 (permit, message)
1037 }.fuse();
1038 futures::pin_mut!(next_message);
1039 futures::select_biased! {
1040 _ = teardown.changed().fuse() => return,
1041 result = handle_io => {
1042 if let Err(error) = result {
1043 tracing::error!(?error, "error handling I/O");
1044 }
1045 break;
1046 }
1047 _ = foreground_message_handlers.next() => {}
1048 next_message = next_message => {
1049 let (permit, message) = next_message;
1050 if let Some(message) = message {
1051 let type_name = message.payload_type_name();
1052 // note: we copy all the fields from the parent span so we can query them in the logs.
1053 // (https://github.com/tokio-rs/tracing/issues/2670).
1054 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1055 user_id=field::Empty,
1056 login=field::Empty,
1057 impersonator=field::Empty,
1058 dev_server_id=field::Empty
1059 );
1060 principal.update_span(&span);
1061 let span_enter = span.enter();
1062 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1063 let is_background = message.is_background();
1064 let handle_message = (handler)(message, session.clone());
1065 drop(span_enter);
1066
1067 let handle_message = async move {
1068 handle_message.await;
1069 drop(permit);
1070 }.instrument(span);
1071 if is_background {
1072 executor.spawn_detached(handle_message);
1073 } else {
1074 foreground_message_handlers.push(handle_message);
1075 }
1076 } else {
1077 tracing::error!("no message handler");
1078 }
1079 } else {
1080 tracing::info!("connection closed");
1081 break;
1082 }
1083 }
1084 }
1085 }
1086
1087 drop(foreground_message_handlers);
1088 tracing::info!("signing out");
1089 if let Err(error) = connection_lost(session, teardown, executor).await {
1090 tracing::error!(?error, "error signing out");
1091 }
1092
1093 }.instrument(span)
1094 }
1095
1096 async fn send_initial_client_update(
1097 &self,
1098 connection_id: ConnectionId,
1099 principal: &Principal,
1100 zed_version: ZedVersion,
1101 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1102 session: &Session,
1103 ) -> Result<()> {
1104 self.peer.send(
1105 connection_id,
1106 proto::Hello {
1107 peer_id: Some(connection_id.into()),
1108 },
1109 )?;
1110 tracing::info!("sent hello message");
1111 if let Some(send_connection_id) = send_connection_id.take() {
1112 let _ = send_connection_id.send(connection_id);
1113 }
1114
1115 match principal {
1116 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1117 if !user.connected_once {
1118 self.peer.send(connection_id, proto::ShowContacts {})?;
1119 self.app_state
1120 .db
1121 .set_user_connected_once(user.id, true)
1122 .await?;
1123 }
1124
1125 let (contacts, dev_server_projects) = future::try_join(
1126 self.app_state.db.get_contacts(user.id),
1127 self.app_state.db.dev_server_projects_update(user.id),
1128 )
1129 .await?;
1130
1131 {
1132 let mut pool = self.connection_pool.lock();
1133 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1134 self.peer.send(
1135 connection_id,
1136 build_initial_contacts_update(contacts, &pool),
1137 )?;
1138 }
1139
1140 if should_auto_subscribe_to_channels(zed_version) {
1141 subscribe_user_to_channels(user.id, session).await?;
1142 }
1143
1144 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1145
1146 if let Some(incoming_call) =
1147 self.app_state.db.incoming_call_for_user(user.id).await?
1148 {
1149 self.peer.send(connection_id, incoming_call)?;
1150 }
1151
1152 update_user_contacts(user.id, &session).await?;
1153 }
1154 Principal::DevServer(dev_server) => {
1155 {
1156 let mut pool = self.connection_pool.lock();
1157 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1158 {
1159 self.peer.send(
1160 stale_connection_id,
1161 proto::ShutdownDevServer {
1162 reason: Some(
1163 "another dev server connected with the same token".to_string(),
1164 ),
1165 },
1166 )?;
1167 pool.remove_connection(stale_connection_id)?;
1168 };
1169 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1170 }
1171
1172 let projects = self
1173 .app_state
1174 .db
1175 .get_projects_for_dev_server(dev_server.id)
1176 .await?;
1177 self.peer
1178 .send(connection_id, proto::DevServerInstructions { projects })?;
1179
1180 let status = self
1181 .app_state
1182 .db
1183 .dev_server_projects_update(dev_server.user_id)
1184 .await?;
1185 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1186 }
1187 }
1188
1189 Ok(())
1190 }
1191
1192 pub async fn invite_code_redeemed(
1193 self: &Arc<Self>,
1194 inviter_id: UserId,
1195 invitee_id: UserId,
1196 ) -> Result<()> {
1197 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1198 if let Some(code) = &user.invite_code {
1199 let pool = self.connection_pool.lock();
1200 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1201 for connection_id in pool.user_connection_ids(inviter_id) {
1202 self.peer.send(
1203 connection_id,
1204 proto::UpdateContacts {
1205 contacts: vec![invitee_contact.clone()],
1206 ..Default::default()
1207 },
1208 )?;
1209 self.peer.send(
1210 connection_id,
1211 proto::UpdateInviteInfo {
1212 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1213 count: user.invite_count as u32,
1214 },
1215 )?;
1216 }
1217 }
1218 }
1219 Ok(())
1220 }
1221
1222 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1223 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1224 if let Some(invite_code) = &user.invite_code {
1225 let pool = self.connection_pool.lock();
1226 for connection_id in pool.user_connection_ids(user_id) {
1227 self.peer.send(
1228 connection_id,
1229 proto::UpdateInviteInfo {
1230 url: format!(
1231 "{}{}",
1232 self.app_state.config.invite_link_prefix, invite_code
1233 ),
1234 count: user.invite_count as u32,
1235 },
1236 )?;
1237 }
1238 }
1239 }
1240 Ok(())
1241 }
1242
1243 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1244 ServerSnapshot {
1245 connection_pool: ConnectionPoolGuard {
1246 guard: self.connection_pool.lock(),
1247 _not_send: PhantomData,
1248 },
1249 peer: &self.peer,
1250 }
1251 }
1252}
1253
1254impl<'a> Deref for ConnectionPoolGuard<'a> {
1255 type Target = ConnectionPool;
1256
1257 fn deref(&self) -> &Self::Target {
1258 &self.guard
1259 }
1260}
1261
1262impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1263 fn deref_mut(&mut self) -> &mut Self::Target {
1264 &mut self.guard
1265 }
1266}
1267
1268impl<'a> Drop for ConnectionPoolGuard<'a> {
1269 fn drop(&mut self) {
1270 #[cfg(test)]
1271 self.check_invariants();
1272 }
1273}
1274
1275fn broadcast<F>(
1276 sender_id: Option<ConnectionId>,
1277 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1278 mut f: F,
1279) where
1280 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1281{
1282 for receiver_id in receiver_ids {
1283 if Some(receiver_id) != sender_id {
1284 if let Err(error) = f(receiver_id) {
1285 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1286 }
1287 }
1288 }
1289}
1290
1291pub struct ProtocolVersion(u32);
1292
1293impl Header for ProtocolVersion {
1294 fn name() -> &'static HeaderName {
1295 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1296 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1297 }
1298
1299 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1300 where
1301 Self: Sized,
1302 I: Iterator<Item = &'i axum::http::HeaderValue>,
1303 {
1304 let version = values
1305 .next()
1306 .ok_or_else(axum::headers::Error::invalid)?
1307 .to_str()
1308 .map_err(|_| axum::headers::Error::invalid())?
1309 .parse()
1310 .map_err(|_| axum::headers::Error::invalid())?;
1311 Ok(Self(version))
1312 }
1313
1314 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1315 values.extend([self.0.to_string().parse().unwrap()]);
1316 }
1317}
1318
1319pub struct AppVersionHeader(SemanticVersion);
1320impl Header for AppVersionHeader {
1321 fn name() -> &'static HeaderName {
1322 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1323 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1324 }
1325
1326 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1327 where
1328 Self: Sized,
1329 I: Iterator<Item = &'i axum::http::HeaderValue>,
1330 {
1331 let version = values
1332 .next()
1333 .ok_or_else(axum::headers::Error::invalid)?
1334 .to_str()
1335 .map_err(|_| axum::headers::Error::invalid())?
1336 .parse()
1337 .map_err(|_| axum::headers::Error::invalid())?;
1338 Ok(Self(version))
1339 }
1340
1341 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1342 values.extend([self.0.to_string().parse().unwrap()]);
1343 }
1344}
1345
1346pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1347 Router::new()
1348 .route("/rpc", get(handle_websocket_request))
1349 .layer(
1350 ServiceBuilder::new()
1351 .layer(Extension(server.app_state.clone()))
1352 .layer(middleware::from_fn(auth::validate_header)),
1353 )
1354 .route("/metrics", get(handle_metrics))
1355 .layer(Extension(server))
1356}
1357
1358pub async fn handle_websocket_request(
1359 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1360 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1361 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1362 Extension(server): Extension<Arc<Server>>,
1363 Extension(principal): Extension<Principal>,
1364 ws: WebSocketUpgrade,
1365) -> axum::response::Response {
1366 if protocol_version != rpc::PROTOCOL_VERSION {
1367 return (
1368 StatusCode::UPGRADE_REQUIRED,
1369 "client must be upgraded".to_string(),
1370 )
1371 .into_response();
1372 }
1373
1374 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1375 return (
1376 StatusCode::UPGRADE_REQUIRED,
1377 "no version header found".to_string(),
1378 )
1379 .into_response();
1380 };
1381
1382 if !version.can_collaborate() {
1383 return (
1384 StatusCode::UPGRADE_REQUIRED,
1385 "client must be upgraded".to_string(),
1386 )
1387 .into_response();
1388 }
1389
1390 let socket_address = socket_address.to_string();
1391 ws.on_upgrade(move |socket| {
1392 let socket = socket
1393 .map_ok(to_tungstenite_message)
1394 .err_into()
1395 .with(|message| async move { Ok(to_axum_message(message)) });
1396 let connection = Connection::new(Box::pin(socket));
1397 async move {
1398 server
1399 .handle_connection(
1400 connection,
1401 socket_address,
1402 principal,
1403 version,
1404 None,
1405 Executor::Production,
1406 )
1407 .await;
1408 }
1409 })
1410}
1411
1412pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1413 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1414 let connections_metric = CONNECTIONS_METRIC
1415 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1416
1417 let connections = server
1418 .connection_pool
1419 .lock()
1420 .connections()
1421 .filter(|connection| !connection.admin)
1422 .count();
1423 connections_metric.set(connections as _);
1424
1425 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1426 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1427 register_int_gauge!(
1428 "shared_projects",
1429 "number of open projects with one or more guests"
1430 )
1431 .unwrap()
1432 });
1433
1434 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1435 shared_projects_metric.set(shared_projects as _);
1436
1437 let encoder = prometheus::TextEncoder::new();
1438 let metric_families = prometheus::gather();
1439 let encoded_metrics = encoder
1440 .encode_to_string(&metric_families)
1441 .map_err(|err| anyhow!("{}", err))?;
1442 Ok(encoded_metrics)
1443}
1444
1445#[instrument(err, skip(executor))]
1446async fn connection_lost(
1447 session: Session,
1448 mut teardown: watch::Receiver<bool>,
1449 executor: Executor,
1450) -> Result<()> {
1451 session.peer.disconnect(session.connection_id);
1452 session
1453 .connection_pool()
1454 .await
1455 .remove_connection(session.connection_id)?;
1456
1457 session
1458 .db()
1459 .await
1460 .connection_lost(session.connection_id)
1461 .await
1462 .trace_err();
1463
1464 futures::select_biased! {
1465 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1466 match &session.principal {
1467 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1468 let session = session.for_user().unwrap();
1469
1470 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1471 leave_room_for_session(&session, session.connection_id).await.trace_err();
1472 leave_channel_buffers_for_session(&session)
1473 .await
1474 .trace_err();
1475
1476 if !session
1477 .connection_pool()
1478 .await
1479 .is_user_online(session.user_id())
1480 {
1481 let db = session.db().await;
1482 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1483 room_updated(&room, &session.peer);
1484 }
1485 }
1486
1487 update_user_contacts(session.user_id(), &session).await?;
1488 },
1489 Principal::DevServer(_) => {
1490 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1491 },
1492 }
1493 },
1494 _ = teardown.changed().fuse() => {}
1495 }
1496
1497 Ok(())
1498}
1499
1500/// Acknowledges a ping from a client, used to keep the connection alive.
1501async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1502 response.send(proto::Ack {})?;
1503 Ok(())
1504}
1505
1506/// Creates a new room for calling (outside of channels)
1507async fn create_room(
1508 _request: proto::CreateRoom,
1509 response: Response<proto::CreateRoom>,
1510 session: UserSession,
1511) -> Result<()> {
1512 let live_kit_room = nanoid::nanoid!(30);
1513
1514 let live_kit_connection_info = util::maybe!(async {
1515 let live_kit = session.live_kit_client.as_ref();
1516 let live_kit = live_kit?;
1517 let user_id = session.user_id().to_string();
1518
1519 let token = live_kit
1520 .room_token(&live_kit_room, &user_id.to_string())
1521 .trace_err()?;
1522
1523 Some(proto::LiveKitConnectionInfo {
1524 server_url: live_kit.url().into(),
1525 token,
1526 can_publish: true,
1527 })
1528 })
1529 .await;
1530
1531 let room = session
1532 .db()
1533 .await
1534 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1535 .await?;
1536
1537 response.send(proto::CreateRoomResponse {
1538 room: Some(room.clone()),
1539 live_kit_connection_info,
1540 })?;
1541
1542 update_user_contacts(session.user_id(), &session).await?;
1543 Ok(())
1544}
1545
1546/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1547async fn join_room(
1548 request: proto::JoinRoom,
1549 response: Response<proto::JoinRoom>,
1550 session: UserSession,
1551) -> Result<()> {
1552 let room_id = RoomId::from_proto(request.id);
1553
1554 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1555
1556 if let Some(channel_id) = channel_id {
1557 return join_channel_internal(channel_id, Box::new(response), session).await;
1558 }
1559
1560 let joined_room = {
1561 let room = session
1562 .db()
1563 .await
1564 .join_room(room_id, session.user_id(), session.connection_id)
1565 .await?;
1566 room_updated(&room.room, &session.peer);
1567 room.into_inner()
1568 };
1569
1570 for connection_id in session
1571 .connection_pool()
1572 .await
1573 .user_connection_ids(session.user_id())
1574 {
1575 session
1576 .peer
1577 .send(
1578 connection_id,
1579 proto::CallCanceled {
1580 room_id: room_id.to_proto(),
1581 },
1582 )
1583 .trace_err();
1584 }
1585
1586 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1587 if let Some(token) = live_kit
1588 .room_token(
1589 &joined_room.room.live_kit_room,
1590 &session.user_id().to_string(),
1591 )
1592 .trace_err()
1593 {
1594 Some(proto::LiveKitConnectionInfo {
1595 server_url: live_kit.url().into(),
1596 token,
1597 can_publish: true,
1598 })
1599 } else {
1600 None
1601 }
1602 } else {
1603 None
1604 };
1605
1606 response.send(proto::JoinRoomResponse {
1607 room: Some(joined_room.room),
1608 channel_id: None,
1609 live_kit_connection_info,
1610 })?;
1611
1612 update_user_contacts(session.user_id(), &session).await?;
1613 Ok(())
1614}
1615
1616/// Rejoin room is used to reconnect to a room after connection errors.
1617async fn rejoin_room(
1618 request: proto::RejoinRoom,
1619 response: Response<proto::RejoinRoom>,
1620 session: UserSession,
1621) -> Result<()> {
1622 let room;
1623 let channel;
1624 {
1625 let mut rejoined_room = session
1626 .db()
1627 .await
1628 .rejoin_room(request, session.user_id(), session.connection_id)
1629 .await?;
1630
1631 response.send(proto::RejoinRoomResponse {
1632 room: Some(rejoined_room.room.clone()),
1633 reshared_projects: rejoined_room
1634 .reshared_projects
1635 .iter()
1636 .map(|project| proto::ResharedProject {
1637 id: project.id.to_proto(),
1638 collaborators: project
1639 .collaborators
1640 .iter()
1641 .map(|collaborator| collaborator.to_proto())
1642 .collect(),
1643 })
1644 .collect(),
1645 rejoined_projects: rejoined_room
1646 .rejoined_projects
1647 .iter()
1648 .map(|rejoined_project| rejoined_project.to_proto())
1649 .collect(),
1650 })?;
1651 room_updated(&rejoined_room.room, &session.peer);
1652
1653 for project in &rejoined_room.reshared_projects {
1654 for collaborator in &project.collaborators {
1655 session
1656 .peer
1657 .send(
1658 collaborator.connection_id,
1659 proto::UpdateProjectCollaborator {
1660 project_id: project.id.to_proto(),
1661 old_peer_id: Some(project.old_connection_id.into()),
1662 new_peer_id: Some(session.connection_id.into()),
1663 },
1664 )
1665 .trace_err();
1666 }
1667
1668 broadcast(
1669 Some(session.connection_id),
1670 project
1671 .collaborators
1672 .iter()
1673 .map(|collaborator| collaborator.connection_id),
1674 |connection_id| {
1675 session.peer.forward_send(
1676 session.connection_id,
1677 connection_id,
1678 proto::UpdateProject {
1679 project_id: project.id.to_proto(),
1680 worktrees: project.worktrees.clone(),
1681 },
1682 )
1683 },
1684 );
1685 }
1686
1687 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1688
1689 let rejoined_room = rejoined_room.into_inner();
1690
1691 room = rejoined_room.room;
1692 channel = rejoined_room.channel;
1693 }
1694
1695 if let Some(channel) = channel {
1696 channel_updated(
1697 &channel,
1698 &room,
1699 &session.peer,
1700 &*session.connection_pool().await,
1701 );
1702 }
1703
1704 update_user_contacts(session.user_id(), &session).await?;
1705 Ok(())
1706}
1707
1708fn notify_rejoined_projects(
1709 rejoined_projects: &mut Vec<RejoinedProject>,
1710 session: &UserSession,
1711) -> Result<()> {
1712 for project in rejoined_projects.iter() {
1713 for collaborator in &project.collaborators {
1714 session
1715 .peer
1716 .send(
1717 collaborator.connection_id,
1718 proto::UpdateProjectCollaborator {
1719 project_id: project.id.to_proto(),
1720 old_peer_id: Some(project.old_connection_id.into()),
1721 new_peer_id: Some(session.connection_id.into()),
1722 },
1723 )
1724 .trace_err();
1725 }
1726 }
1727
1728 for project in rejoined_projects {
1729 for worktree in mem::take(&mut project.worktrees) {
1730 #[cfg(any(test, feature = "test-support"))]
1731 const MAX_CHUNK_SIZE: usize = 2;
1732 #[cfg(not(any(test, feature = "test-support")))]
1733 const MAX_CHUNK_SIZE: usize = 256;
1734
1735 // Stream this worktree's entries.
1736 let message = proto::UpdateWorktree {
1737 project_id: project.id.to_proto(),
1738 worktree_id: worktree.id,
1739 abs_path: worktree.abs_path.clone(),
1740 root_name: worktree.root_name,
1741 updated_entries: worktree.updated_entries,
1742 removed_entries: worktree.removed_entries,
1743 scan_id: worktree.scan_id,
1744 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1745 updated_repositories: worktree.updated_repositories,
1746 removed_repositories: worktree.removed_repositories,
1747 };
1748 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1749 session.peer.send(session.connection_id, update.clone())?;
1750 }
1751
1752 // Stream this worktree's diagnostics.
1753 for summary in worktree.diagnostic_summaries {
1754 session.peer.send(
1755 session.connection_id,
1756 proto::UpdateDiagnosticSummary {
1757 project_id: project.id.to_proto(),
1758 worktree_id: worktree.id,
1759 summary: Some(summary),
1760 },
1761 )?;
1762 }
1763
1764 for settings_file in worktree.settings_files {
1765 session.peer.send(
1766 session.connection_id,
1767 proto::UpdateWorktreeSettings {
1768 project_id: project.id.to_proto(),
1769 worktree_id: worktree.id,
1770 path: settings_file.path,
1771 content: Some(settings_file.content),
1772 },
1773 )?;
1774 }
1775 }
1776
1777 for language_server in &project.language_servers {
1778 session.peer.send(
1779 session.connection_id,
1780 proto::UpdateLanguageServer {
1781 project_id: project.id.to_proto(),
1782 language_server_id: language_server.id,
1783 variant: Some(
1784 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1785 proto::LspDiskBasedDiagnosticsUpdated {},
1786 ),
1787 ),
1788 },
1789 )?;
1790 }
1791 }
1792 Ok(())
1793}
1794
1795/// leave room disconnects from the room.
1796async fn leave_room(
1797 _: proto::LeaveRoom,
1798 response: Response<proto::LeaveRoom>,
1799 session: UserSession,
1800) -> Result<()> {
1801 leave_room_for_session(&session, session.connection_id).await?;
1802 response.send(proto::Ack {})?;
1803 Ok(())
1804}
1805
1806/// Updates the permissions of someone else in the room.
1807async fn set_room_participant_role(
1808 request: proto::SetRoomParticipantRole,
1809 response: Response<proto::SetRoomParticipantRole>,
1810 session: UserSession,
1811) -> Result<()> {
1812 let user_id = UserId::from_proto(request.user_id);
1813 let role = ChannelRole::from(request.role());
1814
1815 let (live_kit_room, can_publish) = {
1816 let room = session
1817 .db()
1818 .await
1819 .set_room_participant_role(
1820 session.user_id(),
1821 RoomId::from_proto(request.room_id),
1822 user_id,
1823 role,
1824 )
1825 .await?;
1826
1827 let live_kit_room = room.live_kit_room.clone();
1828 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1829 room_updated(&room, &session.peer);
1830 (live_kit_room, can_publish)
1831 };
1832
1833 if let Some(live_kit) = session.live_kit_client.as_ref() {
1834 live_kit
1835 .update_participant(
1836 live_kit_room.clone(),
1837 request.user_id.to_string(),
1838 live_kit_server::proto::ParticipantPermission {
1839 can_subscribe: true,
1840 can_publish,
1841 can_publish_data: can_publish,
1842 hidden: false,
1843 recorder: false,
1844 },
1845 )
1846 .await
1847 .trace_err();
1848 }
1849
1850 response.send(proto::Ack {})?;
1851 Ok(())
1852}
1853
1854/// Call someone else into the current room
1855async fn call(
1856 request: proto::Call,
1857 response: Response<proto::Call>,
1858 session: UserSession,
1859) -> Result<()> {
1860 let room_id = RoomId::from_proto(request.room_id);
1861 let calling_user_id = session.user_id();
1862 let calling_connection_id = session.connection_id;
1863 let called_user_id = UserId::from_proto(request.called_user_id);
1864 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1865 if !session
1866 .db()
1867 .await
1868 .has_contact(calling_user_id, called_user_id)
1869 .await?
1870 {
1871 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1872 }
1873
1874 let incoming_call = {
1875 let (room, incoming_call) = &mut *session
1876 .db()
1877 .await
1878 .call(
1879 room_id,
1880 calling_user_id,
1881 calling_connection_id,
1882 called_user_id,
1883 initial_project_id,
1884 )
1885 .await?;
1886 room_updated(&room, &session.peer);
1887 mem::take(incoming_call)
1888 };
1889 update_user_contacts(called_user_id, &session).await?;
1890
1891 let mut calls = session
1892 .connection_pool()
1893 .await
1894 .user_connection_ids(called_user_id)
1895 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1896 .collect::<FuturesUnordered<_>>();
1897
1898 while let Some(call_response) = calls.next().await {
1899 match call_response.as_ref() {
1900 Ok(_) => {
1901 response.send(proto::Ack {})?;
1902 return Ok(());
1903 }
1904 Err(_) => {
1905 call_response.trace_err();
1906 }
1907 }
1908 }
1909
1910 {
1911 let room = session
1912 .db()
1913 .await
1914 .call_failed(room_id, called_user_id)
1915 .await?;
1916 room_updated(&room, &session.peer);
1917 }
1918 update_user_contacts(called_user_id, &session).await?;
1919
1920 Err(anyhow!("failed to ring user"))?
1921}
1922
1923/// Cancel an outgoing call.
1924async fn cancel_call(
1925 request: proto::CancelCall,
1926 response: Response<proto::CancelCall>,
1927 session: UserSession,
1928) -> Result<()> {
1929 let called_user_id = UserId::from_proto(request.called_user_id);
1930 let room_id = RoomId::from_proto(request.room_id);
1931 {
1932 let room = session
1933 .db()
1934 .await
1935 .cancel_call(room_id, session.connection_id, called_user_id)
1936 .await?;
1937 room_updated(&room, &session.peer);
1938 }
1939
1940 for connection_id in session
1941 .connection_pool()
1942 .await
1943 .user_connection_ids(called_user_id)
1944 {
1945 session
1946 .peer
1947 .send(
1948 connection_id,
1949 proto::CallCanceled {
1950 room_id: room_id.to_proto(),
1951 },
1952 )
1953 .trace_err();
1954 }
1955 response.send(proto::Ack {})?;
1956
1957 update_user_contacts(called_user_id, &session).await?;
1958 Ok(())
1959}
1960
1961/// Decline an incoming call.
1962async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1963 let room_id = RoomId::from_proto(message.room_id);
1964 {
1965 let room = session
1966 .db()
1967 .await
1968 .decline_call(Some(room_id), session.user_id())
1969 .await?
1970 .ok_or_else(|| anyhow!("failed to decline call"))?;
1971 room_updated(&room, &session.peer);
1972 }
1973
1974 for connection_id in session
1975 .connection_pool()
1976 .await
1977 .user_connection_ids(session.user_id())
1978 {
1979 session
1980 .peer
1981 .send(
1982 connection_id,
1983 proto::CallCanceled {
1984 room_id: room_id.to_proto(),
1985 },
1986 )
1987 .trace_err();
1988 }
1989 update_user_contacts(session.user_id(), &session).await?;
1990 Ok(())
1991}
1992
1993/// Updates other participants in the room with your current location.
1994async fn update_participant_location(
1995 request: proto::UpdateParticipantLocation,
1996 response: Response<proto::UpdateParticipantLocation>,
1997 session: UserSession,
1998) -> Result<()> {
1999 let room_id = RoomId::from_proto(request.room_id);
2000 let location = request
2001 .location
2002 .ok_or_else(|| anyhow!("invalid location"))?;
2003
2004 let db = session.db().await;
2005 let room = db
2006 .update_room_participant_location(room_id, session.connection_id, location)
2007 .await?;
2008
2009 room_updated(&room, &session.peer);
2010 response.send(proto::Ack {})?;
2011 Ok(())
2012}
2013
2014/// Share a project into the room.
2015async fn share_project(
2016 request: proto::ShareProject,
2017 response: Response<proto::ShareProject>,
2018 session: UserSession,
2019) -> Result<()> {
2020 let (project_id, room) = &*session
2021 .db()
2022 .await
2023 .share_project(
2024 RoomId::from_proto(request.room_id),
2025 session.connection_id,
2026 &request.worktrees,
2027 request
2028 .dev_server_project_id
2029 .map(|id| DevServerProjectId::from_proto(id)),
2030 )
2031 .await?;
2032 response.send(proto::ShareProjectResponse {
2033 project_id: project_id.to_proto(),
2034 })?;
2035 room_updated(&room, &session.peer);
2036
2037 Ok(())
2038}
2039
2040/// Unshare a project from the room.
2041async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2042 let project_id = ProjectId::from_proto(message.project_id);
2043 unshare_project_internal(
2044 project_id,
2045 session.connection_id,
2046 session.user_id(),
2047 &session,
2048 )
2049 .await
2050}
2051
2052async fn unshare_project_internal(
2053 project_id: ProjectId,
2054 connection_id: ConnectionId,
2055 user_id: Option<UserId>,
2056 session: &Session,
2057) -> Result<()> {
2058 let delete = {
2059 let room_guard = session
2060 .db()
2061 .await
2062 .unshare_project(project_id, connection_id, user_id)
2063 .await?;
2064
2065 let (delete, room, guest_connection_ids) = &*room_guard;
2066
2067 let message = proto::UnshareProject {
2068 project_id: project_id.to_proto(),
2069 };
2070
2071 broadcast(
2072 Some(connection_id),
2073 guest_connection_ids.iter().copied(),
2074 |conn_id| session.peer.send(conn_id, message.clone()),
2075 );
2076 if let Some(room) = room {
2077 room_updated(room, &session.peer);
2078 }
2079
2080 *delete
2081 };
2082
2083 if delete {
2084 let db = session.db().await;
2085 db.delete_project(project_id).await?;
2086 }
2087
2088 Ok(())
2089}
2090
2091/// DevServer makes a project available online
2092async fn share_dev_server_project(
2093 request: proto::ShareDevServerProject,
2094 response: Response<proto::ShareDevServerProject>,
2095 session: DevServerSession,
2096) -> Result<()> {
2097 let (dev_server_project, user_id, status) = session
2098 .db()
2099 .await
2100 .share_dev_server_project(
2101 DevServerProjectId::from_proto(request.dev_server_project_id),
2102 session.dev_server_id(),
2103 session.connection_id,
2104 &request.worktrees,
2105 )
2106 .await?;
2107 let Some(project_id) = dev_server_project.project_id else {
2108 return Err(anyhow!("failed to share remote project"))?;
2109 };
2110
2111 send_dev_server_projects_update(user_id, status, &session).await;
2112
2113 response.send(proto::ShareProjectResponse { project_id })?;
2114
2115 Ok(())
2116}
2117
2118/// Join someone elses shared project.
2119async fn join_project(
2120 request: proto::JoinProject,
2121 response: Response<proto::JoinProject>,
2122 session: UserSession,
2123) -> Result<()> {
2124 let project_id = ProjectId::from_proto(request.project_id);
2125
2126 tracing::info!(%project_id, "join project");
2127
2128 let db = session.db().await;
2129 let (project, replica_id) = &mut *db
2130 .join_project(project_id, session.connection_id, session.user_id())
2131 .await?;
2132 drop(db);
2133 tracing::info!(%project_id, "join remote project");
2134 join_project_internal(response, session, project, replica_id)
2135}
2136
2137trait JoinProjectInternalResponse {
2138 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2139}
2140impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2141 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2142 Response::<proto::JoinProject>::send(self, result)
2143 }
2144}
2145impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2146 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2147 Response::<proto::JoinHostedProject>::send(self, result)
2148 }
2149}
2150
2151fn join_project_internal(
2152 response: impl JoinProjectInternalResponse,
2153 session: UserSession,
2154 project: &mut Project,
2155 replica_id: &ReplicaId,
2156) -> Result<()> {
2157 let collaborators = project
2158 .collaborators
2159 .iter()
2160 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2161 .map(|collaborator| collaborator.to_proto())
2162 .collect::<Vec<_>>();
2163 let project_id = project.id;
2164 let guest_user_id = session.user_id();
2165
2166 let worktrees = project
2167 .worktrees
2168 .iter()
2169 .map(|(id, worktree)| proto::WorktreeMetadata {
2170 id: *id,
2171 root_name: worktree.root_name.clone(),
2172 visible: worktree.visible,
2173 abs_path: worktree.abs_path.clone(),
2174 })
2175 .collect::<Vec<_>>();
2176
2177 let add_project_collaborator = proto::AddProjectCollaborator {
2178 project_id: project_id.to_proto(),
2179 collaborator: Some(proto::Collaborator {
2180 peer_id: Some(session.connection_id.into()),
2181 replica_id: replica_id.0 as u32,
2182 user_id: guest_user_id.to_proto(),
2183 }),
2184 };
2185
2186 for collaborator in &collaborators {
2187 session
2188 .peer
2189 .send(
2190 collaborator.peer_id.unwrap().into(),
2191 add_project_collaborator.clone(),
2192 )
2193 .trace_err();
2194 }
2195
2196 // First, we send the metadata associated with each worktree.
2197 response.send(proto::JoinProjectResponse {
2198 project_id: project.id.0 as u64,
2199 worktrees: worktrees.clone(),
2200 replica_id: replica_id.0 as u32,
2201 collaborators: collaborators.clone(),
2202 language_servers: project.language_servers.clone(),
2203 role: project.role.into(),
2204 dev_server_project_id: project
2205 .dev_server_project_id
2206 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2207 })?;
2208
2209 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2210 #[cfg(any(test, feature = "test-support"))]
2211 const MAX_CHUNK_SIZE: usize = 2;
2212 #[cfg(not(any(test, feature = "test-support")))]
2213 const MAX_CHUNK_SIZE: usize = 256;
2214
2215 // Stream this worktree's entries.
2216 let message = proto::UpdateWorktree {
2217 project_id: project_id.to_proto(),
2218 worktree_id,
2219 abs_path: worktree.abs_path.clone(),
2220 root_name: worktree.root_name,
2221 updated_entries: worktree.entries,
2222 removed_entries: Default::default(),
2223 scan_id: worktree.scan_id,
2224 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2225 updated_repositories: worktree.repository_entries.into_values().collect(),
2226 removed_repositories: Default::default(),
2227 };
2228 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2229 session.peer.send(session.connection_id, update.clone())?;
2230 }
2231
2232 // Stream this worktree's diagnostics.
2233 for summary in worktree.diagnostic_summaries {
2234 session.peer.send(
2235 session.connection_id,
2236 proto::UpdateDiagnosticSummary {
2237 project_id: project_id.to_proto(),
2238 worktree_id: worktree.id,
2239 summary: Some(summary),
2240 },
2241 )?;
2242 }
2243
2244 for settings_file in worktree.settings_files {
2245 session.peer.send(
2246 session.connection_id,
2247 proto::UpdateWorktreeSettings {
2248 project_id: project_id.to_proto(),
2249 worktree_id: worktree.id,
2250 path: settings_file.path,
2251 content: Some(settings_file.content),
2252 },
2253 )?;
2254 }
2255 }
2256
2257 for language_server in &project.language_servers {
2258 session.peer.send(
2259 session.connection_id,
2260 proto::UpdateLanguageServer {
2261 project_id: project_id.to_proto(),
2262 language_server_id: language_server.id,
2263 variant: Some(
2264 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2265 proto::LspDiskBasedDiagnosticsUpdated {},
2266 ),
2267 ),
2268 },
2269 )?;
2270 }
2271
2272 Ok(())
2273}
2274
2275/// Leave someone elses shared project.
2276async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2277 let sender_id = session.connection_id;
2278 let project_id = ProjectId::from_proto(request.project_id);
2279 let db = session.db().await;
2280 if db.is_hosted_project(project_id).await? {
2281 let project = db.leave_hosted_project(project_id, sender_id).await?;
2282 project_left(&project, &session);
2283 return Ok(());
2284 }
2285
2286 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2287 tracing::info!(
2288 %project_id,
2289 "leave project"
2290 );
2291
2292 project_left(&project, &session);
2293 if let Some(room) = room {
2294 room_updated(&room, &session.peer);
2295 }
2296
2297 Ok(())
2298}
2299
2300async fn join_hosted_project(
2301 request: proto::JoinHostedProject,
2302 response: Response<proto::JoinHostedProject>,
2303 session: UserSession,
2304) -> Result<()> {
2305 let (mut project, replica_id) = session
2306 .db()
2307 .await
2308 .join_hosted_project(
2309 ProjectId(request.project_id as i32),
2310 session.user_id(),
2311 session.connection_id,
2312 )
2313 .await?;
2314
2315 join_project_internal(response, session, &mut project, &replica_id)
2316}
2317
2318async fn list_remote_directory(
2319 request: proto::ListRemoteDirectory,
2320 response: Response<proto::ListRemoteDirectory>,
2321 session: UserSession,
2322) -> Result<()> {
2323 let dev_server_id = DevServerId(request.dev_server_id as i32);
2324 let dev_server_connection_id = session
2325 .connection_pool()
2326 .await
2327 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2328
2329 session
2330 .db()
2331 .await
2332 .get_dev_server_for_user(dev_server_id, session.user_id())
2333 .await?;
2334
2335 response.send(
2336 session
2337 .peer
2338 .forward_request(session.connection_id, dev_server_connection_id, request)
2339 .await?,
2340 )?;
2341 Ok(())
2342}
2343
2344async fn update_dev_server_project(
2345 request: proto::UpdateDevServerProject,
2346 response: Response<proto::UpdateDevServerProject>,
2347 session: UserSession,
2348) -> Result<()> {
2349 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2350
2351 let (dev_server_project, update) = session
2352 .db()
2353 .await
2354 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2355 .await?;
2356
2357 let projects = session
2358 .db()
2359 .await
2360 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2361 .await?;
2362
2363 let dev_server_connection_id = session
2364 .connection_pool()
2365 .await
2366 .dev_server_connection_id_supporting(
2367 dev_server_project.dev_server_id,
2368 ZedVersion::with_list_directory(),
2369 )?;
2370
2371 session.peer.send(
2372 dev_server_connection_id,
2373 proto::DevServerInstructions { projects },
2374 )?;
2375
2376 send_dev_server_projects_update(session.user_id(), update, &session).await;
2377
2378 response.send(proto::Ack {})
2379}
2380
2381async fn create_dev_server_project(
2382 request: proto::CreateDevServerProject,
2383 response: Response<proto::CreateDevServerProject>,
2384 session: UserSession,
2385) -> Result<()> {
2386 let dev_server_id = DevServerId(request.dev_server_id as i32);
2387 let dev_server_connection_id = session
2388 .connection_pool()
2389 .await
2390 .dev_server_connection_id(dev_server_id);
2391 let Some(dev_server_connection_id) = dev_server_connection_id else {
2392 Err(ErrorCode::DevServerOffline
2393 .message("Cannot create a remote project when the dev server is offline".to_string())
2394 .anyhow())?
2395 };
2396
2397 let path = request.path.clone();
2398 //Check that the path exists on the dev server
2399 session
2400 .peer
2401 .forward_request(
2402 session.connection_id,
2403 dev_server_connection_id,
2404 proto::ValidateDevServerProjectRequest { path: path.clone() },
2405 )
2406 .await?;
2407
2408 let (dev_server_project, update) = session
2409 .db()
2410 .await
2411 .create_dev_server_project(
2412 DevServerId(request.dev_server_id as i32),
2413 &request.path,
2414 session.user_id(),
2415 )
2416 .await?;
2417
2418 let projects = session
2419 .db()
2420 .await
2421 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2422 .await?;
2423
2424 session.peer.send(
2425 dev_server_connection_id,
2426 proto::DevServerInstructions { projects },
2427 )?;
2428
2429 send_dev_server_projects_update(session.user_id(), update, &session).await;
2430
2431 response.send(proto::CreateDevServerProjectResponse {
2432 dev_server_project: Some(dev_server_project.to_proto(None)),
2433 })?;
2434 Ok(())
2435}
2436
2437async fn create_dev_server(
2438 request: proto::CreateDevServer,
2439 response: Response<proto::CreateDevServer>,
2440 session: UserSession,
2441) -> Result<()> {
2442 let access_token = auth::random_token();
2443 let hashed_access_token = auth::hash_access_token(&access_token);
2444
2445 if request.name.is_empty() {
2446 return Err(proto::ErrorCode::Forbidden
2447 .message("Dev server name cannot be empty".to_string())
2448 .anyhow())?;
2449 }
2450
2451 let (dev_server, status) = session
2452 .db()
2453 .await
2454 .create_dev_server(
2455 &request.name,
2456 request.ssh_connection_string.as_deref(),
2457 &hashed_access_token,
2458 session.user_id(),
2459 )
2460 .await?;
2461
2462 send_dev_server_projects_update(session.user_id(), status, &session).await;
2463
2464 response.send(proto::CreateDevServerResponse {
2465 dev_server_id: dev_server.id.0 as u64,
2466 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2467 name: request.name,
2468 })?;
2469 Ok(())
2470}
2471
2472async fn regenerate_dev_server_token(
2473 request: proto::RegenerateDevServerToken,
2474 response: Response<proto::RegenerateDevServerToken>,
2475 session: UserSession,
2476) -> Result<()> {
2477 let dev_server_id = DevServerId(request.dev_server_id as i32);
2478 let access_token = auth::random_token();
2479 let hashed_access_token = auth::hash_access_token(&access_token);
2480
2481 let connection_id = session
2482 .connection_pool()
2483 .await
2484 .dev_server_connection_id(dev_server_id);
2485 if let Some(connection_id) = connection_id {
2486 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2487 session.peer.send(
2488 connection_id,
2489 proto::ShutdownDevServer {
2490 reason: Some("dev server token was regenerated".to_string()),
2491 },
2492 )?;
2493 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2494 }
2495
2496 let status = session
2497 .db()
2498 .await
2499 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2500 .await?;
2501
2502 send_dev_server_projects_update(session.user_id(), status, &session).await;
2503
2504 response.send(proto::RegenerateDevServerTokenResponse {
2505 dev_server_id: dev_server_id.to_proto(),
2506 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2507 })?;
2508 Ok(())
2509}
2510
2511async fn rename_dev_server(
2512 request: proto::RenameDevServer,
2513 response: Response<proto::RenameDevServer>,
2514 session: UserSession,
2515) -> Result<()> {
2516 if request.name.trim().is_empty() {
2517 return Err(proto::ErrorCode::Forbidden
2518 .message("Dev server name cannot be empty".to_string())
2519 .anyhow())?;
2520 }
2521
2522 let dev_server_id = DevServerId(request.dev_server_id as i32);
2523 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2524 if dev_server.user_id != session.user_id() {
2525 return Err(anyhow!(ErrorCode::Forbidden))?;
2526 }
2527
2528 let status = session
2529 .db()
2530 .await
2531 .rename_dev_server(
2532 dev_server_id,
2533 &request.name,
2534 request.ssh_connection_string.as_deref(),
2535 session.user_id(),
2536 )
2537 .await?;
2538
2539 send_dev_server_projects_update(session.user_id(), status, &session).await;
2540
2541 response.send(proto::Ack {})?;
2542 Ok(())
2543}
2544
2545async fn delete_dev_server(
2546 request: proto::DeleteDevServer,
2547 response: Response<proto::DeleteDevServer>,
2548 session: UserSession,
2549) -> Result<()> {
2550 let dev_server_id = DevServerId(request.dev_server_id as i32);
2551 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2552 if dev_server.user_id != session.user_id() {
2553 return Err(anyhow!(ErrorCode::Forbidden))?;
2554 }
2555
2556 let connection_id = session
2557 .connection_pool()
2558 .await
2559 .dev_server_connection_id(dev_server_id);
2560 if let Some(connection_id) = connection_id {
2561 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2562 session.peer.send(
2563 connection_id,
2564 proto::ShutdownDevServer {
2565 reason: Some("dev server was deleted".to_string()),
2566 },
2567 )?;
2568 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2569 }
2570
2571 let status = session
2572 .db()
2573 .await
2574 .delete_dev_server(dev_server_id, session.user_id())
2575 .await?;
2576
2577 send_dev_server_projects_update(session.user_id(), status, &session).await;
2578
2579 response.send(proto::Ack {})?;
2580 Ok(())
2581}
2582
2583async fn delete_dev_server_project(
2584 request: proto::DeleteDevServerProject,
2585 response: Response<proto::DeleteDevServerProject>,
2586 session: UserSession,
2587) -> Result<()> {
2588 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2589 let dev_server_project = session
2590 .db()
2591 .await
2592 .get_dev_server_project(dev_server_project_id)
2593 .await?;
2594
2595 let dev_server = session
2596 .db()
2597 .await
2598 .get_dev_server(dev_server_project.dev_server_id)
2599 .await?;
2600 if dev_server.user_id != session.user_id() {
2601 return Err(anyhow!(ErrorCode::Forbidden))?;
2602 }
2603
2604 let dev_server_connection_id = session
2605 .connection_pool()
2606 .await
2607 .dev_server_connection_id(dev_server.id);
2608
2609 if let Some(dev_server_connection_id) = dev_server_connection_id {
2610 let project = session
2611 .db()
2612 .await
2613 .find_dev_server_project(dev_server_project_id)
2614 .await;
2615 if let Ok(project) = project {
2616 unshare_project_internal(
2617 project.id,
2618 dev_server_connection_id,
2619 Some(session.user_id()),
2620 &session,
2621 )
2622 .await?;
2623 }
2624 }
2625
2626 let (projects, status) = session
2627 .db()
2628 .await
2629 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2630 .await?;
2631
2632 if let Some(dev_server_connection_id) = dev_server_connection_id {
2633 session.peer.send(
2634 dev_server_connection_id,
2635 proto::DevServerInstructions { projects },
2636 )?;
2637 }
2638
2639 send_dev_server_projects_update(session.user_id(), status, &session).await;
2640
2641 response.send(proto::Ack {})?;
2642 Ok(())
2643}
2644
2645async fn rejoin_dev_server_projects(
2646 request: proto::RejoinRemoteProjects,
2647 response: Response<proto::RejoinRemoteProjects>,
2648 session: UserSession,
2649) -> Result<()> {
2650 let mut rejoined_projects = {
2651 let db = session.db().await;
2652 db.rejoin_dev_server_projects(
2653 &request.rejoined_projects,
2654 session.user_id(),
2655 session.0.connection_id,
2656 )
2657 .await?
2658 };
2659 response.send(proto::RejoinRemoteProjectsResponse {
2660 rejoined_projects: rejoined_projects
2661 .iter()
2662 .map(|project| project.to_proto())
2663 .collect(),
2664 })?;
2665 notify_rejoined_projects(&mut rejoined_projects, &session)
2666}
2667
2668async fn reconnect_dev_server(
2669 request: proto::ReconnectDevServer,
2670 response: Response<proto::ReconnectDevServer>,
2671 session: DevServerSession,
2672) -> Result<()> {
2673 let reshared_projects = {
2674 let db = session.db().await;
2675 db.reshare_dev_server_projects(
2676 &request.reshared_projects,
2677 session.dev_server_id(),
2678 session.0.connection_id,
2679 )
2680 .await?
2681 };
2682
2683 for project in &reshared_projects {
2684 for collaborator in &project.collaborators {
2685 session
2686 .peer
2687 .send(
2688 collaborator.connection_id,
2689 proto::UpdateProjectCollaborator {
2690 project_id: project.id.to_proto(),
2691 old_peer_id: Some(project.old_connection_id.into()),
2692 new_peer_id: Some(session.connection_id.into()),
2693 },
2694 )
2695 .trace_err();
2696 }
2697
2698 broadcast(
2699 Some(session.connection_id),
2700 project
2701 .collaborators
2702 .iter()
2703 .map(|collaborator| collaborator.connection_id),
2704 |connection_id| {
2705 session.peer.forward_send(
2706 session.connection_id,
2707 connection_id,
2708 proto::UpdateProject {
2709 project_id: project.id.to_proto(),
2710 worktrees: project.worktrees.clone(),
2711 },
2712 )
2713 },
2714 );
2715 }
2716
2717 response.send(proto::ReconnectDevServerResponse {
2718 reshared_projects: reshared_projects
2719 .iter()
2720 .map(|project| proto::ResharedProject {
2721 id: project.id.to_proto(),
2722 collaborators: project
2723 .collaborators
2724 .iter()
2725 .map(|collaborator| collaborator.to_proto())
2726 .collect(),
2727 })
2728 .collect(),
2729 })?;
2730
2731 Ok(())
2732}
2733
2734async fn shutdown_dev_server(
2735 _: proto::ShutdownDevServer,
2736 response: Response<proto::ShutdownDevServer>,
2737 session: DevServerSession,
2738) -> Result<()> {
2739 response.send(proto::Ack {})?;
2740 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2741 remove_dev_server_connection(session.dev_server_id(), &session).await
2742}
2743
2744async fn shutdown_dev_server_internal(
2745 dev_server_id: DevServerId,
2746 connection_id: ConnectionId,
2747 session: &Session,
2748) -> Result<()> {
2749 let (dev_server_projects, dev_server) = {
2750 let db = session.db().await;
2751 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2752 let dev_server = db.get_dev_server(dev_server_id).await?;
2753 (dev_server_projects, dev_server)
2754 };
2755
2756 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2757 unshare_project_internal(
2758 ProjectId::from_proto(project_id),
2759 connection_id,
2760 None,
2761 session,
2762 )
2763 .await?;
2764 }
2765
2766 session
2767 .connection_pool()
2768 .await
2769 .set_dev_server_offline(dev_server_id);
2770
2771 let status = session
2772 .db()
2773 .await
2774 .dev_server_projects_update(dev_server.user_id)
2775 .await?;
2776 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2777
2778 Ok(())
2779}
2780
2781async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2782 let dev_server_connection = session
2783 .connection_pool()
2784 .await
2785 .dev_server_connection_id(dev_server_id);
2786
2787 if let Some(dev_server_connection) = dev_server_connection {
2788 session
2789 .connection_pool()
2790 .await
2791 .remove_connection(dev_server_connection)?;
2792 }
2793 Ok(())
2794}
2795
2796/// Updates other participants with changes to the project
2797async fn update_project(
2798 request: proto::UpdateProject,
2799 response: Response<proto::UpdateProject>,
2800 session: Session,
2801) -> Result<()> {
2802 let project_id = ProjectId::from_proto(request.project_id);
2803 let (room, guest_connection_ids) = &*session
2804 .db()
2805 .await
2806 .update_project(project_id, session.connection_id, &request.worktrees)
2807 .await?;
2808 broadcast(
2809 Some(session.connection_id),
2810 guest_connection_ids.iter().copied(),
2811 |connection_id| {
2812 session
2813 .peer
2814 .forward_send(session.connection_id, connection_id, request.clone())
2815 },
2816 );
2817 if let Some(room) = room {
2818 room_updated(&room, &session.peer);
2819 }
2820 response.send(proto::Ack {})?;
2821
2822 Ok(())
2823}
2824
2825/// Updates other participants with changes to the worktree
2826async fn update_worktree(
2827 request: proto::UpdateWorktree,
2828 response: Response<proto::UpdateWorktree>,
2829 session: Session,
2830) -> Result<()> {
2831 let guest_connection_ids = session
2832 .db()
2833 .await
2834 .update_worktree(&request, session.connection_id)
2835 .await?;
2836
2837 broadcast(
2838 Some(session.connection_id),
2839 guest_connection_ids.iter().copied(),
2840 |connection_id| {
2841 session
2842 .peer
2843 .forward_send(session.connection_id, connection_id, request.clone())
2844 },
2845 );
2846 response.send(proto::Ack {})?;
2847 Ok(())
2848}
2849
2850/// Updates other participants with changes to the diagnostics
2851async fn update_diagnostic_summary(
2852 message: proto::UpdateDiagnosticSummary,
2853 session: Session,
2854) -> Result<()> {
2855 let guest_connection_ids = session
2856 .db()
2857 .await
2858 .update_diagnostic_summary(&message, session.connection_id)
2859 .await?;
2860
2861 broadcast(
2862 Some(session.connection_id),
2863 guest_connection_ids.iter().copied(),
2864 |connection_id| {
2865 session
2866 .peer
2867 .forward_send(session.connection_id, connection_id, message.clone())
2868 },
2869 );
2870
2871 Ok(())
2872}
2873
2874/// Updates other participants with changes to the worktree settings
2875async fn update_worktree_settings(
2876 message: proto::UpdateWorktreeSettings,
2877 session: Session,
2878) -> Result<()> {
2879 let guest_connection_ids = session
2880 .db()
2881 .await
2882 .update_worktree_settings(&message, session.connection_id)
2883 .await?;
2884
2885 broadcast(
2886 Some(session.connection_id),
2887 guest_connection_ids.iter().copied(),
2888 |connection_id| {
2889 session
2890 .peer
2891 .forward_send(session.connection_id, connection_id, message.clone())
2892 },
2893 );
2894
2895 Ok(())
2896}
2897
2898/// Notify other participants that a language server has started.
2899async fn start_language_server(
2900 request: proto::StartLanguageServer,
2901 session: Session,
2902) -> Result<()> {
2903 let guest_connection_ids = session
2904 .db()
2905 .await
2906 .start_language_server(&request, session.connection_id)
2907 .await?;
2908
2909 broadcast(
2910 Some(session.connection_id),
2911 guest_connection_ids.iter().copied(),
2912 |connection_id| {
2913 session
2914 .peer
2915 .forward_send(session.connection_id, connection_id, request.clone())
2916 },
2917 );
2918 Ok(())
2919}
2920
2921/// Notify other participants that a language server has changed.
2922async fn update_language_server(
2923 request: proto::UpdateLanguageServer,
2924 session: Session,
2925) -> Result<()> {
2926 let project_id = ProjectId::from_proto(request.project_id);
2927 let project_connection_ids = session
2928 .db()
2929 .await
2930 .project_connection_ids(project_id, session.connection_id, true)
2931 .await?;
2932 broadcast(
2933 Some(session.connection_id),
2934 project_connection_ids.iter().copied(),
2935 |connection_id| {
2936 session
2937 .peer
2938 .forward_send(session.connection_id, connection_id, request.clone())
2939 },
2940 );
2941 Ok(())
2942}
2943
2944/// forward a project request to the host. These requests should be read only
2945/// as guests are allowed to send them.
2946async fn forward_read_only_project_request<T>(
2947 request: T,
2948 response: Response<T>,
2949 session: UserSession,
2950) -> Result<()>
2951where
2952 T: EntityMessage + RequestMessage,
2953{
2954 let project_id = ProjectId::from_proto(request.remote_entity_id());
2955 let host_connection_id = session
2956 .db()
2957 .await
2958 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2959 .await?;
2960 let payload = session
2961 .peer
2962 .forward_request(session.connection_id, host_connection_id, request)
2963 .await?;
2964 response.send(payload)?;
2965 Ok(())
2966}
2967
2968/// forward a project request to the dev server. Only allowed
2969/// if it's your dev server.
2970async fn forward_project_request_for_owner<T>(
2971 request: T,
2972 response: Response<T>,
2973 session: UserSession,
2974) -> Result<()>
2975where
2976 T: EntityMessage + RequestMessage,
2977{
2978 let project_id = ProjectId::from_proto(request.remote_entity_id());
2979
2980 let host_connection_id = session
2981 .db()
2982 .await
2983 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2984 .await?;
2985 let payload = session
2986 .peer
2987 .forward_request(session.connection_id, host_connection_id, request)
2988 .await?;
2989 response.send(payload)?;
2990 Ok(())
2991}
2992
2993/// forward a project request to the host. These requests are disallowed
2994/// for guests.
2995async fn forward_mutating_project_request<T>(
2996 request: T,
2997 response: Response<T>,
2998 session: UserSession,
2999) -> Result<()>
3000where
3001 T: EntityMessage + RequestMessage,
3002{
3003 let project_id = ProjectId::from_proto(request.remote_entity_id());
3004
3005 let host_connection_id = session
3006 .db()
3007 .await
3008 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3009 .await?;
3010 let payload = session
3011 .peer
3012 .forward_request(session.connection_id, host_connection_id, request)
3013 .await?;
3014 response.send(payload)?;
3015 Ok(())
3016}
3017
3018/// forward a project request to the host. These requests are disallowed
3019/// for guests.
3020async fn forward_versioned_mutating_project_request<T>(
3021 request: T,
3022 response: Response<T>,
3023 session: UserSession,
3024) -> Result<()>
3025where
3026 T: EntityMessage + RequestMessage + VersionedMessage,
3027{
3028 let project_id = ProjectId::from_proto(request.remote_entity_id());
3029
3030 let host_connection_id = session
3031 .db()
3032 .await
3033 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3034 .await?;
3035 if let Some(host_version) = session
3036 .connection_pool()
3037 .await
3038 .connection(host_connection_id)
3039 .map(|c| c.zed_version)
3040 {
3041 if let Some(min_required_version) = request.required_host_version() {
3042 if min_required_version > host_version {
3043 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3044 .with_tag("required", &min_required_version.to_string())))?;
3045 }
3046 }
3047 }
3048
3049 let payload = session
3050 .peer
3051 .forward_request(session.connection_id, host_connection_id, request)
3052 .await?;
3053 response.send(payload)?;
3054 Ok(())
3055}
3056
3057/// Notify other participants that a new buffer has been created
3058async fn create_buffer_for_peer(
3059 request: proto::CreateBufferForPeer,
3060 session: Session,
3061) -> Result<()> {
3062 session
3063 .db()
3064 .await
3065 .check_user_is_project_host(
3066 ProjectId::from_proto(request.project_id),
3067 session.connection_id,
3068 )
3069 .await?;
3070 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3071 session
3072 .peer
3073 .forward_send(session.connection_id, peer_id.into(), request)?;
3074 Ok(())
3075}
3076
3077/// Notify other participants that a buffer has been updated. This is
3078/// allowed for guests as long as the update is limited to selections.
3079async fn update_buffer(
3080 request: proto::UpdateBuffer,
3081 response: Response<proto::UpdateBuffer>,
3082 session: Session,
3083) -> Result<()> {
3084 let project_id = ProjectId::from_proto(request.project_id);
3085 let mut capability = Capability::ReadOnly;
3086
3087 for op in request.operations.iter() {
3088 match op.variant {
3089 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3090 Some(_) => capability = Capability::ReadWrite,
3091 }
3092 }
3093
3094 let host = {
3095 let guard = session
3096 .db()
3097 .await
3098 .connections_for_buffer_update(
3099 project_id,
3100 session.principal_id(),
3101 session.connection_id,
3102 capability,
3103 )
3104 .await?;
3105
3106 let (host, guests) = &*guard;
3107
3108 broadcast(
3109 Some(session.connection_id),
3110 guests.clone(),
3111 |connection_id| {
3112 session
3113 .peer
3114 .forward_send(session.connection_id, connection_id, request.clone())
3115 },
3116 );
3117
3118 *host
3119 };
3120
3121 if host != session.connection_id {
3122 session
3123 .peer
3124 .forward_request(session.connection_id, host, request.clone())
3125 .await?;
3126 }
3127
3128 response.send(proto::Ack {})?;
3129 Ok(())
3130}
3131
3132async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3133 let project_id = ProjectId::from_proto(message.project_id);
3134
3135 let operation = message.operation.as_ref().context("invalid operation")?;
3136 let capability = match operation.variant.as_ref() {
3137 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3138 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3139 match buffer_op.variant {
3140 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3141 Capability::ReadOnly
3142 }
3143 _ => Capability::ReadWrite,
3144 }
3145 } else {
3146 Capability::ReadWrite
3147 }
3148 }
3149 Some(_) => Capability::ReadWrite,
3150 None => Capability::ReadOnly,
3151 };
3152
3153 let guard = session
3154 .db()
3155 .await
3156 .connections_for_buffer_update(
3157 project_id,
3158 session.principal_id(),
3159 session.connection_id,
3160 capability,
3161 )
3162 .await?;
3163
3164 let (host, guests) = &*guard;
3165
3166 broadcast(
3167 Some(session.connection_id),
3168 guests.iter().chain([host]).copied(),
3169 |connection_id| {
3170 session
3171 .peer
3172 .forward_send(session.connection_id, connection_id, message.clone())
3173 },
3174 );
3175
3176 Ok(())
3177}
3178
3179/// Notify other participants that a project has been updated.
3180async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3181 request: T,
3182 session: Session,
3183) -> Result<()> {
3184 let project_id = ProjectId::from_proto(request.remote_entity_id());
3185 let project_connection_ids = session
3186 .db()
3187 .await
3188 .project_connection_ids(project_id, session.connection_id, false)
3189 .await?;
3190
3191 broadcast(
3192 Some(session.connection_id),
3193 project_connection_ids.iter().copied(),
3194 |connection_id| {
3195 session
3196 .peer
3197 .forward_send(session.connection_id, connection_id, request.clone())
3198 },
3199 );
3200 Ok(())
3201}
3202
3203/// Start following another user in a call.
3204async fn follow(
3205 request: proto::Follow,
3206 response: Response<proto::Follow>,
3207 session: UserSession,
3208) -> Result<()> {
3209 let room_id = RoomId::from_proto(request.room_id);
3210 let project_id = request.project_id.map(ProjectId::from_proto);
3211 let leader_id = request
3212 .leader_id
3213 .ok_or_else(|| anyhow!("invalid leader id"))?
3214 .into();
3215 let follower_id = session.connection_id;
3216
3217 session
3218 .db()
3219 .await
3220 .check_room_participants(room_id, leader_id, session.connection_id)
3221 .await?;
3222
3223 let response_payload = session
3224 .peer
3225 .forward_request(session.connection_id, leader_id, request)
3226 .await?;
3227 response.send(response_payload)?;
3228
3229 if let Some(project_id) = project_id {
3230 let room = session
3231 .db()
3232 .await
3233 .follow(room_id, project_id, leader_id, follower_id)
3234 .await?;
3235 room_updated(&room, &session.peer);
3236 }
3237
3238 Ok(())
3239}
3240
3241/// Stop following another user in a call.
3242async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3243 let room_id = RoomId::from_proto(request.room_id);
3244 let project_id = request.project_id.map(ProjectId::from_proto);
3245 let leader_id = request
3246 .leader_id
3247 .ok_or_else(|| anyhow!("invalid leader id"))?
3248 .into();
3249 let follower_id = session.connection_id;
3250
3251 session
3252 .db()
3253 .await
3254 .check_room_participants(room_id, leader_id, session.connection_id)
3255 .await?;
3256
3257 session
3258 .peer
3259 .forward_send(session.connection_id, leader_id, request)?;
3260
3261 if let Some(project_id) = project_id {
3262 let room = session
3263 .db()
3264 .await
3265 .unfollow(room_id, project_id, leader_id, follower_id)
3266 .await?;
3267 room_updated(&room, &session.peer);
3268 }
3269
3270 Ok(())
3271}
3272
3273/// Notify everyone following you of your current location.
3274async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3275 let room_id = RoomId::from_proto(request.room_id);
3276 let database = session.db.lock().await;
3277
3278 let connection_ids = if let Some(project_id) = request.project_id {
3279 let project_id = ProjectId::from_proto(project_id);
3280 database
3281 .project_connection_ids(project_id, session.connection_id, true)
3282 .await?
3283 } else {
3284 database
3285 .room_connection_ids(room_id, session.connection_id)
3286 .await?
3287 };
3288
3289 // For now, don't send view update messages back to that view's current leader.
3290 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3291 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3292 _ => None,
3293 });
3294
3295 for connection_id in connection_ids.iter().cloned() {
3296 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3297 session
3298 .peer
3299 .forward_send(session.connection_id, connection_id, request.clone())?;
3300 }
3301 }
3302 Ok(())
3303}
3304
3305/// Get public data about users.
3306async fn get_users(
3307 request: proto::GetUsers,
3308 response: Response<proto::GetUsers>,
3309 session: Session,
3310) -> Result<()> {
3311 let user_ids = request
3312 .user_ids
3313 .into_iter()
3314 .map(UserId::from_proto)
3315 .collect();
3316 let users = session
3317 .db()
3318 .await
3319 .get_users_by_ids(user_ids)
3320 .await?
3321 .into_iter()
3322 .map(|user| proto::User {
3323 id: user.id.to_proto(),
3324 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3325 github_login: user.github_login,
3326 })
3327 .collect();
3328 response.send(proto::UsersResponse { users })?;
3329 Ok(())
3330}
3331
3332/// Search for users (to invite) buy Github login
3333async fn fuzzy_search_users(
3334 request: proto::FuzzySearchUsers,
3335 response: Response<proto::FuzzySearchUsers>,
3336 session: UserSession,
3337) -> Result<()> {
3338 let query = request.query;
3339 let users = match query.len() {
3340 0 => vec![],
3341 1 | 2 => session
3342 .db()
3343 .await
3344 .get_user_by_github_login(&query)
3345 .await?
3346 .into_iter()
3347 .collect(),
3348 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3349 };
3350 let users = users
3351 .into_iter()
3352 .filter(|user| user.id != session.user_id())
3353 .map(|user| proto::User {
3354 id: user.id.to_proto(),
3355 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3356 github_login: user.github_login,
3357 })
3358 .collect();
3359 response.send(proto::UsersResponse { users })?;
3360 Ok(())
3361}
3362
3363/// Send a contact request to another user.
3364async fn request_contact(
3365 request: proto::RequestContact,
3366 response: Response<proto::RequestContact>,
3367 session: UserSession,
3368) -> Result<()> {
3369 let requester_id = session.user_id();
3370 let responder_id = UserId::from_proto(request.responder_id);
3371 if requester_id == responder_id {
3372 return Err(anyhow!("cannot add yourself as a contact"))?;
3373 }
3374
3375 let notifications = session
3376 .db()
3377 .await
3378 .send_contact_request(requester_id, responder_id)
3379 .await?;
3380
3381 // Update outgoing contact requests of requester
3382 let mut update = proto::UpdateContacts::default();
3383 update.outgoing_requests.push(responder_id.to_proto());
3384 for connection_id in session
3385 .connection_pool()
3386 .await
3387 .user_connection_ids(requester_id)
3388 {
3389 session.peer.send(connection_id, update.clone())?;
3390 }
3391
3392 // Update incoming contact requests of responder
3393 let mut update = proto::UpdateContacts::default();
3394 update
3395 .incoming_requests
3396 .push(proto::IncomingContactRequest {
3397 requester_id: requester_id.to_proto(),
3398 });
3399 let connection_pool = session.connection_pool().await;
3400 for connection_id in connection_pool.user_connection_ids(responder_id) {
3401 session.peer.send(connection_id, update.clone())?;
3402 }
3403
3404 send_notifications(&connection_pool, &session.peer, notifications);
3405
3406 response.send(proto::Ack {})?;
3407 Ok(())
3408}
3409
3410/// Accept or decline a contact request
3411async fn respond_to_contact_request(
3412 request: proto::RespondToContactRequest,
3413 response: Response<proto::RespondToContactRequest>,
3414 session: UserSession,
3415) -> Result<()> {
3416 let responder_id = session.user_id();
3417 let requester_id = UserId::from_proto(request.requester_id);
3418 let db = session.db().await;
3419 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3420 db.dismiss_contact_notification(responder_id, requester_id)
3421 .await?;
3422 } else {
3423 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3424
3425 let notifications = db
3426 .respond_to_contact_request(responder_id, requester_id, accept)
3427 .await?;
3428 let requester_busy = db.is_user_busy(requester_id).await?;
3429 let responder_busy = db.is_user_busy(responder_id).await?;
3430
3431 let pool = session.connection_pool().await;
3432 // Update responder with new contact
3433 let mut update = proto::UpdateContacts::default();
3434 if accept {
3435 update
3436 .contacts
3437 .push(contact_for_user(requester_id, requester_busy, &pool));
3438 }
3439 update
3440 .remove_incoming_requests
3441 .push(requester_id.to_proto());
3442 for connection_id in pool.user_connection_ids(responder_id) {
3443 session.peer.send(connection_id, update.clone())?;
3444 }
3445
3446 // Update requester with new contact
3447 let mut update = proto::UpdateContacts::default();
3448 if accept {
3449 update
3450 .contacts
3451 .push(contact_for_user(responder_id, responder_busy, &pool));
3452 }
3453 update
3454 .remove_outgoing_requests
3455 .push(responder_id.to_proto());
3456
3457 for connection_id in pool.user_connection_ids(requester_id) {
3458 session.peer.send(connection_id, update.clone())?;
3459 }
3460
3461 send_notifications(&pool, &session.peer, notifications);
3462 }
3463
3464 response.send(proto::Ack {})?;
3465 Ok(())
3466}
3467
3468/// Remove a contact.
3469async fn remove_contact(
3470 request: proto::RemoveContact,
3471 response: Response<proto::RemoveContact>,
3472 session: UserSession,
3473) -> Result<()> {
3474 let requester_id = session.user_id();
3475 let responder_id = UserId::from_proto(request.user_id);
3476 let db = session.db().await;
3477 let (contact_accepted, deleted_notification_id) =
3478 db.remove_contact(requester_id, responder_id).await?;
3479
3480 let pool = session.connection_pool().await;
3481 // Update outgoing contact requests of requester
3482 let mut update = proto::UpdateContacts::default();
3483 if contact_accepted {
3484 update.remove_contacts.push(responder_id.to_proto());
3485 } else {
3486 update
3487 .remove_outgoing_requests
3488 .push(responder_id.to_proto());
3489 }
3490 for connection_id in pool.user_connection_ids(requester_id) {
3491 session.peer.send(connection_id, update.clone())?;
3492 }
3493
3494 // Update incoming contact requests of responder
3495 let mut update = proto::UpdateContacts::default();
3496 if contact_accepted {
3497 update.remove_contacts.push(requester_id.to_proto());
3498 } else {
3499 update
3500 .remove_incoming_requests
3501 .push(requester_id.to_proto());
3502 }
3503 for connection_id in pool.user_connection_ids(responder_id) {
3504 session.peer.send(connection_id, update.clone())?;
3505 if let Some(notification_id) = deleted_notification_id {
3506 session.peer.send(
3507 connection_id,
3508 proto::DeleteNotification {
3509 notification_id: notification_id.to_proto(),
3510 },
3511 )?;
3512 }
3513 }
3514
3515 response.send(proto::Ack {})?;
3516 Ok(())
3517}
3518
3519fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3520 version.0.minor() < 139
3521}
3522
3523async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3524 subscribe_user_to_channels(
3525 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3526 &session,
3527 )
3528 .await?;
3529 Ok(())
3530}
3531
3532async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3533 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3534 let mut pool = session.connection_pool().await;
3535 for membership in &channels_for_user.channel_memberships {
3536 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3537 }
3538 session.peer.send(
3539 session.connection_id,
3540 build_update_user_channels(&channels_for_user),
3541 )?;
3542 session.peer.send(
3543 session.connection_id,
3544 build_channels_update(channels_for_user),
3545 )?;
3546 Ok(())
3547}
3548
3549/// Creates a new channel.
3550async fn create_channel(
3551 request: proto::CreateChannel,
3552 response: Response<proto::CreateChannel>,
3553 session: UserSession,
3554) -> Result<()> {
3555 let db = session.db().await;
3556
3557 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3558 let (channel, membership) = db
3559 .create_channel(&request.name, parent_id, session.user_id())
3560 .await?;
3561
3562 let root_id = channel.root_id();
3563 let channel = Channel::from_model(channel);
3564
3565 response.send(proto::CreateChannelResponse {
3566 channel: Some(channel.to_proto()),
3567 parent_id: request.parent_id,
3568 })?;
3569
3570 let mut connection_pool = session.connection_pool().await;
3571 if let Some(membership) = membership {
3572 connection_pool.subscribe_to_channel(
3573 membership.user_id,
3574 membership.channel_id,
3575 membership.role,
3576 );
3577 let update = proto::UpdateUserChannels {
3578 channel_memberships: vec![proto::ChannelMembership {
3579 channel_id: membership.channel_id.to_proto(),
3580 role: membership.role.into(),
3581 }],
3582 ..Default::default()
3583 };
3584 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3585 session.peer.send(connection_id, update.clone())?;
3586 }
3587 }
3588
3589 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3590 if !role.can_see_channel(channel.visibility) {
3591 continue;
3592 }
3593
3594 let update = proto::UpdateChannels {
3595 channels: vec![channel.to_proto()],
3596 ..Default::default()
3597 };
3598 session.peer.send(connection_id, update.clone())?;
3599 }
3600
3601 Ok(())
3602}
3603
3604/// Delete a channel
3605async fn delete_channel(
3606 request: proto::DeleteChannel,
3607 response: Response<proto::DeleteChannel>,
3608 session: UserSession,
3609) -> Result<()> {
3610 let db = session.db().await;
3611
3612 let channel_id = request.channel_id;
3613 let (root_channel, removed_channels) = db
3614 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3615 .await?;
3616 response.send(proto::Ack {})?;
3617
3618 // Notify members of removed channels
3619 let mut update = proto::UpdateChannels::default();
3620 update
3621 .delete_channels
3622 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3623
3624 let connection_pool = session.connection_pool().await;
3625 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3626 session.peer.send(connection_id, update.clone())?;
3627 }
3628
3629 Ok(())
3630}
3631
3632/// Invite someone to join a channel.
3633async fn invite_channel_member(
3634 request: proto::InviteChannelMember,
3635 response: Response<proto::InviteChannelMember>,
3636 session: UserSession,
3637) -> Result<()> {
3638 let db = session.db().await;
3639 let channel_id = ChannelId::from_proto(request.channel_id);
3640 let invitee_id = UserId::from_proto(request.user_id);
3641 let InviteMemberResult {
3642 channel,
3643 notifications,
3644 } = db
3645 .invite_channel_member(
3646 channel_id,
3647 invitee_id,
3648 session.user_id(),
3649 request.role().into(),
3650 )
3651 .await?;
3652
3653 let update = proto::UpdateChannels {
3654 channel_invitations: vec![channel.to_proto()],
3655 ..Default::default()
3656 };
3657
3658 let connection_pool = session.connection_pool().await;
3659 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3660 session.peer.send(connection_id, update.clone())?;
3661 }
3662
3663 send_notifications(&connection_pool, &session.peer, notifications);
3664
3665 response.send(proto::Ack {})?;
3666 Ok(())
3667}
3668
3669/// remove someone from a channel
3670async fn remove_channel_member(
3671 request: proto::RemoveChannelMember,
3672 response: Response<proto::RemoveChannelMember>,
3673 session: UserSession,
3674) -> Result<()> {
3675 let db = session.db().await;
3676 let channel_id = ChannelId::from_proto(request.channel_id);
3677 let member_id = UserId::from_proto(request.user_id);
3678
3679 let RemoveChannelMemberResult {
3680 membership_update,
3681 notification_id,
3682 } = db
3683 .remove_channel_member(channel_id, member_id, session.user_id())
3684 .await?;
3685
3686 let mut connection_pool = session.connection_pool().await;
3687 notify_membership_updated(
3688 &mut connection_pool,
3689 membership_update,
3690 member_id,
3691 &session.peer,
3692 );
3693 for connection_id in connection_pool.user_connection_ids(member_id) {
3694 if let Some(notification_id) = notification_id {
3695 session
3696 .peer
3697 .send(
3698 connection_id,
3699 proto::DeleteNotification {
3700 notification_id: notification_id.to_proto(),
3701 },
3702 )
3703 .trace_err();
3704 }
3705 }
3706
3707 response.send(proto::Ack {})?;
3708 Ok(())
3709}
3710
3711/// Toggle the channel between public and private.
3712/// Care is taken to maintain the invariant that public channels only descend from public channels,
3713/// (though members-only channels can appear at any point in the hierarchy).
3714async fn set_channel_visibility(
3715 request: proto::SetChannelVisibility,
3716 response: Response<proto::SetChannelVisibility>,
3717 session: UserSession,
3718) -> Result<()> {
3719 let db = session.db().await;
3720 let channel_id = ChannelId::from_proto(request.channel_id);
3721 let visibility = request.visibility().into();
3722
3723 let channel_model = db
3724 .set_channel_visibility(channel_id, visibility, session.user_id())
3725 .await?;
3726 let root_id = channel_model.root_id();
3727 let channel = Channel::from_model(channel_model);
3728
3729 let mut connection_pool = session.connection_pool().await;
3730 for (user_id, role) in connection_pool
3731 .channel_user_ids(root_id)
3732 .collect::<Vec<_>>()
3733 .into_iter()
3734 {
3735 let update = if role.can_see_channel(channel.visibility) {
3736 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3737 proto::UpdateChannels {
3738 channels: vec![channel.to_proto()],
3739 ..Default::default()
3740 }
3741 } else {
3742 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3743 proto::UpdateChannels {
3744 delete_channels: vec![channel.id.to_proto()],
3745 ..Default::default()
3746 }
3747 };
3748
3749 for connection_id in connection_pool.user_connection_ids(user_id) {
3750 session.peer.send(connection_id, update.clone())?;
3751 }
3752 }
3753
3754 response.send(proto::Ack {})?;
3755 Ok(())
3756}
3757
3758/// Alter the role for a user in the channel.
3759async fn set_channel_member_role(
3760 request: proto::SetChannelMemberRole,
3761 response: Response<proto::SetChannelMemberRole>,
3762 session: UserSession,
3763) -> Result<()> {
3764 let db = session.db().await;
3765 let channel_id = ChannelId::from_proto(request.channel_id);
3766 let member_id = UserId::from_proto(request.user_id);
3767 let result = db
3768 .set_channel_member_role(
3769 channel_id,
3770 session.user_id(),
3771 member_id,
3772 request.role().into(),
3773 )
3774 .await?;
3775
3776 match result {
3777 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3778 let mut connection_pool = session.connection_pool().await;
3779 notify_membership_updated(
3780 &mut connection_pool,
3781 membership_update,
3782 member_id,
3783 &session.peer,
3784 )
3785 }
3786 db::SetMemberRoleResult::InviteUpdated(channel) => {
3787 let update = proto::UpdateChannels {
3788 channel_invitations: vec![channel.to_proto()],
3789 ..Default::default()
3790 };
3791
3792 for connection_id in session
3793 .connection_pool()
3794 .await
3795 .user_connection_ids(member_id)
3796 {
3797 session.peer.send(connection_id, update.clone())?;
3798 }
3799 }
3800 }
3801
3802 response.send(proto::Ack {})?;
3803 Ok(())
3804}
3805
3806/// Change the name of a channel
3807async fn rename_channel(
3808 request: proto::RenameChannel,
3809 response: Response<proto::RenameChannel>,
3810 session: UserSession,
3811) -> Result<()> {
3812 let db = session.db().await;
3813 let channel_id = ChannelId::from_proto(request.channel_id);
3814 let channel_model = db
3815 .rename_channel(channel_id, session.user_id(), &request.name)
3816 .await?;
3817 let root_id = channel_model.root_id();
3818 let channel = Channel::from_model(channel_model);
3819
3820 response.send(proto::RenameChannelResponse {
3821 channel: Some(channel.to_proto()),
3822 })?;
3823
3824 let connection_pool = session.connection_pool().await;
3825 let update = proto::UpdateChannels {
3826 channels: vec![channel.to_proto()],
3827 ..Default::default()
3828 };
3829 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3830 if role.can_see_channel(channel.visibility) {
3831 session.peer.send(connection_id, update.clone())?;
3832 }
3833 }
3834
3835 Ok(())
3836}
3837
3838/// Move a channel to a new parent.
3839async fn move_channel(
3840 request: proto::MoveChannel,
3841 response: Response<proto::MoveChannel>,
3842 session: UserSession,
3843) -> Result<()> {
3844 let channel_id = ChannelId::from_proto(request.channel_id);
3845 let to = ChannelId::from_proto(request.to);
3846
3847 let (root_id, channels) = session
3848 .db()
3849 .await
3850 .move_channel(channel_id, to, session.user_id())
3851 .await?;
3852
3853 let connection_pool = session.connection_pool().await;
3854 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3855 let channels = channels
3856 .iter()
3857 .filter_map(|channel| {
3858 if role.can_see_channel(channel.visibility) {
3859 Some(channel.to_proto())
3860 } else {
3861 None
3862 }
3863 })
3864 .collect::<Vec<_>>();
3865 if channels.is_empty() {
3866 continue;
3867 }
3868
3869 let update = proto::UpdateChannels {
3870 channels,
3871 ..Default::default()
3872 };
3873
3874 session.peer.send(connection_id, update.clone())?;
3875 }
3876
3877 response.send(Ack {})?;
3878 Ok(())
3879}
3880
3881/// Get the list of channel members
3882async fn get_channel_members(
3883 request: proto::GetChannelMembers,
3884 response: Response<proto::GetChannelMembers>,
3885 session: UserSession,
3886) -> Result<()> {
3887 let db = session.db().await;
3888 let channel_id = ChannelId::from_proto(request.channel_id);
3889 let limit = if request.limit == 0 {
3890 u16::MAX as u64
3891 } else {
3892 request.limit
3893 };
3894 let (members, users) = db
3895 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3896 .await?;
3897 response.send(proto::GetChannelMembersResponse { members, users })?;
3898 Ok(())
3899}
3900
3901/// Accept or decline a channel invitation.
3902async fn respond_to_channel_invite(
3903 request: proto::RespondToChannelInvite,
3904 response: Response<proto::RespondToChannelInvite>,
3905 session: UserSession,
3906) -> Result<()> {
3907 let db = session.db().await;
3908 let channel_id = ChannelId::from_proto(request.channel_id);
3909 let RespondToChannelInvite {
3910 membership_update,
3911 notifications,
3912 } = db
3913 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3914 .await?;
3915
3916 let mut connection_pool = session.connection_pool().await;
3917 if let Some(membership_update) = membership_update {
3918 notify_membership_updated(
3919 &mut connection_pool,
3920 membership_update,
3921 session.user_id(),
3922 &session.peer,
3923 );
3924 } else {
3925 let update = proto::UpdateChannels {
3926 remove_channel_invitations: vec![channel_id.to_proto()],
3927 ..Default::default()
3928 };
3929
3930 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3931 session.peer.send(connection_id, update.clone())?;
3932 }
3933 };
3934
3935 send_notifications(&connection_pool, &session.peer, notifications);
3936
3937 response.send(proto::Ack {})?;
3938
3939 Ok(())
3940}
3941
3942/// Join the channels' room
3943async fn join_channel(
3944 request: proto::JoinChannel,
3945 response: Response<proto::JoinChannel>,
3946 session: UserSession,
3947) -> Result<()> {
3948 let channel_id = ChannelId::from_proto(request.channel_id);
3949 join_channel_internal(channel_id, Box::new(response), session).await
3950}
3951
3952trait JoinChannelInternalResponse {
3953 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3954}
3955impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3956 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3957 Response::<proto::JoinChannel>::send(self, result)
3958 }
3959}
3960impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3961 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3962 Response::<proto::JoinRoom>::send(self, result)
3963 }
3964}
3965
3966async fn join_channel_internal(
3967 channel_id: ChannelId,
3968 response: Box<impl JoinChannelInternalResponse>,
3969 session: UserSession,
3970) -> Result<()> {
3971 let joined_room = {
3972 let mut db = session.db().await;
3973 // If zed quits without leaving the room, and the user re-opens zed before the
3974 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3975 // room they were in.
3976 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3977 tracing::info!(
3978 stale_connection_id = %connection,
3979 "cleaning up stale connection",
3980 );
3981 drop(db);
3982 leave_room_for_session(&session, connection).await?;
3983 db = session.db().await;
3984 }
3985
3986 let (joined_room, membership_updated, role) = db
3987 .join_channel(channel_id, session.user_id(), session.connection_id)
3988 .await?;
3989
3990 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3991 let (can_publish, token) = if role == ChannelRole::Guest {
3992 (
3993 false,
3994 live_kit
3995 .guest_token(
3996 &joined_room.room.live_kit_room,
3997 &session.user_id().to_string(),
3998 )
3999 .trace_err()?,
4000 )
4001 } else {
4002 (
4003 true,
4004 live_kit
4005 .room_token(
4006 &joined_room.room.live_kit_room,
4007 &session.user_id().to_string(),
4008 )
4009 .trace_err()?,
4010 )
4011 };
4012
4013 Some(LiveKitConnectionInfo {
4014 server_url: live_kit.url().into(),
4015 token,
4016 can_publish,
4017 })
4018 });
4019
4020 response.send(proto::JoinRoomResponse {
4021 room: Some(joined_room.room.clone()),
4022 channel_id: joined_room
4023 .channel
4024 .as_ref()
4025 .map(|channel| channel.id.to_proto()),
4026 live_kit_connection_info,
4027 })?;
4028
4029 let mut connection_pool = session.connection_pool().await;
4030 if let Some(membership_updated) = membership_updated {
4031 notify_membership_updated(
4032 &mut connection_pool,
4033 membership_updated,
4034 session.user_id(),
4035 &session.peer,
4036 );
4037 }
4038
4039 room_updated(&joined_room.room, &session.peer);
4040
4041 joined_room
4042 };
4043
4044 channel_updated(
4045 &joined_room
4046 .channel
4047 .ok_or_else(|| anyhow!("channel not returned"))?,
4048 &joined_room.room,
4049 &session.peer,
4050 &*session.connection_pool().await,
4051 );
4052
4053 update_user_contacts(session.user_id(), &session).await?;
4054 Ok(())
4055}
4056
4057/// Start editing the channel notes
4058async fn join_channel_buffer(
4059 request: proto::JoinChannelBuffer,
4060 response: Response<proto::JoinChannelBuffer>,
4061 session: UserSession,
4062) -> Result<()> {
4063 let db = session.db().await;
4064 let channel_id = ChannelId::from_proto(request.channel_id);
4065
4066 let open_response = db
4067 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4068 .await?;
4069
4070 let collaborators = open_response.collaborators.clone();
4071 response.send(open_response)?;
4072
4073 let update = UpdateChannelBufferCollaborators {
4074 channel_id: channel_id.to_proto(),
4075 collaborators: collaborators.clone(),
4076 };
4077 channel_buffer_updated(
4078 session.connection_id,
4079 collaborators
4080 .iter()
4081 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4082 &update,
4083 &session.peer,
4084 );
4085
4086 Ok(())
4087}
4088
4089/// Edit the channel notes
4090async fn update_channel_buffer(
4091 request: proto::UpdateChannelBuffer,
4092 session: UserSession,
4093) -> Result<()> {
4094 let db = session.db().await;
4095 let channel_id = ChannelId::from_proto(request.channel_id);
4096
4097 let (collaborators, epoch, version) = db
4098 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4099 .await?;
4100
4101 channel_buffer_updated(
4102 session.connection_id,
4103 collaborators.clone(),
4104 &proto::UpdateChannelBuffer {
4105 channel_id: channel_id.to_proto(),
4106 operations: request.operations,
4107 },
4108 &session.peer,
4109 );
4110
4111 let pool = &*session.connection_pool().await;
4112
4113 let non_collaborators =
4114 pool.channel_connection_ids(channel_id)
4115 .filter_map(|(connection_id, _)| {
4116 if collaborators.contains(&connection_id) {
4117 None
4118 } else {
4119 Some(connection_id)
4120 }
4121 });
4122
4123 broadcast(None, non_collaborators, |peer_id| {
4124 session.peer.send(
4125 peer_id,
4126 proto::UpdateChannels {
4127 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4128 channel_id: channel_id.to_proto(),
4129 epoch: epoch as u64,
4130 version: version.clone(),
4131 }],
4132 ..Default::default()
4133 },
4134 )
4135 });
4136
4137 Ok(())
4138}
4139
4140/// Rejoin the channel notes after a connection blip
4141async fn rejoin_channel_buffers(
4142 request: proto::RejoinChannelBuffers,
4143 response: Response<proto::RejoinChannelBuffers>,
4144 session: UserSession,
4145) -> Result<()> {
4146 let db = session.db().await;
4147 let buffers = db
4148 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4149 .await?;
4150
4151 for rejoined_buffer in &buffers {
4152 let collaborators_to_notify = rejoined_buffer
4153 .buffer
4154 .collaborators
4155 .iter()
4156 .filter_map(|c| Some(c.peer_id?.into()));
4157 channel_buffer_updated(
4158 session.connection_id,
4159 collaborators_to_notify,
4160 &proto::UpdateChannelBufferCollaborators {
4161 channel_id: rejoined_buffer.buffer.channel_id,
4162 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4163 },
4164 &session.peer,
4165 );
4166 }
4167
4168 response.send(proto::RejoinChannelBuffersResponse {
4169 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4170 })?;
4171
4172 Ok(())
4173}
4174
4175/// Stop editing the channel notes
4176async fn leave_channel_buffer(
4177 request: proto::LeaveChannelBuffer,
4178 response: Response<proto::LeaveChannelBuffer>,
4179 session: UserSession,
4180) -> Result<()> {
4181 let db = session.db().await;
4182 let channel_id = ChannelId::from_proto(request.channel_id);
4183
4184 let left_buffer = db
4185 .leave_channel_buffer(channel_id, session.connection_id)
4186 .await?;
4187
4188 response.send(Ack {})?;
4189
4190 channel_buffer_updated(
4191 session.connection_id,
4192 left_buffer.connections,
4193 &proto::UpdateChannelBufferCollaborators {
4194 channel_id: channel_id.to_proto(),
4195 collaborators: left_buffer.collaborators,
4196 },
4197 &session.peer,
4198 );
4199
4200 Ok(())
4201}
4202
4203fn channel_buffer_updated<T: EnvelopedMessage>(
4204 sender_id: ConnectionId,
4205 collaborators: impl IntoIterator<Item = ConnectionId>,
4206 message: &T,
4207 peer: &Peer,
4208) {
4209 broadcast(Some(sender_id), collaborators, |peer_id| {
4210 peer.send(peer_id, message.clone())
4211 });
4212}
4213
4214fn send_notifications(
4215 connection_pool: &ConnectionPool,
4216 peer: &Peer,
4217 notifications: db::NotificationBatch,
4218) {
4219 for (user_id, notification) in notifications {
4220 for connection_id in connection_pool.user_connection_ids(user_id) {
4221 if let Err(error) = peer.send(
4222 connection_id,
4223 proto::AddNotification {
4224 notification: Some(notification.clone()),
4225 },
4226 ) {
4227 tracing::error!(
4228 "failed to send notification to {:?} {}",
4229 connection_id,
4230 error
4231 );
4232 }
4233 }
4234 }
4235}
4236
4237/// Send a message to the channel
4238async fn send_channel_message(
4239 request: proto::SendChannelMessage,
4240 response: Response<proto::SendChannelMessage>,
4241 session: UserSession,
4242) -> Result<()> {
4243 // Validate the message body.
4244 let body = request.body.trim().to_string();
4245 if body.len() > MAX_MESSAGE_LEN {
4246 return Err(anyhow!("message is too long"))?;
4247 }
4248 if body.is_empty() {
4249 return Err(anyhow!("message can't be blank"))?;
4250 }
4251
4252 // TODO: adjust mentions if body is trimmed
4253
4254 let timestamp = OffsetDateTime::now_utc();
4255 let nonce = request
4256 .nonce
4257 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4258
4259 let channel_id = ChannelId::from_proto(request.channel_id);
4260 let CreatedChannelMessage {
4261 message_id,
4262 participant_connection_ids,
4263 notifications,
4264 } = session
4265 .db()
4266 .await
4267 .create_channel_message(
4268 channel_id,
4269 session.user_id(),
4270 &body,
4271 &request.mentions,
4272 timestamp,
4273 nonce.clone().into(),
4274 match request.reply_to_message_id {
4275 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4276 None => None,
4277 },
4278 )
4279 .await?;
4280
4281 let message = proto::ChannelMessage {
4282 sender_id: session.user_id().to_proto(),
4283 id: message_id.to_proto(),
4284 body,
4285 mentions: request.mentions,
4286 timestamp: timestamp.unix_timestamp() as u64,
4287 nonce: Some(nonce),
4288 reply_to_message_id: request.reply_to_message_id,
4289 edited_at: None,
4290 };
4291 broadcast(
4292 Some(session.connection_id),
4293 participant_connection_ids.clone(),
4294 |connection| {
4295 session.peer.send(
4296 connection,
4297 proto::ChannelMessageSent {
4298 channel_id: channel_id.to_proto(),
4299 message: Some(message.clone()),
4300 },
4301 )
4302 },
4303 );
4304 response.send(proto::SendChannelMessageResponse {
4305 message: Some(message),
4306 })?;
4307
4308 let pool = &*session.connection_pool().await;
4309 let non_participants =
4310 pool.channel_connection_ids(channel_id)
4311 .filter_map(|(connection_id, _)| {
4312 if participant_connection_ids.contains(&connection_id) {
4313 None
4314 } else {
4315 Some(connection_id)
4316 }
4317 });
4318 broadcast(None, non_participants, |peer_id| {
4319 session.peer.send(
4320 peer_id,
4321 proto::UpdateChannels {
4322 latest_channel_message_ids: vec![proto::ChannelMessageId {
4323 channel_id: channel_id.to_proto(),
4324 message_id: message_id.to_proto(),
4325 }],
4326 ..Default::default()
4327 },
4328 )
4329 });
4330 send_notifications(pool, &session.peer, notifications);
4331
4332 Ok(())
4333}
4334
4335/// Delete a channel message
4336async fn remove_channel_message(
4337 request: proto::RemoveChannelMessage,
4338 response: Response<proto::RemoveChannelMessage>,
4339 session: UserSession,
4340) -> Result<()> {
4341 let channel_id = ChannelId::from_proto(request.channel_id);
4342 let message_id = MessageId::from_proto(request.message_id);
4343 let (connection_ids, existing_notification_ids) = session
4344 .db()
4345 .await
4346 .remove_channel_message(channel_id, message_id, session.user_id())
4347 .await?;
4348
4349 broadcast(
4350 Some(session.connection_id),
4351 connection_ids,
4352 move |connection| {
4353 session.peer.send(connection, request.clone())?;
4354
4355 for notification_id in &existing_notification_ids {
4356 session.peer.send(
4357 connection,
4358 proto::DeleteNotification {
4359 notification_id: (*notification_id).to_proto(),
4360 },
4361 )?;
4362 }
4363
4364 Ok(())
4365 },
4366 );
4367 response.send(proto::Ack {})?;
4368 Ok(())
4369}
4370
4371async fn update_channel_message(
4372 request: proto::UpdateChannelMessage,
4373 response: Response<proto::UpdateChannelMessage>,
4374 session: UserSession,
4375) -> Result<()> {
4376 let channel_id = ChannelId::from_proto(request.channel_id);
4377 let message_id = MessageId::from_proto(request.message_id);
4378 let updated_at = OffsetDateTime::now_utc();
4379 let UpdatedChannelMessage {
4380 message_id,
4381 participant_connection_ids,
4382 notifications,
4383 reply_to_message_id,
4384 timestamp,
4385 deleted_mention_notification_ids,
4386 updated_mention_notifications,
4387 } = session
4388 .db()
4389 .await
4390 .update_channel_message(
4391 channel_id,
4392 message_id,
4393 session.user_id(),
4394 request.body.as_str(),
4395 &request.mentions,
4396 updated_at,
4397 )
4398 .await?;
4399
4400 let nonce = request
4401 .nonce
4402 .clone()
4403 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4404
4405 let message = proto::ChannelMessage {
4406 sender_id: session.user_id().to_proto(),
4407 id: message_id.to_proto(),
4408 body: request.body.clone(),
4409 mentions: request.mentions.clone(),
4410 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4411 nonce: Some(nonce),
4412 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4413 edited_at: Some(updated_at.unix_timestamp() as u64),
4414 };
4415
4416 response.send(proto::Ack {})?;
4417
4418 let pool = &*session.connection_pool().await;
4419 broadcast(
4420 Some(session.connection_id),
4421 participant_connection_ids,
4422 |connection| {
4423 session.peer.send(
4424 connection,
4425 proto::ChannelMessageUpdate {
4426 channel_id: channel_id.to_proto(),
4427 message: Some(message.clone()),
4428 },
4429 )?;
4430
4431 for notification_id in &deleted_mention_notification_ids {
4432 session.peer.send(
4433 connection,
4434 proto::DeleteNotification {
4435 notification_id: (*notification_id).to_proto(),
4436 },
4437 )?;
4438 }
4439
4440 for notification in &updated_mention_notifications {
4441 session.peer.send(
4442 connection,
4443 proto::UpdateNotification {
4444 notification: Some(notification.clone()),
4445 },
4446 )?;
4447 }
4448
4449 Ok(())
4450 },
4451 );
4452
4453 send_notifications(pool, &session.peer, notifications);
4454
4455 Ok(())
4456}
4457
4458/// Mark a channel message as read
4459async fn acknowledge_channel_message(
4460 request: proto::AckChannelMessage,
4461 session: UserSession,
4462) -> Result<()> {
4463 let channel_id = ChannelId::from_proto(request.channel_id);
4464 let message_id = MessageId::from_proto(request.message_id);
4465 let notifications = session
4466 .db()
4467 .await
4468 .observe_channel_message(channel_id, session.user_id(), message_id)
4469 .await?;
4470 send_notifications(
4471 &*session.connection_pool().await,
4472 &session.peer,
4473 notifications,
4474 );
4475 Ok(())
4476}
4477
4478/// Mark a buffer version as synced
4479async fn acknowledge_buffer_version(
4480 request: proto::AckBufferOperation,
4481 session: UserSession,
4482) -> Result<()> {
4483 let buffer_id = BufferId::from_proto(request.buffer_id);
4484 session
4485 .db()
4486 .await
4487 .observe_buffer_version(
4488 buffer_id,
4489 session.user_id(),
4490 request.epoch as i32,
4491 &request.version,
4492 )
4493 .await?;
4494 Ok(())
4495}
4496
4497struct CompleteWithLanguageModelRateLimit;
4498
4499impl RateLimit for CompleteWithLanguageModelRateLimit {
4500 fn capacity() -> usize {
4501 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4502 .ok()
4503 .and_then(|v| v.parse().ok())
4504 .unwrap_or(120) // Picked arbitrarily
4505 }
4506
4507 fn refill_duration() -> chrono::Duration {
4508 chrono::Duration::hours(1)
4509 }
4510
4511 fn db_name() -> &'static str {
4512 "complete-with-language-model"
4513 }
4514}
4515
4516async fn complete_with_language_model(
4517 mut request: proto::CompleteWithLanguageModel,
4518 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4519 session: Session,
4520 open_ai_api_key: Option<Arc<str>>,
4521 google_ai_api_key: Option<Arc<str>>,
4522 anthropic_api_key: Option<Arc<str>>,
4523) -> Result<()> {
4524 let Some(session) = session.for_user() else {
4525 return Err(anyhow!("user not found"))?;
4526 };
4527 authorize_access_to_language_models(&session).await?;
4528 session
4529 .rate_limiter
4530 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4531 .await?;
4532
4533 let mut provider_and_model = request.model.split('/');
4534 let (provider, model) = match (
4535 provider_and_model.next().unwrap(),
4536 provider_and_model.next(),
4537 ) {
4538 (provider, Some(model)) => (provider, model),
4539 (model, None) => {
4540 if model.starts_with("gpt") {
4541 ("openai", model)
4542 } else if model.starts_with("gemini") {
4543 ("google", model)
4544 } else if model.starts_with("claude") {
4545 ("anthropic", model)
4546 } else {
4547 ("unknown", model)
4548 }
4549 }
4550 };
4551 let provider = provider.to_string();
4552 request.model = model.to_string();
4553
4554 match provider.as_str() {
4555 "openai" => {
4556 let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
4557 complete_with_open_ai(request, response, session, api_key).await?;
4558 }
4559 "anthropic" => {
4560 let api_key =
4561 anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
4562 complete_with_anthropic(request, response, session, api_key).await?;
4563 }
4564 "google" => {
4565 let api_key =
4566 google_ai_api_key.context("no Google AI API key configured on the server")?;
4567 complete_with_google_ai(request, response, session, api_key).await?;
4568 }
4569 provider => return Err(anyhow!("unknown provider {:?}", provider))?,
4570 }
4571
4572 Ok(())
4573}
4574
4575async fn complete_with_open_ai(
4576 request: proto::CompleteWithLanguageModel,
4577 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4578 session: UserSession,
4579 api_key: Arc<str>,
4580) -> Result<()> {
4581 let mut completion_stream = open_ai::stream_completion(
4582 session.http_client.as_ref(),
4583 OPEN_AI_API_URL,
4584 &api_key,
4585 crate::ai::language_model_request_to_open_ai(request)?,
4586 None,
4587 )
4588 .await
4589 .context("open_ai::stream_completion request failed within collab")?;
4590
4591 while let Some(event) = completion_stream.next().await {
4592 let event = event?;
4593 response.send(proto::LanguageModelResponse {
4594 choices: event
4595 .choices
4596 .into_iter()
4597 .map(|choice| proto::LanguageModelChoiceDelta {
4598 index: choice.index,
4599 delta: Some(proto::LanguageModelResponseMessage {
4600 role: choice.delta.role.map(|role| match role {
4601 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4602 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4603 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4604 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4605 } as i32),
4606 content: choice.delta.content,
4607 tool_calls: choice
4608 .delta
4609 .tool_calls
4610 .unwrap_or_default()
4611 .into_iter()
4612 .map(|delta| proto::ToolCallDelta {
4613 index: delta.index as u32,
4614 id: delta.id,
4615 variant: match delta.function {
4616 Some(function) => {
4617 let name = function.name;
4618 let arguments = function.arguments;
4619
4620 Some(proto::tool_call_delta::Variant::Function(
4621 proto::tool_call_delta::FunctionCallDelta {
4622 name,
4623 arguments,
4624 },
4625 ))
4626 }
4627 None => None,
4628 },
4629 })
4630 .collect(),
4631 }),
4632 finish_reason: choice.finish_reason,
4633 })
4634 .collect(),
4635 })?;
4636 }
4637
4638 Ok(())
4639}
4640
4641async fn complete_with_google_ai(
4642 request: proto::CompleteWithLanguageModel,
4643 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4644 session: UserSession,
4645 api_key: Arc<str>,
4646) -> Result<()> {
4647 let mut stream = google_ai::stream_generate_content(
4648 session.http_client.clone(),
4649 google_ai::API_URL,
4650 api_key.as_ref(),
4651 &request.model.clone(),
4652 crate::ai::language_model_request_to_google_ai(request)?,
4653 )
4654 .await
4655 .context("google_ai::stream_generate_content request failed")?;
4656
4657 while let Some(event) = stream.next().await {
4658 let event = event?;
4659 response.send(proto::LanguageModelResponse {
4660 choices: event
4661 .candidates
4662 .unwrap_or_default()
4663 .into_iter()
4664 .map(|candidate| proto::LanguageModelChoiceDelta {
4665 index: candidate.index as u32,
4666 delta: Some(proto::LanguageModelResponseMessage {
4667 role: Some(match candidate.content.role {
4668 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4669 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4670 } as i32),
4671 content: Some(
4672 candidate
4673 .content
4674 .parts
4675 .into_iter()
4676 .filter_map(|part| match part {
4677 google_ai::Part::TextPart(part) => Some(part.text),
4678 google_ai::Part::InlineDataPart(_) => None,
4679 })
4680 .collect(),
4681 ),
4682 // Tool calls are not supported for Google
4683 tool_calls: Vec::new(),
4684 }),
4685 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4686 })
4687 .collect(),
4688 })?;
4689 }
4690
4691 Ok(())
4692}
4693
4694async fn complete_with_anthropic(
4695 request: proto::CompleteWithLanguageModel,
4696 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4697 session: UserSession,
4698 api_key: Arc<str>,
4699) -> Result<()> {
4700 let model = anthropic::Model::from_id(&request.model)?;
4701
4702 let mut system_message = String::new();
4703 let messages = request
4704 .messages
4705 .into_iter()
4706 .filter_map(|message| {
4707 match message.role() {
4708 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4709 role: anthropic::Role::User,
4710 content: message.content,
4711 }),
4712 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4713 role: anthropic::Role::Assistant,
4714 content: message.content,
4715 }),
4716 // Anthropic's API breaks system instructions out as a separate field rather
4717 // than having a system message role.
4718 LanguageModelRole::LanguageModelSystem => {
4719 if !system_message.is_empty() {
4720 system_message.push_str("\n\n");
4721 }
4722 system_message.push_str(&message.content);
4723
4724 None
4725 }
4726 // We don't yet support tool calls for Anthropic
4727 LanguageModelRole::LanguageModelTool => None,
4728 }
4729 })
4730 .collect();
4731
4732 let mut stream = anthropic::stream_completion(
4733 session.http_client.as_ref(),
4734 anthropic::ANTHROPIC_API_URL,
4735 &api_key,
4736 anthropic::Request {
4737 model,
4738 messages,
4739 stream: true,
4740 system: system_message,
4741 max_tokens: 4092,
4742 },
4743 None,
4744 )
4745 .await?;
4746
4747 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4748
4749 while let Some(event) = stream.next().await {
4750 let event = event?;
4751
4752 match event {
4753 anthropic::ResponseEvent::MessageStart { message } => {
4754 if let Some(role) = message.role {
4755 if role == "assistant" {
4756 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4757 } else if role == "user" {
4758 current_role = proto::LanguageModelRole::LanguageModelUser;
4759 }
4760 }
4761 }
4762 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4763 match content_block {
4764 anthropic::ContentBlock::Text { text } => {
4765 if !text.is_empty() {
4766 response.send(proto::LanguageModelResponse {
4767 choices: vec![proto::LanguageModelChoiceDelta {
4768 index: 0,
4769 delta: Some(proto::LanguageModelResponseMessage {
4770 role: Some(current_role as i32),
4771 content: Some(text),
4772 tool_calls: Vec::new(),
4773 }),
4774 finish_reason: None,
4775 }],
4776 })?;
4777 }
4778 }
4779 }
4780 }
4781 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4782 anthropic::TextDelta::TextDelta { text } => {
4783 response.send(proto::LanguageModelResponse {
4784 choices: vec![proto::LanguageModelChoiceDelta {
4785 index: 0,
4786 delta: Some(proto::LanguageModelResponseMessage {
4787 role: Some(current_role as i32),
4788 content: Some(text),
4789 tool_calls: Vec::new(),
4790 }),
4791 finish_reason: None,
4792 }],
4793 })?;
4794 }
4795 },
4796 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4797 if let Some(stop_reason) = delta.stop_reason {
4798 response.send(proto::LanguageModelResponse {
4799 choices: vec![proto::LanguageModelChoiceDelta {
4800 index: 0,
4801 delta: None,
4802 finish_reason: Some(stop_reason),
4803 }],
4804 })?;
4805 }
4806 }
4807 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4808 anthropic::ResponseEvent::MessageStop {} => {}
4809 anthropic::ResponseEvent::Ping {} => {}
4810 }
4811 }
4812
4813 Ok(())
4814}
4815
4816struct CountTokensWithLanguageModelRateLimit;
4817
4818impl RateLimit for CountTokensWithLanguageModelRateLimit {
4819 fn capacity() -> usize {
4820 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4821 .ok()
4822 .and_then(|v| v.parse().ok())
4823 .unwrap_or(600) // Picked arbitrarily
4824 }
4825
4826 fn refill_duration() -> chrono::Duration {
4827 chrono::Duration::hours(1)
4828 }
4829
4830 fn db_name() -> &'static str {
4831 "count-tokens-with-language-model"
4832 }
4833}
4834
4835async fn count_tokens_with_language_model(
4836 request: proto::CountTokensWithLanguageModel,
4837 response: Response<proto::CountTokensWithLanguageModel>,
4838 session: UserSession,
4839 google_ai_api_key: Option<Arc<str>>,
4840) -> Result<()> {
4841 authorize_access_to_language_models(&session).await?;
4842
4843 if !request.model.starts_with("gemini") {
4844 return Err(anyhow!(
4845 "counting tokens for model: {:?} is not supported",
4846 request.model
4847 ))?;
4848 }
4849
4850 session
4851 .rate_limiter
4852 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4853 .await?;
4854
4855 let api_key = google_ai_api_key
4856 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4857 let tokens_response = google_ai::count_tokens(
4858 session.http_client.as_ref(),
4859 google_ai::API_URL,
4860 &api_key,
4861 crate::ai::count_tokens_request_to_google_ai(request)?,
4862 )
4863 .await?;
4864 response.send(proto::CountTokensResponse {
4865 token_count: tokens_response.total_tokens as u32,
4866 })?;
4867 Ok(())
4868}
4869
4870struct ComputeEmbeddingsRateLimit;
4871
4872impl RateLimit for ComputeEmbeddingsRateLimit {
4873 fn capacity() -> usize {
4874 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4875 .ok()
4876 .and_then(|v| v.parse().ok())
4877 .unwrap_or(5000) // Picked arbitrarily
4878 }
4879
4880 fn refill_duration() -> chrono::Duration {
4881 chrono::Duration::hours(1)
4882 }
4883
4884 fn db_name() -> &'static str {
4885 "compute-embeddings"
4886 }
4887}
4888
4889async fn compute_embeddings(
4890 request: proto::ComputeEmbeddings,
4891 response: Response<proto::ComputeEmbeddings>,
4892 session: UserSession,
4893 api_key: Option<Arc<str>>,
4894) -> Result<()> {
4895 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4896 authorize_access_to_language_models(&session).await?;
4897
4898 session
4899 .rate_limiter
4900 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4901 .await?;
4902
4903 let embeddings = match request.model.as_str() {
4904 "openai/text-embedding-3-small" => {
4905 open_ai::embed(
4906 session.http_client.as_ref(),
4907 OPEN_AI_API_URL,
4908 &api_key,
4909 OpenAiEmbeddingModel::TextEmbedding3Small,
4910 request.texts.iter().map(|text| text.as_str()),
4911 )
4912 .await?
4913 }
4914 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4915 };
4916
4917 let embeddings = request
4918 .texts
4919 .iter()
4920 .map(|text| {
4921 let mut hasher = sha2::Sha256::new();
4922 hasher.update(text.as_bytes());
4923 let result = hasher.finalize();
4924 result.to_vec()
4925 })
4926 .zip(
4927 embeddings
4928 .data
4929 .into_iter()
4930 .map(|embedding| embedding.embedding),
4931 )
4932 .collect::<HashMap<_, _>>();
4933
4934 let db = session.db().await;
4935 db.save_embeddings(&request.model, &embeddings)
4936 .await
4937 .context("failed to save embeddings")
4938 .trace_err();
4939
4940 response.send(proto::ComputeEmbeddingsResponse {
4941 embeddings: embeddings
4942 .into_iter()
4943 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4944 .collect(),
4945 })?;
4946 Ok(())
4947}
4948
4949async fn get_cached_embeddings(
4950 request: proto::GetCachedEmbeddings,
4951 response: Response<proto::GetCachedEmbeddings>,
4952 session: UserSession,
4953) -> Result<()> {
4954 authorize_access_to_language_models(&session).await?;
4955
4956 let db = session.db().await;
4957 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4958
4959 response.send(proto::GetCachedEmbeddingsResponse {
4960 embeddings: embeddings
4961 .into_iter()
4962 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4963 .collect(),
4964 })?;
4965 Ok(())
4966}
4967
4968async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4969 let db = session.db().await;
4970 let flags = db.get_user_flags(session.user_id()).await?;
4971 if flags.iter().any(|flag| flag == "language-models") {
4972 Ok(())
4973 } else {
4974 Err(anyhow!("permission denied"))?
4975 }
4976}
4977
4978/// Get a Supermaven API key for the user
4979async fn get_supermaven_api_key(
4980 _request: proto::GetSupermavenApiKey,
4981 response: Response<proto::GetSupermavenApiKey>,
4982 session: UserSession,
4983) -> Result<()> {
4984 let user_id: String = session.user_id().to_string();
4985 if !session.is_staff() {
4986 return Err(anyhow!("supermaven not enabled for this account"))?;
4987 }
4988
4989 let email = session
4990 .email()
4991 .ok_or_else(|| anyhow!("user must have an email"))?;
4992
4993 let supermaven_admin_api = session
4994 .supermaven_client
4995 .as_ref()
4996 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4997
4998 let result = supermaven_admin_api
4999 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
5000 .await?;
5001
5002 response.send(proto::GetSupermavenApiKeyResponse {
5003 api_key: result.api_key,
5004 })?;
5005
5006 Ok(())
5007}
5008
5009/// Start receiving chat updates for a channel
5010async fn join_channel_chat(
5011 request: proto::JoinChannelChat,
5012 response: Response<proto::JoinChannelChat>,
5013 session: UserSession,
5014) -> Result<()> {
5015 let channel_id = ChannelId::from_proto(request.channel_id);
5016
5017 let db = session.db().await;
5018 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
5019 .await?;
5020 let messages = db
5021 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
5022 .await?;
5023 response.send(proto::JoinChannelChatResponse {
5024 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5025 messages,
5026 })?;
5027 Ok(())
5028}
5029
5030/// Stop receiving chat updates for a channel
5031async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
5032 let channel_id = ChannelId::from_proto(request.channel_id);
5033 session
5034 .db()
5035 .await
5036 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
5037 .await?;
5038 Ok(())
5039}
5040
5041/// Retrieve the chat history for a channel
5042async fn get_channel_messages(
5043 request: proto::GetChannelMessages,
5044 response: Response<proto::GetChannelMessages>,
5045 session: UserSession,
5046) -> Result<()> {
5047 let channel_id = ChannelId::from_proto(request.channel_id);
5048 let messages = session
5049 .db()
5050 .await
5051 .get_channel_messages(
5052 channel_id,
5053 session.user_id(),
5054 MESSAGE_COUNT_PER_PAGE,
5055 Some(MessageId::from_proto(request.before_message_id)),
5056 )
5057 .await?;
5058 response.send(proto::GetChannelMessagesResponse {
5059 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5060 messages,
5061 })?;
5062 Ok(())
5063}
5064
5065/// Retrieve specific chat messages
5066async fn get_channel_messages_by_id(
5067 request: proto::GetChannelMessagesById,
5068 response: Response<proto::GetChannelMessagesById>,
5069 session: UserSession,
5070) -> Result<()> {
5071 let message_ids = request
5072 .message_ids
5073 .iter()
5074 .map(|id| MessageId::from_proto(*id))
5075 .collect::<Vec<_>>();
5076 let messages = session
5077 .db()
5078 .await
5079 .get_channel_messages_by_id(session.user_id(), &message_ids)
5080 .await?;
5081 response.send(proto::GetChannelMessagesResponse {
5082 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5083 messages,
5084 })?;
5085 Ok(())
5086}
5087
5088/// Retrieve the current users notifications
5089async fn get_notifications(
5090 request: proto::GetNotifications,
5091 response: Response<proto::GetNotifications>,
5092 session: UserSession,
5093) -> Result<()> {
5094 let notifications = session
5095 .db()
5096 .await
5097 .get_notifications(
5098 session.user_id(),
5099 NOTIFICATION_COUNT_PER_PAGE,
5100 request
5101 .before_id
5102 .map(|id| db::NotificationId::from_proto(id)),
5103 )
5104 .await?;
5105 response.send(proto::GetNotificationsResponse {
5106 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5107 notifications,
5108 })?;
5109 Ok(())
5110}
5111
5112/// Mark notifications as read
5113async fn mark_notification_as_read(
5114 request: proto::MarkNotificationRead,
5115 response: Response<proto::MarkNotificationRead>,
5116 session: UserSession,
5117) -> Result<()> {
5118 let database = &session.db().await;
5119 let notifications = database
5120 .mark_notification_as_read_by_id(
5121 session.user_id(),
5122 NotificationId::from_proto(request.notification_id),
5123 )
5124 .await?;
5125 send_notifications(
5126 &*session.connection_pool().await,
5127 &session.peer,
5128 notifications,
5129 );
5130 response.send(proto::Ack {})?;
5131 Ok(())
5132}
5133
5134/// Get the current users information
5135async fn get_private_user_info(
5136 _request: proto::GetPrivateUserInfo,
5137 response: Response<proto::GetPrivateUserInfo>,
5138 session: UserSession,
5139) -> Result<()> {
5140 let db = session.db().await;
5141
5142 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5143 let user = db
5144 .get_user_by_id(session.user_id())
5145 .await?
5146 .ok_or_else(|| anyhow!("user not found"))?;
5147 let flags = db.get_user_flags(session.user_id()).await?;
5148
5149 response.send(proto::GetPrivateUserInfoResponse {
5150 metrics_id,
5151 staff: user.admin,
5152 flags,
5153 })?;
5154 Ok(())
5155}
5156
5157fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
5158 match message {
5159 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5160 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5161 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5162 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5163 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5164 code: frame.code.into(),
5165 reason: frame.reason,
5166 })),
5167 }
5168}
5169
5170fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5171 match message {
5172 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5173 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5174 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5175 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5176 AxumMessage::Close(frame) => {
5177 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5178 code: frame.code.into(),
5179 reason: frame.reason,
5180 }))
5181 }
5182 }
5183}
5184
5185fn notify_membership_updated(
5186 connection_pool: &mut ConnectionPool,
5187 result: MembershipUpdated,
5188 user_id: UserId,
5189 peer: &Peer,
5190) {
5191 for membership in &result.new_channels.channel_memberships {
5192 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5193 }
5194 for channel_id in &result.removed_channels {
5195 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5196 }
5197
5198 let user_channels_update = proto::UpdateUserChannels {
5199 channel_memberships: result
5200 .new_channels
5201 .channel_memberships
5202 .iter()
5203 .map(|cm| proto::ChannelMembership {
5204 channel_id: cm.channel_id.to_proto(),
5205 role: cm.role.into(),
5206 })
5207 .collect(),
5208 ..Default::default()
5209 };
5210
5211 let mut update = build_channels_update(result.new_channels);
5212 update.delete_channels = result
5213 .removed_channels
5214 .into_iter()
5215 .map(|id| id.to_proto())
5216 .collect();
5217 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5218
5219 for connection_id in connection_pool.user_connection_ids(user_id) {
5220 peer.send(connection_id, user_channels_update.clone())
5221 .trace_err();
5222 peer.send(connection_id, update.clone()).trace_err();
5223 }
5224}
5225
5226fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5227 proto::UpdateUserChannels {
5228 channel_memberships: channels
5229 .channel_memberships
5230 .iter()
5231 .map(|m| proto::ChannelMembership {
5232 channel_id: m.channel_id.to_proto(),
5233 role: m.role.into(),
5234 })
5235 .collect(),
5236 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5237 observed_channel_message_id: channels.observed_channel_messages.clone(),
5238 }
5239}
5240
5241fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5242 let mut update = proto::UpdateChannels::default();
5243
5244 for channel in channels.channels {
5245 update.channels.push(channel.to_proto());
5246 }
5247
5248 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5249 update.latest_channel_message_ids = channels.latest_channel_messages;
5250
5251 for (channel_id, participants) in channels.channel_participants {
5252 update
5253 .channel_participants
5254 .push(proto::ChannelParticipants {
5255 channel_id: channel_id.to_proto(),
5256 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5257 });
5258 }
5259
5260 for channel in channels.invited_channels {
5261 update.channel_invitations.push(channel.to_proto());
5262 }
5263
5264 update.hosted_projects = channels.hosted_projects;
5265 update
5266}
5267
5268fn build_initial_contacts_update(
5269 contacts: Vec<db::Contact>,
5270 pool: &ConnectionPool,
5271) -> proto::UpdateContacts {
5272 let mut update = proto::UpdateContacts::default();
5273
5274 for contact in contacts {
5275 match contact {
5276 db::Contact::Accepted { user_id, busy } => {
5277 update.contacts.push(contact_for_user(user_id, busy, &pool));
5278 }
5279 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5280 db::Contact::Incoming { user_id } => {
5281 update
5282 .incoming_requests
5283 .push(proto::IncomingContactRequest {
5284 requester_id: user_id.to_proto(),
5285 })
5286 }
5287 }
5288 }
5289
5290 update
5291}
5292
5293fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5294 proto::Contact {
5295 user_id: user_id.to_proto(),
5296 online: pool.is_user_online(user_id),
5297 busy,
5298 }
5299}
5300
5301fn room_updated(room: &proto::Room, peer: &Peer) {
5302 broadcast(
5303 None,
5304 room.participants
5305 .iter()
5306 .filter_map(|participant| Some(participant.peer_id?.into())),
5307 |peer_id| {
5308 peer.send(
5309 peer_id,
5310 proto::RoomUpdated {
5311 room: Some(room.clone()),
5312 },
5313 )
5314 },
5315 );
5316}
5317
5318fn channel_updated(
5319 channel: &db::channel::Model,
5320 room: &proto::Room,
5321 peer: &Peer,
5322 pool: &ConnectionPool,
5323) {
5324 let participants = room
5325 .participants
5326 .iter()
5327 .map(|p| p.user_id)
5328 .collect::<Vec<_>>();
5329
5330 broadcast(
5331 None,
5332 pool.channel_connection_ids(channel.root_id())
5333 .filter_map(|(channel_id, role)| {
5334 role.can_see_channel(channel.visibility).then(|| channel_id)
5335 }),
5336 |peer_id| {
5337 peer.send(
5338 peer_id,
5339 proto::UpdateChannels {
5340 channel_participants: vec![proto::ChannelParticipants {
5341 channel_id: channel.id.to_proto(),
5342 participant_user_ids: participants.clone(),
5343 }],
5344 ..Default::default()
5345 },
5346 )
5347 },
5348 );
5349}
5350
5351async fn send_dev_server_projects_update(
5352 user_id: UserId,
5353 mut status: proto::DevServerProjectsUpdate,
5354 session: &Session,
5355) {
5356 let pool = session.connection_pool().await;
5357 for dev_server in &mut status.dev_servers {
5358 dev_server.status =
5359 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5360 }
5361 let connections = pool.user_connection_ids(user_id);
5362 for connection_id in connections {
5363 session.peer.send(connection_id, status.clone()).trace_err();
5364 }
5365}
5366
5367async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5368 let db = session.db().await;
5369
5370 let contacts = db.get_contacts(user_id).await?;
5371 let busy = db.is_user_busy(user_id).await?;
5372
5373 let pool = session.connection_pool().await;
5374 let updated_contact = contact_for_user(user_id, busy, &pool);
5375 for contact in contacts {
5376 if let db::Contact::Accepted {
5377 user_id: contact_user_id,
5378 ..
5379 } = contact
5380 {
5381 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5382 session
5383 .peer
5384 .send(
5385 contact_conn_id,
5386 proto::UpdateContacts {
5387 contacts: vec![updated_contact.clone()],
5388 remove_contacts: Default::default(),
5389 incoming_requests: Default::default(),
5390 remove_incoming_requests: Default::default(),
5391 outgoing_requests: Default::default(),
5392 remove_outgoing_requests: Default::default(),
5393 },
5394 )
5395 .trace_err();
5396 }
5397 }
5398 }
5399 Ok(())
5400}
5401
5402async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5403 log::info!("lost dev server connection, unsharing projects");
5404 let project_ids = session
5405 .db()
5406 .await
5407 .get_stale_dev_server_projects(session.connection_id)
5408 .await?;
5409
5410 for project_id in project_ids {
5411 // not unshare re-checks the connection ids match, so we get away with no transaction
5412 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5413 }
5414
5415 let user_id = session.dev_server().user_id;
5416 let update = session
5417 .db()
5418 .await
5419 .dev_server_projects_update(user_id)
5420 .await?;
5421
5422 send_dev_server_projects_update(user_id, update, session).await;
5423
5424 Ok(())
5425}
5426
5427async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5428 let mut contacts_to_update = HashSet::default();
5429
5430 let room_id;
5431 let canceled_calls_to_user_ids;
5432 let live_kit_room;
5433 let delete_live_kit_room;
5434 let room;
5435 let channel;
5436
5437 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5438 contacts_to_update.insert(session.user_id());
5439
5440 for project in left_room.left_projects.values() {
5441 project_left(project, session);
5442 }
5443
5444 room_id = RoomId::from_proto(left_room.room.id);
5445 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5446 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5447 delete_live_kit_room = left_room.deleted;
5448 room = mem::take(&mut left_room.room);
5449 channel = mem::take(&mut left_room.channel);
5450
5451 room_updated(&room, &session.peer);
5452 } else {
5453 return Ok(());
5454 }
5455
5456 if let Some(channel) = channel {
5457 channel_updated(
5458 &channel,
5459 &room,
5460 &session.peer,
5461 &*session.connection_pool().await,
5462 );
5463 }
5464
5465 {
5466 let pool = session.connection_pool().await;
5467 for canceled_user_id in canceled_calls_to_user_ids {
5468 for connection_id in pool.user_connection_ids(canceled_user_id) {
5469 session
5470 .peer
5471 .send(
5472 connection_id,
5473 proto::CallCanceled {
5474 room_id: room_id.to_proto(),
5475 },
5476 )
5477 .trace_err();
5478 }
5479 contacts_to_update.insert(canceled_user_id);
5480 }
5481 }
5482
5483 for contact_user_id in contacts_to_update {
5484 update_user_contacts(contact_user_id, &session).await?;
5485 }
5486
5487 if let Some(live_kit) = session.live_kit_client.as_ref() {
5488 live_kit
5489 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5490 .await
5491 .trace_err();
5492
5493 if delete_live_kit_room {
5494 live_kit.delete_room(live_kit_room).await.trace_err();
5495 }
5496 }
5497
5498 Ok(())
5499}
5500
5501async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5502 let left_channel_buffers = session
5503 .db()
5504 .await
5505 .leave_channel_buffers(session.connection_id)
5506 .await?;
5507
5508 for left_buffer in left_channel_buffers {
5509 channel_buffer_updated(
5510 session.connection_id,
5511 left_buffer.connections,
5512 &proto::UpdateChannelBufferCollaborators {
5513 channel_id: left_buffer.channel_id.to_proto(),
5514 collaborators: left_buffer.collaborators,
5515 },
5516 &session.peer,
5517 );
5518 }
5519
5520 Ok(())
5521}
5522
5523fn project_left(project: &db::LeftProject, session: &UserSession) {
5524 for connection_id in &project.connection_ids {
5525 if project.should_unshare {
5526 session
5527 .peer
5528 .send(
5529 *connection_id,
5530 proto::UnshareProject {
5531 project_id: project.id.to_proto(),
5532 },
5533 )
5534 .trace_err();
5535 } else {
5536 session
5537 .peer
5538 .send(
5539 *connection_id,
5540 proto::RemoveProjectCollaborator {
5541 project_id: project.id.to_proto(),
5542 peer_id: Some(session.connection_id.into()),
5543 },
5544 )
5545 .trace_err();
5546 }
5547 }
5548}
5549
5550pub trait ResultExt {
5551 type Ok;
5552
5553 fn trace_err(self) -> Option<Self::Ok>;
5554}
5555
5556impl<T, E> ResultExt for Result<T, E>
5557where
5558 E: std::fmt::Debug,
5559{
5560 type Ok = T;
5561
5562 #[track_caller]
5563 fn trace_err(self) -> Option<T> {
5564 match self {
5565 Ok(value) => Some(value),
5566 Err(error) => {
5567 tracing::error!("{:?}", error);
5568 None
5569 }
5570 }
5571 }
5572}