1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, InviteMemberResult, MembershipUpdated,
8 MessageId, NotificationId, PrincipalId, Project, ProjectId, RejoinedProject,
9 RemoteProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37
38use futures::{
39 channel::oneshot,
40 future::{self, BoxFuture},
41 stream::FuturesUnordered,
42 FutureExt, SinkExt, StreamExt, TryStreamExt,
43};
44use prometheus::{register_int_gauge, IntGauge};
45use rpc::{
46 proto::{
47 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
48 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
49 },
50 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
51};
52use semantic_version::SemanticVersion;
53use serde::{Serialize, Serializer};
54use std::{
55 any::TypeId,
56 future::Future,
57 marker::PhantomData,
58 mem,
59 net::SocketAddr,
60 ops::{Deref, DerefMut},
61 rc::Rc,
62 sync::{
63 atomic::{AtomicBool, Ordering::SeqCst},
64 Arc, OnceLock,
65 },
66 time::{Duration, Instant},
67};
68use time::OffsetDateTime;
69use tokio::sync::{watch, Semaphore};
70use tower::ServiceBuilder;
71use tracing::{
72 field::{self},
73 info_span, instrument, Instrument,
74};
75use util::http::IsahcHttpClient;
76
77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
78
79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
81
82const MESSAGE_COUNT_PER_PAGE: usize = 100;
83const MAX_MESSAGE_LEN: usize = 1024;
84const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
85
86type MessageHandler =
87 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
88
89struct Response<R> {
90 peer: Arc<Peer>,
91 receipt: Receipt<R>,
92 responded: Arc<AtomicBool>,
93}
94
95impl<R: RequestMessage> Response<R> {
96 fn send(self, payload: R::Response) -> Result<()> {
97 self.responded.store(true, SeqCst);
98 self.peer.respond(self.receipt, payload)?;
99 Ok(())
100 }
101}
102
103struct StreamingResponse<R: RequestMessage> {
104 peer: Arc<Peer>,
105 receipt: Receipt<R>,
106}
107
108impl<R: RequestMessage> StreamingResponse<R> {
109 fn send(&self, payload: R::Response) -> Result<()> {
110 self.peer.respond(self.receipt, payload)?;
111 Ok(())
112 }
113}
114
115#[derive(Clone, Debug)]
116pub enum Principal {
117 User(User),
118 Impersonated { user: User, admin: User },
119 DevServer(dev_server::Model),
120}
121
122impl Principal {
123 fn update_span(&self, span: &tracing::Span) {
124 match &self {
125 Principal::User(user) => {
126 span.record("user_id", &user.id.0);
127 span.record("login", &user.github_login);
128 }
129 Principal::Impersonated { user, admin } => {
130 span.record("user_id", &user.id.0);
131 span.record("login", &user.github_login);
132 span.record("impersonator", &admin.github_login);
133 }
134 Principal::DevServer(dev_server) => {
135 span.record("dev_server_id", &dev_server.id.0);
136 }
137 }
138 }
139}
140
141#[derive(Clone)]
142struct Session {
143 principal: Principal,
144 connection_id: ConnectionId,
145 db: Arc<tokio::sync::Mutex<DbHandle>>,
146 peer: Arc<Peer>,
147 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
149 http_client: IsahcHttpClient,
150 rate_limiter: Arc<RateLimiter>,
151 _executor: Executor,
152}
153
154impl Session {
155 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 let guard = self.db.lock().await;
159 #[cfg(test)]
160 tokio::task::yield_now().await;
161 guard
162 }
163
164 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
165 #[cfg(test)]
166 tokio::task::yield_now().await;
167 let guard = self.connection_pool.lock();
168 ConnectionPoolGuard {
169 guard,
170 _not_send: PhantomData,
171 }
172 }
173
174 fn for_user(self) -> Option<UserSession> {
175 UserSession::new(self)
176 }
177
178 fn for_dev_server(self) -> Option<DevServerSession> {
179 DevServerSession::new(self)
180 }
181
182 fn user_id(&self) -> Option<UserId> {
183 match &self.principal {
184 Principal::User(user) => Some(user.id),
185 Principal::Impersonated { user, .. } => Some(user.id),
186 Principal::DevServer(_) => None,
187 }
188 }
189
190 fn dev_server_id(&self) -> Option<DevServerId> {
191 match &self.principal {
192 Principal::User(_) | Principal::Impersonated { .. } => None,
193 Principal::DevServer(dev_server) => Some(dev_server.id),
194 }
195 }
196
197 fn principal_id(&self) -> PrincipalId {
198 match &self.principal {
199 Principal::User(user) => PrincipalId::UserId(user.id),
200 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
201 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
202 }
203 }
204}
205
206impl Debug for Session {
207 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
208 let mut result = f.debug_struct("Session");
209 match &self.principal {
210 Principal::User(user) => {
211 result.field("user", &user.github_login);
212 }
213 Principal::Impersonated { user, admin } => {
214 result.field("user", &user.github_login);
215 result.field("impersonator", &admin.github_login);
216 }
217 Principal::DevServer(dev_server) => {
218 result.field("dev_server", &dev_server.id);
219 }
220 }
221 result.field("connection_id", &self.connection_id).finish()
222 }
223}
224
225struct UserSession(Session);
226
227impl UserSession {
228 pub fn new(s: Session) -> Option<Self> {
229 s.user_id().map(|_| UserSession(s))
230 }
231 pub fn user_id(&self) -> UserId {
232 self.0.user_id().unwrap()
233 }
234}
235
236impl Deref for UserSession {
237 type Target = Session;
238
239 fn deref(&self) -> &Self::Target {
240 &self.0
241 }
242}
243impl DerefMut for UserSession {
244 fn deref_mut(&mut self) -> &mut Self::Target {
245 &mut self.0
246 }
247}
248
249struct DevServerSession(Session);
250
251impl DevServerSession {
252 pub fn new(s: Session) -> Option<Self> {
253 s.dev_server_id().map(|_| DevServerSession(s))
254 }
255 pub fn dev_server_id(&self) -> DevServerId {
256 self.0.dev_server_id().unwrap()
257 }
258
259 fn dev_server(&self) -> &dev_server::Model {
260 match &self.0.principal {
261 Principal::DevServer(dev_server) => dev_server,
262 _ => unreachable!(),
263 }
264 }
265}
266
267impl Deref for DevServerSession {
268 type Target = Session;
269
270 fn deref(&self) -> &Self::Target {
271 &self.0
272 }
273}
274impl DerefMut for DevServerSession {
275 fn deref_mut(&mut self) -> &mut Self::Target {
276 &mut self.0
277 }
278}
279
280fn user_handler<M: RequestMessage, Fut>(
281 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
282) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
283where
284 Fut: Send + Future<Output = Result<()>>,
285{
286 let handler = Arc::new(handler);
287 move |message, response, session| {
288 let handler = handler.clone();
289 Box::pin(async move {
290 if let Some(user_session) = session.for_user() {
291 Ok(handler(message, response, user_session).await?)
292 } else {
293 Err(Error::Internal(anyhow!(
294 "must be a user to call {}",
295 M::NAME
296 )))
297 }
298 })
299 }
300}
301
302fn dev_server_handler<M: RequestMessage, Fut>(
303 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
304) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
305where
306 Fut: Send + Future<Output = Result<()>>,
307{
308 let handler = Arc::new(handler);
309 move |message, response, session| {
310 let handler = handler.clone();
311 Box::pin(async move {
312 if let Some(dev_server_session) = session.for_dev_server() {
313 Ok(handler(message, response, dev_server_session).await?)
314 } else {
315 Err(Error::Internal(anyhow!(
316 "must be a dev server to call {}",
317 M::NAME
318 )))
319 }
320 })
321 }
322}
323
324fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
325 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
326) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
327where
328 InnertRetFut: Send + Future<Output = Result<()>>,
329{
330 let handler = Arc::new(handler);
331 move |message, session| {
332 let handler = handler.clone();
333 Box::pin(async move {
334 if let Some(user_session) = session.for_user() {
335 Ok(handler(message, user_session).await?)
336 } else {
337 Err(Error::Internal(anyhow!(
338 "must be a user to call {}",
339 M::NAME
340 )))
341 }
342 })
343 }
344}
345
346struct DbHandle(Arc<Database>);
347
348impl Deref for DbHandle {
349 type Target = Database;
350
351 fn deref(&self) -> &Self::Target {
352 self.0.as_ref()
353 }
354}
355
356pub struct Server {
357 id: parking_lot::Mutex<ServerId>,
358 peer: Arc<Peer>,
359 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
360 app_state: Arc<AppState>,
361 handlers: HashMap<TypeId, MessageHandler>,
362 teardown: watch::Sender<bool>,
363}
364
365pub(crate) struct ConnectionPoolGuard<'a> {
366 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
367 _not_send: PhantomData<Rc<()>>,
368}
369
370#[derive(Serialize)]
371pub struct ServerSnapshot<'a> {
372 peer: &'a Peer,
373 #[serde(serialize_with = "serialize_deref")]
374 connection_pool: ConnectionPoolGuard<'a>,
375}
376
377pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
378where
379 S: Serializer,
380 T: Deref<Target = U>,
381 U: Serialize,
382{
383 Serialize::serialize(value.deref(), serializer)
384}
385
386impl Server {
387 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
388 let mut server = Self {
389 id: parking_lot::Mutex::new(id),
390 peer: Peer::new(id.0 as u32),
391 app_state: app_state.clone(),
392 connection_pool: Default::default(),
393 handlers: Default::default(),
394 teardown: watch::channel(false).0,
395 };
396
397 server
398 .add_request_handler(ping)
399 .add_request_handler(user_handler(create_room))
400 .add_request_handler(user_handler(join_room))
401 .add_request_handler(user_handler(rejoin_room))
402 .add_request_handler(user_handler(leave_room))
403 .add_request_handler(user_handler(set_room_participant_role))
404 .add_request_handler(user_handler(call))
405 .add_request_handler(user_handler(cancel_call))
406 .add_message_handler(user_message_handler(decline_call))
407 .add_request_handler(user_handler(update_participant_location))
408 .add_request_handler(user_handler(share_project))
409 .add_message_handler(unshare_project)
410 .add_request_handler(user_handler(join_project))
411 .add_request_handler(user_handler(join_hosted_project))
412 .add_request_handler(user_handler(rejoin_remote_projects))
413 .add_request_handler(user_handler(create_remote_project))
414 .add_request_handler(user_handler(create_dev_server))
415 .add_request_handler(user_handler(delete_dev_server))
416 .add_request_handler(dev_server_handler(share_remote_project))
417 .add_request_handler(dev_server_handler(shutdown_dev_server))
418 .add_request_handler(dev_server_handler(reconnect_dev_server))
419 .add_message_handler(user_message_handler(leave_project))
420 .add_request_handler(update_project)
421 .add_request_handler(update_worktree)
422 .add_message_handler(start_language_server)
423 .add_message_handler(update_language_server)
424 .add_message_handler(update_diagnostic_summary)
425 .add_message_handler(update_worktree_settings)
426 .add_request_handler(user_handler(
427 forward_read_only_project_request::<proto::GetHover>,
428 ))
429 .add_request_handler(user_handler(
430 forward_read_only_project_request::<proto::GetDefinition>,
431 ))
432 .add_request_handler(user_handler(
433 forward_read_only_project_request::<proto::GetTypeDefinition>,
434 ))
435 .add_request_handler(user_handler(
436 forward_read_only_project_request::<proto::GetReferences>,
437 ))
438 .add_request_handler(user_handler(
439 forward_read_only_project_request::<proto::SearchProject>,
440 ))
441 .add_request_handler(user_handler(
442 forward_read_only_project_request::<proto::GetDocumentHighlights>,
443 ))
444 .add_request_handler(user_handler(
445 forward_read_only_project_request::<proto::GetProjectSymbols>,
446 ))
447 .add_request_handler(user_handler(
448 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
449 ))
450 .add_request_handler(user_handler(
451 forward_read_only_project_request::<proto::OpenBufferById>,
452 ))
453 .add_request_handler(user_handler(
454 forward_read_only_project_request::<proto::SynchronizeBuffers>,
455 ))
456 .add_request_handler(user_handler(
457 forward_read_only_project_request::<proto::InlayHints>,
458 ))
459 .add_request_handler(user_handler(
460 forward_read_only_project_request::<proto::OpenBufferByPath>,
461 ))
462 .add_request_handler(user_handler(
463 forward_mutating_project_request::<proto::GetCompletions>,
464 ))
465 .add_request_handler(user_handler(
466 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
467 ))
468 .add_request_handler(user_handler(
469 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
470 ))
471 .add_request_handler(user_handler(
472 forward_mutating_project_request::<proto::GetCodeActions>,
473 ))
474 .add_request_handler(user_handler(
475 forward_mutating_project_request::<proto::ApplyCodeAction>,
476 ))
477 .add_request_handler(user_handler(
478 forward_mutating_project_request::<proto::PrepareRename>,
479 ))
480 .add_request_handler(user_handler(
481 forward_mutating_project_request::<proto::PerformRename>,
482 ))
483 .add_request_handler(user_handler(
484 forward_mutating_project_request::<proto::ReloadBuffers>,
485 ))
486 .add_request_handler(user_handler(
487 forward_mutating_project_request::<proto::FormatBuffers>,
488 ))
489 .add_request_handler(user_handler(
490 forward_mutating_project_request::<proto::CreateProjectEntry>,
491 ))
492 .add_request_handler(user_handler(
493 forward_mutating_project_request::<proto::RenameProjectEntry>,
494 ))
495 .add_request_handler(user_handler(
496 forward_mutating_project_request::<proto::CopyProjectEntry>,
497 ))
498 .add_request_handler(user_handler(
499 forward_mutating_project_request::<proto::DeleteProjectEntry>,
500 ))
501 .add_request_handler(user_handler(
502 forward_mutating_project_request::<proto::ExpandProjectEntry>,
503 ))
504 .add_request_handler(user_handler(
505 forward_mutating_project_request::<proto::OnTypeFormatting>,
506 ))
507 .add_request_handler(user_handler(
508 forward_mutating_project_request::<proto::SaveBuffer>,
509 ))
510 .add_request_handler(user_handler(
511 forward_mutating_project_request::<proto::BlameBuffer>,
512 ))
513 .add_request_handler(user_handler(
514 forward_mutating_project_request::<proto::MultiLspQuery>,
515 ))
516 .add_message_handler(create_buffer_for_peer)
517 .add_request_handler(update_buffer)
518 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
519 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
520 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
521 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
522 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
523 .add_request_handler(get_users)
524 .add_request_handler(user_handler(fuzzy_search_users))
525 .add_request_handler(user_handler(request_contact))
526 .add_request_handler(user_handler(remove_contact))
527 .add_request_handler(user_handler(respond_to_contact_request))
528 .add_request_handler(user_handler(create_channel))
529 .add_request_handler(user_handler(delete_channel))
530 .add_request_handler(user_handler(invite_channel_member))
531 .add_request_handler(user_handler(remove_channel_member))
532 .add_request_handler(user_handler(set_channel_member_role))
533 .add_request_handler(user_handler(set_channel_visibility))
534 .add_request_handler(user_handler(rename_channel))
535 .add_request_handler(user_handler(join_channel_buffer))
536 .add_request_handler(user_handler(leave_channel_buffer))
537 .add_message_handler(user_message_handler(update_channel_buffer))
538 .add_request_handler(user_handler(rejoin_channel_buffers))
539 .add_request_handler(user_handler(get_channel_members))
540 .add_request_handler(user_handler(respond_to_channel_invite))
541 .add_request_handler(user_handler(join_channel))
542 .add_request_handler(user_handler(join_channel_chat))
543 .add_message_handler(user_message_handler(leave_channel_chat))
544 .add_request_handler(user_handler(send_channel_message))
545 .add_request_handler(user_handler(remove_channel_message))
546 .add_request_handler(user_handler(update_channel_message))
547 .add_request_handler(user_handler(get_channel_messages))
548 .add_request_handler(user_handler(get_channel_messages_by_id))
549 .add_request_handler(user_handler(get_notifications))
550 .add_request_handler(user_handler(mark_notification_as_read))
551 .add_request_handler(user_handler(move_channel))
552 .add_request_handler(user_handler(follow))
553 .add_message_handler(user_message_handler(unfollow))
554 .add_message_handler(user_message_handler(update_followers))
555 .add_request_handler(user_handler(get_private_user_info))
556 .add_message_handler(user_message_handler(acknowledge_channel_message))
557 .add_message_handler(user_message_handler(acknowledge_buffer_version))
558 .add_streaming_request_handler({
559 let app_state = app_state.clone();
560 move |request, response, session| {
561 complete_with_language_model(
562 request,
563 response,
564 session,
565 app_state.config.openai_api_key.clone(),
566 app_state.config.google_ai_api_key.clone(),
567 app_state.config.anthropic_api_key.clone(),
568 )
569 }
570 })
571 .add_request_handler({
572 let app_state = app_state.clone();
573 user_handler(move |request, response, session| {
574 count_tokens_with_language_model(
575 request,
576 response,
577 session,
578 app_state.config.google_ai_api_key.clone(),
579 )
580 })
581 })
582 .add_request_handler({
583 user_handler(move |request, response, session| {
584 get_cached_embeddings(request, response, session)
585 })
586 })
587 .add_request_handler({
588 let app_state = app_state.clone();
589 user_handler(move |request, response, session| {
590 compute_embeddings(
591 request,
592 response,
593 session,
594 app_state.config.openai_api_key.clone(),
595 )
596 })
597 });
598
599 Arc::new(server)
600 }
601
602 pub async fn start(&self) -> Result<()> {
603 let server_id = *self.id.lock();
604 let app_state = self.app_state.clone();
605 let peer = self.peer.clone();
606 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
607 let pool = self.connection_pool.clone();
608 let live_kit_client = self.app_state.live_kit_client.clone();
609
610 let span = info_span!("start server");
611 self.app_state.executor.spawn_detached(
612 async move {
613 tracing::info!("waiting for cleanup timeout");
614 timeout.await;
615 tracing::info!("cleanup timeout expired, retrieving stale rooms");
616 if let Some((room_ids, channel_ids)) = app_state
617 .db
618 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
619 .await
620 .trace_err()
621 {
622 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
623 tracing::info!(
624 stale_channel_buffer_count = channel_ids.len(),
625 "retrieved stale channel buffers"
626 );
627
628 for channel_id in channel_ids {
629 if let Some(refreshed_channel_buffer) = app_state
630 .db
631 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
632 .await
633 .trace_err()
634 {
635 for connection_id in refreshed_channel_buffer.connection_ids {
636 peer.send(
637 connection_id,
638 proto::UpdateChannelBufferCollaborators {
639 channel_id: channel_id.to_proto(),
640 collaborators: refreshed_channel_buffer
641 .collaborators
642 .clone(),
643 },
644 )
645 .trace_err();
646 }
647 }
648 }
649
650 for room_id in room_ids {
651 let mut contacts_to_update = HashSet::default();
652 let mut canceled_calls_to_user_ids = Vec::new();
653 let mut live_kit_room = String::new();
654 let mut delete_live_kit_room = false;
655
656 if let Some(mut refreshed_room) = app_state
657 .db
658 .clear_stale_room_participants(room_id, server_id)
659 .await
660 .trace_err()
661 {
662 tracing::info!(
663 room_id = room_id.0,
664 new_participant_count = refreshed_room.room.participants.len(),
665 "refreshed room"
666 );
667 room_updated(&refreshed_room.room, &peer);
668 if let Some(channel) = refreshed_room.channel.as_ref() {
669 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
670 }
671 contacts_to_update
672 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
673 contacts_to_update
674 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
675 canceled_calls_to_user_ids =
676 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
677 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
678 delete_live_kit_room = refreshed_room.room.participants.is_empty();
679 }
680
681 {
682 let pool = pool.lock();
683 for canceled_user_id in canceled_calls_to_user_ids {
684 for connection_id in pool.user_connection_ids(canceled_user_id) {
685 peer.send(
686 connection_id,
687 proto::CallCanceled {
688 room_id: room_id.to_proto(),
689 },
690 )
691 .trace_err();
692 }
693 }
694 }
695
696 for user_id in contacts_to_update {
697 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
698 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
699 if let Some((busy, contacts)) = busy.zip(contacts) {
700 let pool = pool.lock();
701 let updated_contact = contact_for_user(user_id, busy, &pool);
702 for contact in contacts {
703 if let db::Contact::Accepted {
704 user_id: contact_user_id,
705 ..
706 } = contact
707 {
708 for contact_conn_id in
709 pool.user_connection_ids(contact_user_id)
710 {
711 peer.send(
712 contact_conn_id,
713 proto::UpdateContacts {
714 contacts: vec![updated_contact.clone()],
715 remove_contacts: Default::default(),
716 incoming_requests: Default::default(),
717 remove_incoming_requests: Default::default(),
718 outgoing_requests: Default::default(),
719 remove_outgoing_requests: Default::default(),
720 },
721 )
722 .trace_err();
723 }
724 }
725 }
726 }
727 }
728
729 if let Some(live_kit) = live_kit_client.as_ref() {
730 if delete_live_kit_room {
731 live_kit.delete_room(live_kit_room).await.trace_err();
732 }
733 }
734 }
735 }
736
737 app_state
738 .db
739 .delete_stale_servers(&app_state.config.zed_environment, server_id)
740 .await
741 .trace_err();
742 }
743 .instrument(span),
744 );
745 Ok(())
746 }
747
748 pub fn teardown(&self) {
749 self.peer.teardown();
750 self.connection_pool.lock().reset();
751 let _ = self.teardown.send(true);
752 }
753
754 #[cfg(test)]
755 pub fn reset(&self, id: ServerId) {
756 self.teardown();
757 *self.id.lock() = id;
758 self.peer.reset(id.0 as u32);
759 let _ = self.teardown.send(false);
760 }
761
762 #[cfg(test)]
763 pub fn id(&self) -> ServerId {
764 *self.id.lock()
765 }
766
767 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
768 where
769 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
770 Fut: 'static + Send + Future<Output = Result<()>>,
771 M: EnvelopedMessage,
772 {
773 let prev_handler = self.handlers.insert(
774 TypeId::of::<M>(),
775 Box::new(move |envelope, session| {
776 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
777 let received_at = envelope.received_at;
778 tracing::info!("message received");
779 let start_time = Instant::now();
780 let future = (handler)(*envelope, session);
781 async move {
782 let result = future.await;
783 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
784 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
785 let queue_duration_ms = total_duration_ms - processing_duration_ms;
786 let payload_type = M::NAME;
787
788 match result {
789 Err(error) => {
790 tracing::error!(
791 ?error,
792 total_duration_ms,
793 processing_duration_ms,
794 queue_duration_ms,
795 payload_type,
796 "error handling message"
797 )
798 }
799 Ok(()) => tracing::info!(
800 total_duration_ms,
801 processing_duration_ms,
802 queue_duration_ms,
803 "finished handling message"
804 ),
805 }
806 }
807 .boxed()
808 }),
809 );
810 if prev_handler.is_some() {
811 panic!("registered a handler for the same message twice");
812 }
813 self
814 }
815
816 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
817 where
818 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
819 Fut: 'static + Send + Future<Output = Result<()>>,
820 M: EnvelopedMessage,
821 {
822 self.add_handler(move |envelope, session| handler(envelope.payload, session));
823 self
824 }
825
826 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
827 where
828 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
829 Fut: Send + Future<Output = Result<()>>,
830 M: RequestMessage,
831 {
832 let handler = Arc::new(handler);
833 self.add_handler(move |envelope, session| {
834 let receipt = envelope.receipt();
835 let handler = handler.clone();
836 async move {
837 let peer = session.peer.clone();
838 let responded = Arc::new(AtomicBool::default());
839 let response = Response {
840 peer: peer.clone(),
841 responded: responded.clone(),
842 receipt,
843 };
844 match (handler)(envelope.payload, response, session).await {
845 Ok(()) => {
846 if responded.load(std::sync::atomic::Ordering::SeqCst) {
847 Ok(())
848 } else {
849 Err(anyhow!("handler did not send a response"))?
850 }
851 }
852 Err(error) => {
853 let proto_err = match &error {
854 Error::Internal(err) => err.to_proto(),
855 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
856 };
857 peer.respond_with_error(receipt, proto_err)?;
858 Err(error)
859 }
860 }
861 }
862 })
863 }
864
865 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
866 where
867 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
868 Fut: Send + Future<Output = Result<()>>,
869 M: RequestMessage,
870 {
871 let handler = Arc::new(handler);
872 self.add_handler(move |envelope, session| {
873 let receipt = envelope.receipt();
874 let handler = handler.clone();
875 async move {
876 let peer = session.peer.clone();
877 let response = StreamingResponse {
878 peer: peer.clone(),
879 receipt,
880 };
881 match (handler)(envelope.payload, response, session).await {
882 Ok(()) => {
883 peer.end_stream(receipt)?;
884 Ok(())
885 }
886 Err(error) => {
887 let proto_err = match &error {
888 Error::Internal(err) => err.to_proto(),
889 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
890 };
891 peer.respond_with_error(receipt, proto_err)?;
892 Err(error)
893 }
894 }
895 }
896 })
897 }
898
899 #[allow(clippy::too_many_arguments)]
900 pub fn handle_connection(
901 self: &Arc<Self>,
902 connection: Connection,
903 address: String,
904 principal: Principal,
905 zed_version: ZedVersion,
906 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
907 executor: Executor,
908 ) -> impl Future<Output = ()> {
909 let this = self.clone();
910 let span = info_span!("handle connection", %address,
911 connection_id=field::Empty,
912 user_id=field::Empty,
913 login=field::Empty,
914 impersonator=field::Empty,
915 dev_server_id=field::Empty
916 );
917 principal.update_span(&span);
918
919 let mut teardown = self.teardown.subscribe();
920 async move {
921 if *teardown.borrow() {
922 tracing::error!("server is tearing down");
923 return
924 }
925 let (connection_id, handle_io, mut incoming_rx) = this
926 .peer
927 .add_connection(connection, {
928 let executor = executor.clone();
929 move |duration| executor.sleep(duration)
930 });
931 tracing::Span::current().record("connection_id", format!("{}", connection_id));
932 tracing::info!("connection opened");
933
934 let http_client = match IsahcHttpClient::new() {
935 Ok(http_client) => http_client,
936 Err(error) => {
937 tracing::error!(?error, "failed to create HTTP client");
938 return;
939 }
940 };
941
942 let session = Session {
943 principal: principal.clone(),
944 connection_id,
945 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
946 peer: this.peer.clone(),
947 connection_pool: this.connection_pool.clone(),
948 live_kit_client: this.app_state.live_kit_client.clone(),
949 http_client,
950 rate_limiter: this.app_state.rate_limiter.clone(),
951 _executor: executor.clone(),
952 };
953
954 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
955 tracing::error!(?error, "failed to send initial client update");
956 return;
957 }
958
959 let handle_io = handle_io.fuse();
960 futures::pin_mut!(handle_io);
961
962 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
963 // This prevents deadlocks when e.g., client A performs a request to client B and
964 // client B performs a request to client A. If both clients stop processing further
965 // messages until their respective request completes, they won't have a chance to
966 // respond to the other client's request and cause a deadlock.
967 //
968 // This arrangement ensures we will attempt to process earlier messages first, but fall
969 // back to processing messages arrived later in the spirit of making progress.
970 let mut foreground_message_handlers = FuturesUnordered::new();
971 let concurrent_handlers = Arc::new(Semaphore::new(256));
972 loop {
973 let next_message = async {
974 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
975 let message = incoming_rx.next().await;
976 (permit, message)
977 }.fuse();
978 futures::pin_mut!(next_message);
979 futures::select_biased! {
980 _ = teardown.changed().fuse() => return,
981 result = handle_io => {
982 if let Err(error) = result {
983 tracing::error!(?error, "error handling I/O");
984 }
985 break;
986 }
987 _ = foreground_message_handlers.next() => {}
988 next_message = next_message => {
989 let (permit, message) = next_message;
990 if let Some(message) = message {
991 let type_name = message.payload_type_name();
992 // note: we copy all the fields from the parent span so we can query them in the logs.
993 // (https://github.com/tokio-rs/tracing/issues/2670).
994 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
995 user_id=field::Empty,
996 login=field::Empty,
997 impersonator=field::Empty,
998 dev_server_id=field::Empty
999 );
1000 principal.update_span(&span);
1001 let span_enter = span.enter();
1002 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1003 let is_background = message.is_background();
1004 let handle_message = (handler)(message, session.clone());
1005 drop(span_enter);
1006
1007 let handle_message = async move {
1008 handle_message.await;
1009 drop(permit);
1010 }.instrument(span);
1011 if is_background {
1012 executor.spawn_detached(handle_message);
1013 } else {
1014 foreground_message_handlers.push(handle_message);
1015 }
1016 } else {
1017 tracing::error!("no message handler");
1018 }
1019 } else {
1020 tracing::info!("connection closed");
1021 break;
1022 }
1023 }
1024 }
1025 }
1026
1027 drop(foreground_message_handlers);
1028 tracing::info!("signing out");
1029 if let Err(error) = connection_lost(session, teardown, executor).await {
1030 tracing::error!(?error, "error signing out");
1031 }
1032
1033 }.instrument(span)
1034 }
1035
1036 async fn send_initial_client_update(
1037 &self,
1038 connection_id: ConnectionId,
1039 principal: &Principal,
1040 zed_version: ZedVersion,
1041 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1042 session: &Session,
1043 ) -> Result<()> {
1044 self.peer.send(
1045 connection_id,
1046 proto::Hello {
1047 peer_id: Some(connection_id.into()),
1048 },
1049 )?;
1050 tracing::info!("sent hello message");
1051 if let Some(send_connection_id) = send_connection_id.take() {
1052 let _ = send_connection_id.send(connection_id);
1053 }
1054
1055 match principal {
1056 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1057 if !user.connected_once {
1058 self.peer.send(connection_id, proto::ShowContacts {})?;
1059 self.app_state
1060 .db
1061 .set_user_connected_once(user.id, true)
1062 .await?;
1063 }
1064
1065 let (contacts, channels_for_user, channel_invites, remote_projects) =
1066 future::try_join4(
1067 self.app_state.db.get_contacts(user.id),
1068 self.app_state.db.get_channels_for_user(user.id),
1069 self.app_state.db.get_channel_invites_for_user(user.id),
1070 self.app_state.db.remote_projects_update(user.id),
1071 )
1072 .await?;
1073
1074 {
1075 let mut pool = self.connection_pool.lock();
1076 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1077 for membership in &channels_for_user.channel_memberships {
1078 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1079 }
1080 self.peer.send(
1081 connection_id,
1082 build_initial_contacts_update(contacts, &pool),
1083 )?;
1084 self.peer.send(
1085 connection_id,
1086 build_update_user_channels(&channels_for_user),
1087 )?;
1088 self.peer.send(
1089 connection_id,
1090 build_channels_update(channels_for_user, channel_invites),
1091 )?;
1092 }
1093 send_remote_projects_update(user.id, remote_projects, session).await;
1094
1095 if let Some(incoming_call) =
1096 self.app_state.db.incoming_call_for_user(user.id).await?
1097 {
1098 self.peer.send(connection_id, incoming_call)?;
1099 }
1100
1101 update_user_contacts(user.id, &session).await?;
1102 }
1103 Principal::DevServer(dev_server) => {
1104 {
1105 let mut pool = self.connection_pool.lock();
1106 if pool.dev_server_connection_id(dev_server.id).is_some() {
1107 return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1108 };
1109 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1110 }
1111
1112 let projects = self
1113 .app_state
1114 .db
1115 .get_remote_projects_for_dev_server(dev_server.id)
1116 .await?;
1117 self.peer
1118 .send(connection_id, proto::DevServerInstructions { projects })?;
1119
1120 let status = self
1121 .app_state
1122 .db
1123 .remote_projects_update(dev_server.user_id)
1124 .await?;
1125 send_remote_projects_update(dev_server.user_id, status, &session).await;
1126 }
1127 }
1128
1129 Ok(())
1130 }
1131
1132 pub async fn invite_code_redeemed(
1133 self: &Arc<Self>,
1134 inviter_id: UserId,
1135 invitee_id: UserId,
1136 ) -> Result<()> {
1137 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1138 if let Some(code) = &user.invite_code {
1139 let pool = self.connection_pool.lock();
1140 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1141 for connection_id in pool.user_connection_ids(inviter_id) {
1142 self.peer.send(
1143 connection_id,
1144 proto::UpdateContacts {
1145 contacts: vec![invitee_contact.clone()],
1146 ..Default::default()
1147 },
1148 )?;
1149 self.peer.send(
1150 connection_id,
1151 proto::UpdateInviteInfo {
1152 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1153 count: user.invite_count as u32,
1154 },
1155 )?;
1156 }
1157 }
1158 }
1159 Ok(())
1160 }
1161
1162 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1163 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1164 if let Some(invite_code) = &user.invite_code {
1165 let pool = self.connection_pool.lock();
1166 for connection_id in pool.user_connection_ids(user_id) {
1167 self.peer.send(
1168 connection_id,
1169 proto::UpdateInviteInfo {
1170 url: format!(
1171 "{}{}",
1172 self.app_state.config.invite_link_prefix, invite_code
1173 ),
1174 count: user.invite_count as u32,
1175 },
1176 )?;
1177 }
1178 }
1179 }
1180 Ok(())
1181 }
1182
1183 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1184 ServerSnapshot {
1185 connection_pool: ConnectionPoolGuard {
1186 guard: self.connection_pool.lock(),
1187 _not_send: PhantomData,
1188 },
1189 peer: &self.peer,
1190 }
1191 }
1192}
1193
1194impl<'a> Deref for ConnectionPoolGuard<'a> {
1195 type Target = ConnectionPool;
1196
1197 fn deref(&self) -> &Self::Target {
1198 &self.guard
1199 }
1200}
1201
1202impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1203 fn deref_mut(&mut self) -> &mut Self::Target {
1204 &mut self.guard
1205 }
1206}
1207
1208impl<'a> Drop for ConnectionPoolGuard<'a> {
1209 fn drop(&mut self) {
1210 #[cfg(test)]
1211 self.check_invariants();
1212 }
1213}
1214
1215fn broadcast<F>(
1216 sender_id: Option<ConnectionId>,
1217 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1218 mut f: F,
1219) where
1220 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1221{
1222 for receiver_id in receiver_ids {
1223 if Some(receiver_id) != sender_id {
1224 if let Err(error) = f(receiver_id) {
1225 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1226 }
1227 }
1228 }
1229}
1230
1231pub struct ProtocolVersion(u32);
1232
1233impl Header for ProtocolVersion {
1234 fn name() -> &'static HeaderName {
1235 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1236 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1237 }
1238
1239 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1240 where
1241 Self: Sized,
1242 I: Iterator<Item = &'i axum::http::HeaderValue>,
1243 {
1244 let version = values
1245 .next()
1246 .ok_or_else(axum::headers::Error::invalid)?
1247 .to_str()
1248 .map_err(|_| axum::headers::Error::invalid())?
1249 .parse()
1250 .map_err(|_| axum::headers::Error::invalid())?;
1251 Ok(Self(version))
1252 }
1253
1254 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1255 values.extend([self.0.to_string().parse().unwrap()]);
1256 }
1257}
1258
1259pub struct AppVersionHeader(SemanticVersion);
1260impl Header for AppVersionHeader {
1261 fn name() -> &'static HeaderName {
1262 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1263 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1264 }
1265
1266 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1267 where
1268 Self: Sized,
1269 I: Iterator<Item = &'i axum::http::HeaderValue>,
1270 {
1271 let version = values
1272 .next()
1273 .ok_or_else(axum::headers::Error::invalid)?
1274 .to_str()
1275 .map_err(|_| axum::headers::Error::invalid())?
1276 .parse()
1277 .map_err(|_| axum::headers::Error::invalid())?;
1278 Ok(Self(version))
1279 }
1280
1281 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1282 values.extend([self.0.to_string().parse().unwrap()]);
1283 }
1284}
1285
1286pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1287 Router::new()
1288 .route("/rpc", get(handle_websocket_request))
1289 .layer(
1290 ServiceBuilder::new()
1291 .layer(Extension(server.app_state.clone()))
1292 .layer(middleware::from_fn(auth::validate_header)),
1293 )
1294 .route("/metrics", get(handle_metrics))
1295 .layer(Extension(server))
1296}
1297
1298pub async fn handle_websocket_request(
1299 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1300 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1301 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1302 Extension(server): Extension<Arc<Server>>,
1303 Extension(principal): Extension<Principal>,
1304 ws: WebSocketUpgrade,
1305) -> axum::response::Response {
1306 if protocol_version != rpc::PROTOCOL_VERSION {
1307 return (
1308 StatusCode::UPGRADE_REQUIRED,
1309 "client must be upgraded".to_string(),
1310 )
1311 .into_response();
1312 }
1313
1314 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1315 return (
1316 StatusCode::UPGRADE_REQUIRED,
1317 "no version header found".to_string(),
1318 )
1319 .into_response();
1320 };
1321
1322 if !version.can_collaborate() {
1323 return (
1324 StatusCode::UPGRADE_REQUIRED,
1325 "client must be upgraded".to_string(),
1326 )
1327 .into_response();
1328 }
1329
1330 let socket_address = socket_address.to_string();
1331 ws.on_upgrade(move |socket| {
1332 let socket = socket
1333 .map_ok(to_tungstenite_message)
1334 .err_into()
1335 .with(|message| async move { Ok(to_axum_message(message)) });
1336 let connection = Connection::new(Box::pin(socket));
1337 async move {
1338 server
1339 .handle_connection(
1340 connection,
1341 socket_address,
1342 principal,
1343 version,
1344 None,
1345 Executor::Production,
1346 )
1347 .await;
1348 }
1349 })
1350}
1351
1352pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1353 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1354 let connections_metric = CONNECTIONS_METRIC
1355 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1356
1357 let connections = server
1358 .connection_pool
1359 .lock()
1360 .connections()
1361 .filter(|connection| !connection.admin)
1362 .count();
1363 connections_metric.set(connections as _);
1364
1365 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1366 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1367 register_int_gauge!(
1368 "shared_projects",
1369 "number of open projects with one or more guests"
1370 )
1371 .unwrap()
1372 });
1373
1374 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1375 shared_projects_metric.set(shared_projects as _);
1376
1377 let encoder = prometheus::TextEncoder::new();
1378 let metric_families = prometheus::gather();
1379 let encoded_metrics = encoder
1380 .encode_to_string(&metric_families)
1381 .map_err(|err| anyhow!("{}", err))?;
1382 Ok(encoded_metrics)
1383}
1384
1385#[instrument(err, skip(executor))]
1386async fn connection_lost(
1387 session: Session,
1388 mut teardown: watch::Receiver<bool>,
1389 executor: Executor,
1390) -> Result<()> {
1391 session.peer.disconnect(session.connection_id);
1392 session
1393 .connection_pool()
1394 .await
1395 .remove_connection(session.connection_id)?;
1396
1397 session
1398 .db()
1399 .await
1400 .connection_lost(session.connection_id)
1401 .await
1402 .trace_err();
1403
1404 futures::select_biased! {
1405 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1406 match &session.principal {
1407 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1408 let session = session.for_user().unwrap();
1409
1410 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1411 leave_room_for_session(&session, session.connection_id).await.trace_err();
1412 leave_channel_buffers_for_session(&session)
1413 .await
1414 .trace_err();
1415
1416 if !session
1417 .connection_pool()
1418 .await
1419 .is_user_online(session.user_id())
1420 {
1421 let db = session.db().await;
1422 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1423 room_updated(&room, &session.peer);
1424 }
1425 }
1426
1427 update_user_contacts(session.user_id(), &session).await?;
1428 },
1429 Principal::DevServer(_) => {
1430 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1431 },
1432 }
1433 },
1434 _ = teardown.changed().fuse() => {}
1435 }
1436
1437 Ok(())
1438}
1439
1440/// Acknowledges a ping from a client, used to keep the connection alive.
1441async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1442 response.send(proto::Ack {})?;
1443 Ok(())
1444}
1445
1446/// Creates a new room for calling (outside of channels)
1447async fn create_room(
1448 _request: proto::CreateRoom,
1449 response: Response<proto::CreateRoom>,
1450 session: UserSession,
1451) -> Result<()> {
1452 let live_kit_room = nanoid::nanoid!(30);
1453
1454 let live_kit_connection_info = util::maybe!(async {
1455 let live_kit = session.live_kit_client.as_ref();
1456 let live_kit = live_kit?;
1457 let user_id = session.user_id().to_string();
1458
1459 let token = live_kit
1460 .room_token(&live_kit_room, &user_id.to_string())
1461 .trace_err()?;
1462
1463 Some(proto::LiveKitConnectionInfo {
1464 server_url: live_kit.url().into(),
1465 token,
1466 can_publish: true,
1467 })
1468 })
1469 .await;
1470
1471 let room = session
1472 .db()
1473 .await
1474 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1475 .await?;
1476
1477 response.send(proto::CreateRoomResponse {
1478 room: Some(room.clone()),
1479 live_kit_connection_info,
1480 })?;
1481
1482 update_user_contacts(session.user_id(), &session).await?;
1483 Ok(())
1484}
1485
1486/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1487async fn join_room(
1488 request: proto::JoinRoom,
1489 response: Response<proto::JoinRoom>,
1490 session: UserSession,
1491) -> Result<()> {
1492 let room_id = RoomId::from_proto(request.id);
1493
1494 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1495
1496 if let Some(channel_id) = channel_id {
1497 return join_channel_internal(channel_id, Box::new(response), session).await;
1498 }
1499
1500 let joined_room = {
1501 let room = session
1502 .db()
1503 .await
1504 .join_room(room_id, session.user_id(), session.connection_id)
1505 .await?;
1506 room_updated(&room.room, &session.peer);
1507 room.into_inner()
1508 };
1509
1510 for connection_id in session
1511 .connection_pool()
1512 .await
1513 .user_connection_ids(session.user_id())
1514 {
1515 session
1516 .peer
1517 .send(
1518 connection_id,
1519 proto::CallCanceled {
1520 room_id: room_id.to_proto(),
1521 },
1522 )
1523 .trace_err();
1524 }
1525
1526 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1527 if let Some(token) = live_kit
1528 .room_token(
1529 &joined_room.room.live_kit_room,
1530 &session.user_id().to_string(),
1531 )
1532 .trace_err()
1533 {
1534 Some(proto::LiveKitConnectionInfo {
1535 server_url: live_kit.url().into(),
1536 token,
1537 can_publish: true,
1538 })
1539 } else {
1540 None
1541 }
1542 } else {
1543 None
1544 };
1545
1546 response.send(proto::JoinRoomResponse {
1547 room: Some(joined_room.room),
1548 channel_id: None,
1549 live_kit_connection_info,
1550 })?;
1551
1552 update_user_contacts(session.user_id(), &session).await?;
1553 Ok(())
1554}
1555
1556/// Rejoin room is used to reconnect to a room after connection errors.
1557async fn rejoin_room(
1558 request: proto::RejoinRoom,
1559 response: Response<proto::RejoinRoom>,
1560 session: UserSession,
1561) -> Result<()> {
1562 let room;
1563 let channel;
1564 {
1565 let mut rejoined_room = session
1566 .db()
1567 .await
1568 .rejoin_room(request, session.user_id(), session.connection_id)
1569 .await?;
1570
1571 response.send(proto::RejoinRoomResponse {
1572 room: Some(rejoined_room.room.clone()),
1573 reshared_projects: rejoined_room
1574 .reshared_projects
1575 .iter()
1576 .map(|project| proto::ResharedProject {
1577 id: project.id.to_proto(),
1578 collaborators: project
1579 .collaborators
1580 .iter()
1581 .map(|collaborator| collaborator.to_proto())
1582 .collect(),
1583 })
1584 .collect(),
1585 rejoined_projects: rejoined_room
1586 .rejoined_projects
1587 .iter()
1588 .map(|rejoined_project| rejoined_project.to_proto())
1589 .collect(),
1590 })?;
1591 room_updated(&rejoined_room.room, &session.peer);
1592
1593 for project in &rejoined_room.reshared_projects {
1594 for collaborator in &project.collaborators {
1595 session
1596 .peer
1597 .send(
1598 collaborator.connection_id,
1599 proto::UpdateProjectCollaborator {
1600 project_id: project.id.to_proto(),
1601 old_peer_id: Some(project.old_connection_id.into()),
1602 new_peer_id: Some(session.connection_id.into()),
1603 },
1604 )
1605 .trace_err();
1606 }
1607
1608 broadcast(
1609 Some(session.connection_id),
1610 project
1611 .collaborators
1612 .iter()
1613 .map(|collaborator| collaborator.connection_id),
1614 |connection_id| {
1615 session.peer.forward_send(
1616 session.connection_id,
1617 connection_id,
1618 proto::UpdateProject {
1619 project_id: project.id.to_proto(),
1620 worktrees: project.worktrees.clone(),
1621 },
1622 )
1623 },
1624 );
1625 }
1626
1627 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1628
1629 let rejoined_room = rejoined_room.into_inner();
1630
1631 room = rejoined_room.room;
1632 channel = rejoined_room.channel;
1633 }
1634
1635 if let Some(channel) = channel {
1636 channel_updated(
1637 &channel,
1638 &room,
1639 &session.peer,
1640 &*session.connection_pool().await,
1641 );
1642 }
1643
1644 update_user_contacts(session.user_id(), &session).await?;
1645 Ok(())
1646}
1647
1648fn notify_rejoined_projects(
1649 rejoined_projects: &mut Vec<RejoinedProject>,
1650 session: &UserSession,
1651) -> Result<()> {
1652 for project in rejoined_projects.iter() {
1653 for collaborator in &project.collaborators {
1654 session
1655 .peer
1656 .send(
1657 collaborator.connection_id,
1658 proto::UpdateProjectCollaborator {
1659 project_id: project.id.to_proto(),
1660 old_peer_id: Some(project.old_connection_id.into()),
1661 new_peer_id: Some(session.connection_id.into()),
1662 },
1663 )
1664 .trace_err();
1665 }
1666 }
1667
1668 for project in rejoined_projects {
1669 for worktree in mem::take(&mut project.worktrees) {
1670 #[cfg(any(test, feature = "test-support"))]
1671 const MAX_CHUNK_SIZE: usize = 2;
1672 #[cfg(not(any(test, feature = "test-support")))]
1673 const MAX_CHUNK_SIZE: usize = 256;
1674
1675 // Stream this worktree's entries.
1676 let message = proto::UpdateWorktree {
1677 project_id: project.id.to_proto(),
1678 worktree_id: worktree.id,
1679 abs_path: worktree.abs_path.clone(),
1680 root_name: worktree.root_name,
1681 updated_entries: worktree.updated_entries,
1682 removed_entries: worktree.removed_entries,
1683 scan_id: worktree.scan_id,
1684 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1685 updated_repositories: worktree.updated_repositories,
1686 removed_repositories: worktree.removed_repositories,
1687 };
1688 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1689 session.peer.send(session.connection_id, update.clone())?;
1690 }
1691
1692 // Stream this worktree's diagnostics.
1693 for summary in worktree.diagnostic_summaries {
1694 session.peer.send(
1695 session.connection_id,
1696 proto::UpdateDiagnosticSummary {
1697 project_id: project.id.to_proto(),
1698 worktree_id: worktree.id,
1699 summary: Some(summary),
1700 },
1701 )?;
1702 }
1703
1704 for settings_file in worktree.settings_files {
1705 session.peer.send(
1706 session.connection_id,
1707 proto::UpdateWorktreeSettings {
1708 project_id: project.id.to_proto(),
1709 worktree_id: worktree.id,
1710 path: settings_file.path,
1711 content: Some(settings_file.content),
1712 },
1713 )?;
1714 }
1715 }
1716
1717 for language_server in &project.language_servers {
1718 session.peer.send(
1719 session.connection_id,
1720 proto::UpdateLanguageServer {
1721 project_id: project.id.to_proto(),
1722 language_server_id: language_server.id,
1723 variant: Some(
1724 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1725 proto::LspDiskBasedDiagnosticsUpdated {},
1726 ),
1727 ),
1728 },
1729 )?;
1730 }
1731 }
1732 Ok(())
1733}
1734
1735/// leave room disconnects from the room.
1736async fn leave_room(
1737 _: proto::LeaveRoom,
1738 response: Response<proto::LeaveRoom>,
1739 session: UserSession,
1740) -> Result<()> {
1741 leave_room_for_session(&session, session.connection_id).await?;
1742 response.send(proto::Ack {})?;
1743 Ok(())
1744}
1745
1746/// Updates the permissions of someone else in the room.
1747async fn set_room_participant_role(
1748 request: proto::SetRoomParticipantRole,
1749 response: Response<proto::SetRoomParticipantRole>,
1750 session: UserSession,
1751) -> Result<()> {
1752 let user_id = UserId::from_proto(request.user_id);
1753 let role = ChannelRole::from(request.role());
1754
1755 let (live_kit_room, can_publish) = {
1756 let room = session
1757 .db()
1758 .await
1759 .set_room_participant_role(
1760 session.user_id(),
1761 RoomId::from_proto(request.room_id),
1762 user_id,
1763 role,
1764 )
1765 .await?;
1766
1767 let live_kit_room = room.live_kit_room.clone();
1768 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1769 room_updated(&room, &session.peer);
1770 (live_kit_room, can_publish)
1771 };
1772
1773 if let Some(live_kit) = session.live_kit_client.as_ref() {
1774 live_kit
1775 .update_participant(
1776 live_kit_room.clone(),
1777 request.user_id.to_string(),
1778 live_kit_server::proto::ParticipantPermission {
1779 can_subscribe: true,
1780 can_publish,
1781 can_publish_data: can_publish,
1782 hidden: false,
1783 recorder: false,
1784 },
1785 )
1786 .await
1787 .trace_err();
1788 }
1789
1790 response.send(proto::Ack {})?;
1791 Ok(())
1792}
1793
1794/// Call someone else into the current room
1795async fn call(
1796 request: proto::Call,
1797 response: Response<proto::Call>,
1798 session: UserSession,
1799) -> Result<()> {
1800 let room_id = RoomId::from_proto(request.room_id);
1801 let calling_user_id = session.user_id();
1802 let calling_connection_id = session.connection_id;
1803 let called_user_id = UserId::from_proto(request.called_user_id);
1804 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1805 if !session
1806 .db()
1807 .await
1808 .has_contact(calling_user_id, called_user_id)
1809 .await?
1810 {
1811 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1812 }
1813
1814 let incoming_call = {
1815 let (room, incoming_call) = &mut *session
1816 .db()
1817 .await
1818 .call(
1819 room_id,
1820 calling_user_id,
1821 calling_connection_id,
1822 called_user_id,
1823 initial_project_id,
1824 )
1825 .await?;
1826 room_updated(&room, &session.peer);
1827 mem::take(incoming_call)
1828 };
1829 update_user_contacts(called_user_id, &session).await?;
1830
1831 let mut calls = session
1832 .connection_pool()
1833 .await
1834 .user_connection_ids(called_user_id)
1835 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1836 .collect::<FuturesUnordered<_>>();
1837
1838 while let Some(call_response) = calls.next().await {
1839 match call_response.as_ref() {
1840 Ok(_) => {
1841 response.send(proto::Ack {})?;
1842 return Ok(());
1843 }
1844 Err(_) => {
1845 call_response.trace_err();
1846 }
1847 }
1848 }
1849
1850 {
1851 let room = session
1852 .db()
1853 .await
1854 .call_failed(room_id, called_user_id)
1855 .await?;
1856 room_updated(&room, &session.peer);
1857 }
1858 update_user_contacts(called_user_id, &session).await?;
1859
1860 Err(anyhow!("failed to ring user"))?
1861}
1862
1863/// Cancel an outgoing call.
1864async fn cancel_call(
1865 request: proto::CancelCall,
1866 response: Response<proto::CancelCall>,
1867 session: UserSession,
1868) -> Result<()> {
1869 let called_user_id = UserId::from_proto(request.called_user_id);
1870 let room_id = RoomId::from_proto(request.room_id);
1871 {
1872 let room = session
1873 .db()
1874 .await
1875 .cancel_call(room_id, session.connection_id, called_user_id)
1876 .await?;
1877 room_updated(&room, &session.peer);
1878 }
1879
1880 for connection_id in session
1881 .connection_pool()
1882 .await
1883 .user_connection_ids(called_user_id)
1884 {
1885 session
1886 .peer
1887 .send(
1888 connection_id,
1889 proto::CallCanceled {
1890 room_id: room_id.to_proto(),
1891 },
1892 )
1893 .trace_err();
1894 }
1895 response.send(proto::Ack {})?;
1896
1897 update_user_contacts(called_user_id, &session).await?;
1898 Ok(())
1899}
1900
1901/// Decline an incoming call.
1902async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1903 let room_id = RoomId::from_proto(message.room_id);
1904 {
1905 let room = session
1906 .db()
1907 .await
1908 .decline_call(Some(room_id), session.user_id())
1909 .await?
1910 .ok_or_else(|| anyhow!("failed to decline call"))?;
1911 room_updated(&room, &session.peer);
1912 }
1913
1914 for connection_id in session
1915 .connection_pool()
1916 .await
1917 .user_connection_ids(session.user_id())
1918 {
1919 session
1920 .peer
1921 .send(
1922 connection_id,
1923 proto::CallCanceled {
1924 room_id: room_id.to_proto(),
1925 },
1926 )
1927 .trace_err();
1928 }
1929 update_user_contacts(session.user_id(), &session).await?;
1930 Ok(())
1931}
1932
1933/// Updates other participants in the room with your current location.
1934async fn update_participant_location(
1935 request: proto::UpdateParticipantLocation,
1936 response: Response<proto::UpdateParticipantLocation>,
1937 session: UserSession,
1938) -> Result<()> {
1939 let room_id = RoomId::from_proto(request.room_id);
1940 let location = request
1941 .location
1942 .ok_or_else(|| anyhow!("invalid location"))?;
1943
1944 let db = session.db().await;
1945 let room = db
1946 .update_room_participant_location(room_id, session.connection_id, location)
1947 .await?;
1948
1949 room_updated(&room, &session.peer);
1950 response.send(proto::Ack {})?;
1951 Ok(())
1952}
1953
1954/// Share a project into the room.
1955async fn share_project(
1956 request: proto::ShareProject,
1957 response: Response<proto::ShareProject>,
1958 session: UserSession,
1959) -> Result<()> {
1960 let (project_id, room) = &*session
1961 .db()
1962 .await
1963 .share_project(
1964 RoomId::from_proto(request.room_id),
1965 session.connection_id,
1966 &request.worktrees,
1967 request
1968 .remote_project_id
1969 .map(|id| RemoteProjectId::from_proto(id)),
1970 )
1971 .await?;
1972 response.send(proto::ShareProjectResponse {
1973 project_id: project_id.to_proto(),
1974 })?;
1975 room_updated(&room, &session.peer);
1976
1977 Ok(())
1978}
1979
1980/// Unshare a project from the room.
1981async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1982 let project_id = ProjectId::from_proto(message.project_id);
1983 unshare_project_internal(
1984 project_id,
1985 session.connection_id,
1986 session.user_id(),
1987 &session,
1988 )
1989 .await
1990}
1991
1992async fn unshare_project_internal(
1993 project_id: ProjectId,
1994 connection_id: ConnectionId,
1995 user_id: Option<UserId>,
1996 session: &Session,
1997) -> Result<()> {
1998 let (room, guest_connection_ids) = &*session
1999 .db()
2000 .await
2001 .unshare_project(project_id, connection_id, user_id)
2002 .await?;
2003
2004 let message = proto::UnshareProject {
2005 project_id: project_id.to_proto(),
2006 };
2007
2008 broadcast(
2009 Some(connection_id),
2010 guest_connection_ids.iter().copied(),
2011 |conn_id| session.peer.send(conn_id, message.clone()),
2012 );
2013 if let Some(room) = room {
2014 room_updated(room, &session.peer);
2015 }
2016
2017 Ok(())
2018}
2019
2020/// DevServer makes a project available online
2021async fn share_remote_project(
2022 request: proto::ShareRemoteProject,
2023 response: Response<proto::ShareRemoteProject>,
2024 session: DevServerSession,
2025) -> Result<()> {
2026 let (remote_project, user_id, status) = session
2027 .db()
2028 .await
2029 .share_remote_project(
2030 RemoteProjectId::from_proto(request.remote_project_id),
2031 session.dev_server_id(),
2032 session.connection_id,
2033 &request.worktrees,
2034 )
2035 .await?;
2036 let Some(project_id) = remote_project.project_id else {
2037 return Err(anyhow!("failed to share remote project"))?;
2038 };
2039
2040 send_remote_projects_update(user_id, status, &session).await;
2041
2042 response.send(proto::ShareProjectResponse { project_id })?;
2043
2044 Ok(())
2045}
2046
2047/// Join someone elses shared project.
2048async fn join_project(
2049 request: proto::JoinProject,
2050 response: Response<proto::JoinProject>,
2051 session: UserSession,
2052) -> Result<()> {
2053 let project_id = ProjectId::from_proto(request.project_id);
2054
2055 tracing::info!(%project_id, "join project");
2056
2057 let db = session.db().await;
2058 let (project, replica_id) = &mut *db
2059 .join_project(project_id, session.connection_id, session.user_id())
2060 .await?;
2061 drop(db);
2062 tracing::info!(%project_id, "join remote project");
2063 join_project_internal(response, session, project, replica_id)
2064}
2065
2066trait JoinProjectInternalResponse {
2067 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2068}
2069impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2070 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2071 Response::<proto::JoinProject>::send(self, result)
2072 }
2073}
2074impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2075 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2076 Response::<proto::JoinHostedProject>::send(self, result)
2077 }
2078}
2079
2080fn join_project_internal(
2081 response: impl JoinProjectInternalResponse,
2082 session: UserSession,
2083 project: &mut Project,
2084 replica_id: &ReplicaId,
2085) -> Result<()> {
2086 let collaborators = project
2087 .collaborators
2088 .iter()
2089 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2090 .map(|collaborator| collaborator.to_proto())
2091 .collect::<Vec<_>>();
2092 let project_id = project.id;
2093 let guest_user_id = session.user_id();
2094
2095 let worktrees = project
2096 .worktrees
2097 .iter()
2098 .map(|(id, worktree)| proto::WorktreeMetadata {
2099 id: *id,
2100 root_name: worktree.root_name.clone(),
2101 visible: worktree.visible,
2102 abs_path: worktree.abs_path.clone(),
2103 })
2104 .collect::<Vec<_>>();
2105
2106 let add_project_collaborator = proto::AddProjectCollaborator {
2107 project_id: project_id.to_proto(),
2108 collaborator: Some(proto::Collaborator {
2109 peer_id: Some(session.connection_id.into()),
2110 replica_id: replica_id.0 as u32,
2111 user_id: guest_user_id.to_proto(),
2112 }),
2113 };
2114
2115 for collaborator in &collaborators {
2116 session
2117 .peer
2118 .send(
2119 collaborator.peer_id.unwrap().into(),
2120 add_project_collaborator.clone(),
2121 )
2122 .trace_err();
2123 }
2124
2125 // First, we send the metadata associated with each worktree.
2126 response.send(proto::JoinProjectResponse {
2127 project_id: project.id.0 as u64,
2128 worktrees: worktrees.clone(),
2129 replica_id: replica_id.0 as u32,
2130 collaborators: collaborators.clone(),
2131 language_servers: project.language_servers.clone(),
2132 role: project.role.into(),
2133 remote_project_id: project
2134 .remote_project_id
2135 .map(|remote_project_id| remote_project_id.0 as u64),
2136 })?;
2137
2138 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2139 #[cfg(any(test, feature = "test-support"))]
2140 const MAX_CHUNK_SIZE: usize = 2;
2141 #[cfg(not(any(test, feature = "test-support")))]
2142 const MAX_CHUNK_SIZE: usize = 256;
2143
2144 // Stream this worktree's entries.
2145 let message = proto::UpdateWorktree {
2146 project_id: project_id.to_proto(),
2147 worktree_id,
2148 abs_path: worktree.abs_path.clone(),
2149 root_name: worktree.root_name,
2150 updated_entries: worktree.entries,
2151 removed_entries: Default::default(),
2152 scan_id: worktree.scan_id,
2153 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2154 updated_repositories: worktree.repository_entries.into_values().collect(),
2155 removed_repositories: Default::default(),
2156 };
2157 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2158 session.peer.send(session.connection_id, update.clone())?;
2159 }
2160
2161 // Stream this worktree's diagnostics.
2162 for summary in worktree.diagnostic_summaries {
2163 session.peer.send(
2164 session.connection_id,
2165 proto::UpdateDiagnosticSummary {
2166 project_id: project_id.to_proto(),
2167 worktree_id: worktree.id,
2168 summary: Some(summary),
2169 },
2170 )?;
2171 }
2172
2173 for settings_file in worktree.settings_files {
2174 session.peer.send(
2175 session.connection_id,
2176 proto::UpdateWorktreeSettings {
2177 project_id: project_id.to_proto(),
2178 worktree_id: worktree.id,
2179 path: settings_file.path,
2180 content: Some(settings_file.content),
2181 },
2182 )?;
2183 }
2184 }
2185
2186 for language_server in &project.language_servers {
2187 session.peer.send(
2188 session.connection_id,
2189 proto::UpdateLanguageServer {
2190 project_id: project_id.to_proto(),
2191 language_server_id: language_server.id,
2192 variant: Some(
2193 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2194 proto::LspDiskBasedDiagnosticsUpdated {},
2195 ),
2196 ),
2197 },
2198 )?;
2199 }
2200
2201 Ok(())
2202}
2203
2204/// Leave someone elses shared project.
2205async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2206 let sender_id = session.connection_id;
2207 let project_id = ProjectId::from_proto(request.project_id);
2208 let db = session.db().await;
2209 if db.is_hosted_project(project_id).await? {
2210 let project = db.leave_hosted_project(project_id, sender_id).await?;
2211 project_left(&project, &session);
2212 return Ok(());
2213 }
2214
2215 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2216 tracing::info!(
2217 %project_id,
2218 "leave project"
2219 );
2220
2221 project_left(&project, &session);
2222 if let Some(room) = room {
2223 room_updated(&room, &session.peer);
2224 }
2225
2226 Ok(())
2227}
2228
2229async fn join_hosted_project(
2230 request: proto::JoinHostedProject,
2231 response: Response<proto::JoinHostedProject>,
2232 session: UserSession,
2233) -> Result<()> {
2234 let (mut project, replica_id) = session
2235 .db()
2236 .await
2237 .join_hosted_project(
2238 ProjectId(request.project_id as i32),
2239 session.user_id(),
2240 session.connection_id,
2241 )
2242 .await?;
2243
2244 join_project_internal(response, session, &mut project, &replica_id)
2245}
2246
2247async fn create_remote_project(
2248 request: proto::CreateRemoteProject,
2249 response: Response<proto::CreateRemoteProject>,
2250 session: UserSession,
2251) -> Result<()> {
2252 let dev_server_id = DevServerId(request.dev_server_id as i32);
2253 let dev_server_connection_id = session
2254 .connection_pool()
2255 .await
2256 .dev_server_connection_id(dev_server_id);
2257 let Some(dev_server_connection_id) = dev_server_connection_id else {
2258 Err(ErrorCode::DevServerOffline
2259 .message("Cannot create a remote project when the dev server is offline".to_string())
2260 .anyhow())?
2261 };
2262
2263 let path = request.path.clone();
2264 //Check that the path exists on the dev server
2265 session
2266 .peer
2267 .forward_request(
2268 session.connection_id,
2269 dev_server_connection_id,
2270 proto::ValidateRemoteProjectRequest { path: path.clone() },
2271 )
2272 .await?;
2273
2274 let (remote_project, update) = session
2275 .db()
2276 .await
2277 .create_remote_project(
2278 DevServerId(request.dev_server_id as i32),
2279 &request.path,
2280 session.user_id(),
2281 )
2282 .await?;
2283
2284 let projects = session
2285 .db()
2286 .await
2287 .get_remote_projects_for_dev_server(remote_project.dev_server_id)
2288 .await?;
2289
2290 session.peer.send(
2291 dev_server_connection_id,
2292 proto::DevServerInstructions { projects },
2293 )?;
2294
2295 send_remote_projects_update(session.user_id(), update, &session).await;
2296
2297 response.send(proto::CreateRemoteProjectResponse {
2298 remote_project: Some(remote_project.to_proto(None)),
2299 })?;
2300 Ok(())
2301}
2302
2303async fn create_dev_server(
2304 request: proto::CreateDevServer,
2305 response: Response<proto::CreateDevServer>,
2306 session: UserSession,
2307) -> Result<()> {
2308 let access_token = auth::random_token();
2309 let hashed_access_token = auth::hash_access_token(&access_token);
2310
2311 let (dev_server, status) = session
2312 .db()
2313 .await
2314 .create_dev_server(&request.name, &hashed_access_token, session.user_id())
2315 .await?;
2316
2317 send_remote_projects_update(session.user_id(), status, &session).await;
2318
2319 response.send(proto::CreateDevServerResponse {
2320 dev_server_id: dev_server.id.0 as u64,
2321 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2322 name: request.name.clone(),
2323 })?;
2324 Ok(())
2325}
2326
2327async fn delete_dev_server(
2328 request: proto::DeleteDevServer,
2329 response: Response<proto::DeleteDevServer>,
2330 session: UserSession,
2331) -> Result<()> {
2332 let dev_server_id = DevServerId(request.dev_server_id as i32);
2333 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2334 if dev_server.user_id != session.user_id() {
2335 return Err(anyhow!(ErrorCode::Forbidden))?;
2336 }
2337
2338 let connection_id = session
2339 .connection_pool()
2340 .await
2341 .dev_server_connection_id(dev_server_id);
2342 if let Some(connection_id) = connection_id {
2343 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2344 session
2345 .peer
2346 .send(connection_id, proto::ShutdownDevServer {})?;
2347 }
2348
2349 let status = session
2350 .db()
2351 .await
2352 .delete_dev_server(dev_server_id, session.user_id())
2353 .await?;
2354
2355 send_remote_projects_update(session.user_id(), status, &session).await;
2356
2357 response.send(proto::Ack {})?;
2358 Ok(())
2359}
2360
2361async fn rejoin_remote_projects(
2362 request: proto::RejoinRemoteProjects,
2363 response: Response<proto::RejoinRemoteProjects>,
2364 session: UserSession,
2365) -> Result<()> {
2366 let mut rejoined_projects = {
2367 let db = session.db().await;
2368 db.rejoin_remote_projects(
2369 &request.rejoined_projects,
2370 session.user_id(),
2371 session.0.connection_id,
2372 )
2373 .await?
2374 };
2375 notify_rejoined_projects(&mut rejoined_projects, &session)?;
2376
2377 response.send(proto::RejoinRemoteProjectsResponse {
2378 rejoined_projects: rejoined_projects
2379 .into_iter()
2380 .map(|project| project.to_proto())
2381 .collect(),
2382 })
2383}
2384
2385async fn reconnect_dev_server(
2386 request: proto::ReconnectDevServer,
2387 response: Response<proto::ReconnectDevServer>,
2388 session: DevServerSession,
2389) -> Result<()> {
2390 let reshared_projects = {
2391 let db = session.db().await;
2392 db.reshare_remote_projects(
2393 &request.reshared_projects,
2394 session.dev_server_id(),
2395 session.0.connection_id,
2396 )
2397 .await?
2398 };
2399
2400 for project in &reshared_projects {
2401 for collaborator in &project.collaborators {
2402 session
2403 .peer
2404 .send(
2405 collaborator.connection_id,
2406 proto::UpdateProjectCollaborator {
2407 project_id: project.id.to_proto(),
2408 old_peer_id: Some(project.old_connection_id.into()),
2409 new_peer_id: Some(session.connection_id.into()),
2410 },
2411 )
2412 .trace_err();
2413 }
2414
2415 broadcast(
2416 Some(session.connection_id),
2417 project
2418 .collaborators
2419 .iter()
2420 .map(|collaborator| collaborator.connection_id),
2421 |connection_id| {
2422 session.peer.forward_send(
2423 session.connection_id,
2424 connection_id,
2425 proto::UpdateProject {
2426 project_id: project.id.to_proto(),
2427 worktrees: project.worktrees.clone(),
2428 },
2429 )
2430 },
2431 );
2432 }
2433
2434 response.send(proto::ReconnectDevServerResponse {
2435 reshared_projects: reshared_projects
2436 .iter()
2437 .map(|project| proto::ResharedProject {
2438 id: project.id.to_proto(),
2439 collaborators: project
2440 .collaborators
2441 .iter()
2442 .map(|collaborator| collaborator.to_proto())
2443 .collect(),
2444 })
2445 .collect(),
2446 })?;
2447
2448 Ok(())
2449}
2450
2451async fn shutdown_dev_server(
2452 _: proto::ShutdownDevServer,
2453 response: Response<proto::ShutdownDevServer>,
2454 session: DevServerSession,
2455) -> Result<()> {
2456 response.send(proto::Ack {})?;
2457 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await
2458}
2459
2460async fn shutdown_dev_server_internal(
2461 dev_server_id: DevServerId,
2462 connection_id: ConnectionId,
2463 session: &Session,
2464) -> Result<()> {
2465 let (remote_projects, dev_server) = {
2466 let db = session.db().await;
2467 let remote_projects = db.get_remote_projects_for_dev_server(dev_server_id).await?;
2468 let dev_server = db.get_dev_server(dev_server_id).await?;
2469 (remote_projects, dev_server)
2470 };
2471
2472 for project_id in remote_projects.iter().filter_map(|p| p.project_id) {
2473 unshare_project_internal(
2474 ProjectId::from_proto(project_id),
2475 connection_id,
2476 None,
2477 session,
2478 )
2479 .await?;
2480 }
2481
2482 session
2483 .connection_pool()
2484 .await
2485 .set_dev_server_offline(dev_server_id);
2486
2487 let status = session
2488 .db()
2489 .await
2490 .remote_projects_update(dev_server.user_id)
2491 .await?;
2492 send_remote_projects_update(dev_server.user_id, status, &session).await;
2493
2494 Ok(())
2495}
2496
2497/// Updates other participants with changes to the project
2498async fn update_project(
2499 request: proto::UpdateProject,
2500 response: Response<proto::UpdateProject>,
2501 session: Session,
2502) -> Result<()> {
2503 let project_id = ProjectId::from_proto(request.project_id);
2504 let (room, guest_connection_ids) = &*session
2505 .db()
2506 .await
2507 .update_project(project_id, session.connection_id, &request.worktrees)
2508 .await?;
2509 broadcast(
2510 Some(session.connection_id),
2511 guest_connection_ids.iter().copied(),
2512 |connection_id| {
2513 session
2514 .peer
2515 .forward_send(session.connection_id, connection_id, request.clone())
2516 },
2517 );
2518 if let Some(room) = room {
2519 room_updated(&room, &session.peer);
2520 }
2521 response.send(proto::Ack {})?;
2522
2523 Ok(())
2524}
2525
2526/// Updates other participants with changes to the worktree
2527async fn update_worktree(
2528 request: proto::UpdateWorktree,
2529 response: Response<proto::UpdateWorktree>,
2530 session: Session,
2531) -> Result<()> {
2532 let guest_connection_ids = session
2533 .db()
2534 .await
2535 .update_worktree(&request, session.connection_id)
2536 .await?;
2537
2538 broadcast(
2539 Some(session.connection_id),
2540 guest_connection_ids.iter().copied(),
2541 |connection_id| {
2542 session
2543 .peer
2544 .forward_send(session.connection_id, connection_id, request.clone())
2545 },
2546 );
2547 response.send(proto::Ack {})?;
2548 Ok(())
2549}
2550
2551/// Updates other participants with changes to the diagnostics
2552async fn update_diagnostic_summary(
2553 message: proto::UpdateDiagnosticSummary,
2554 session: Session,
2555) -> Result<()> {
2556 let guest_connection_ids = session
2557 .db()
2558 .await
2559 .update_diagnostic_summary(&message, session.connection_id)
2560 .await?;
2561
2562 broadcast(
2563 Some(session.connection_id),
2564 guest_connection_ids.iter().copied(),
2565 |connection_id| {
2566 session
2567 .peer
2568 .forward_send(session.connection_id, connection_id, message.clone())
2569 },
2570 );
2571
2572 Ok(())
2573}
2574
2575/// Updates other participants with changes to the worktree settings
2576async fn update_worktree_settings(
2577 message: proto::UpdateWorktreeSettings,
2578 session: Session,
2579) -> Result<()> {
2580 let guest_connection_ids = session
2581 .db()
2582 .await
2583 .update_worktree_settings(&message, session.connection_id)
2584 .await?;
2585
2586 broadcast(
2587 Some(session.connection_id),
2588 guest_connection_ids.iter().copied(),
2589 |connection_id| {
2590 session
2591 .peer
2592 .forward_send(session.connection_id, connection_id, message.clone())
2593 },
2594 );
2595
2596 Ok(())
2597}
2598
2599/// Notify other participants that a language server has started.
2600async fn start_language_server(
2601 request: proto::StartLanguageServer,
2602 session: Session,
2603) -> Result<()> {
2604 let guest_connection_ids = session
2605 .db()
2606 .await
2607 .start_language_server(&request, session.connection_id)
2608 .await?;
2609
2610 broadcast(
2611 Some(session.connection_id),
2612 guest_connection_ids.iter().copied(),
2613 |connection_id| {
2614 session
2615 .peer
2616 .forward_send(session.connection_id, connection_id, request.clone())
2617 },
2618 );
2619 Ok(())
2620}
2621
2622/// Notify other participants that a language server has changed.
2623async fn update_language_server(
2624 request: proto::UpdateLanguageServer,
2625 session: Session,
2626) -> Result<()> {
2627 let project_id = ProjectId::from_proto(request.project_id);
2628 let project_connection_ids = session
2629 .db()
2630 .await
2631 .project_connection_ids(project_id, session.connection_id, true)
2632 .await?;
2633 broadcast(
2634 Some(session.connection_id),
2635 project_connection_ids.iter().copied(),
2636 |connection_id| {
2637 session
2638 .peer
2639 .forward_send(session.connection_id, connection_id, request.clone())
2640 },
2641 );
2642 Ok(())
2643}
2644
2645/// forward a project request to the host. These requests should be read only
2646/// as guests are allowed to send them.
2647async fn forward_read_only_project_request<T>(
2648 request: T,
2649 response: Response<T>,
2650 session: UserSession,
2651) -> Result<()>
2652where
2653 T: EntityMessage + RequestMessage,
2654{
2655 let project_id = ProjectId::from_proto(request.remote_entity_id());
2656 let host_connection_id = session
2657 .db()
2658 .await
2659 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2660 .await?;
2661 let payload = session
2662 .peer
2663 .forward_request(session.connection_id, host_connection_id, request)
2664 .await?;
2665 response.send(payload)?;
2666 Ok(())
2667}
2668
2669/// forward a project request to the host. These requests are disallowed
2670/// for guests.
2671async fn forward_mutating_project_request<T>(
2672 request: T,
2673 response: Response<T>,
2674 session: UserSession,
2675) -> Result<()>
2676where
2677 T: EntityMessage + RequestMessage,
2678{
2679 let project_id = ProjectId::from_proto(request.remote_entity_id());
2680 let host_connection_id = session
2681 .db()
2682 .await
2683 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2684 .await?;
2685 let payload = session
2686 .peer
2687 .forward_request(session.connection_id, host_connection_id, request)
2688 .await?;
2689 response.send(payload)?;
2690 Ok(())
2691}
2692
2693/// Notify other participants that a new buffer has been created
2694async fn create_buffer_for_peer(
2695 request: proto::CreateBufferForPeer,
2696 session: Session,
2697) -> Result<()> {
2698 session
2699 .db()
2700 .await
2701 .check_user_is_project_host(
2702 ProjectId::from_proto(request.project_id),
2703 session.connection_id,
2704 )
2705 .await?;
2706 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2707 session
2708 .peer
2709 .forward_send(session.connection_id, peer_id.into(), request)?;
2710 Ok(())
2711}
2712
2713/// Notify other participants that a buffer has been updated. This is
2714/// allowed for guests as long as the update is limited to selections.
2715async fn update_buffer(
2716 request: proto::UpdateBuffer,
2717 response: Response<proto::UpdateBuffer>,
2718 session: Session,
2719) -> Result<()> {
2720 let project_id = ProjectId::from_proto(request.project_id);
2721 let mut capability = Capability::ReadOnly;
2722
2723 for op in request.operations.iter() {
2724 match op.variant {
2725 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2726 Some(_) => capability = Capability::ReadWrite,
2727 }
2728 }
2729
2730 let host = {
2731 let guard = session
2732 .db()
2733 .await
2734 .connections_for_buffer_update(
2735 project_id,
2736 session.principal_id(),
2737 session.connection_id,
2738 capability,
2739 )
2740 .await?;
2741
2742 let (host, guests) = &*guard;
2743
2744 broadcast(
2745 Some(session.connection_id),
2746 guests.clone(),
2747 |connection_id| {
2748 session
2749 .peer
2750 .forward_send(session.connection_id, connection_id, request.clone())
2751 },
2752 );
2753
2754 *host
2755 };
2756
2757 if host != session.connection_id {
2758 session
2759 .peer
2760 .forward_request(session.connection_id, host, request.clone())
2761 .await?;
2762 }
2763
2764 response.send(proto::Ack {})?;
2765 Ok(())
2766}
2767
2768/// Notify other participants that a project has been updated.
2769async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2770 request: T,
2771 session: Session,
2772) -> Result<()> {
2773 let project_id = ProjectId::from_proto(request.remote_entity_id());
2774 let project_connection_ids = session
2775 .db()
2776 .await
2777 .project_connection_ids(project_id, session.connection_id, false)
2778 .await?;
2779
2780 broadcast(
2781 Some(session.connection_id),
2782 project_connection_ids.iter().copied(),
2783 |connection_id| {
2784 session
2785 .peer
2786 .forward_send(session.connection_id, connection_id, request.clone())
2787 },
2788 );
2789 Ok(())
2790}
2791
2792/// Start following another user in a call.
2793async fn follow(
2794 request: proto::Follow,
2795 response: Response<proto::Follow>,
2796 session: UserSession,
2797) -> Result<()> {
2798 let room_id = RoomId::from_proto(request.room_id);
2799 let project_id = request.project_id.map(ProjectId::from_proto);
2800 let leader_id = request
2801 .leader_id
2802 .ok_or_else(|| anyhow!("invalid leader id"))?
2803 .into();
2804 let follower_id = session.connection_id;
2805
2806 session
2807 .db()
2808 .await
2809 .check_room_participants(room_id, leader_id, session.connection_id)
2810 .await?;
2811
2812 let response_payload = session
2813 .peer
2814 .forward_request(session.connection_id, leader_id, request)
2815 .await?;
2816 response.send(response_payload)?;
2817
2818 if let Some(project_id) = project_id {
2819 let room = session
2820 .db()
2821 .await
2822 .follow(room_id, project_id, leader_id, follower_id)
2823 .await?;
2824 room_updated(&room, &session.peer);
2825 }
2826
2827 Ok(())
2828}
2829
2830/// Stop following another user in a call.
2831async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2832 let room_id = RoomId::from_proto(request.room_id);
2833 let project_id = request.project_id.map(ProjectId::from_proto);
2834 let leader_id = request
2835 .leader_id
2836 .ok_or_else(|| anyhow!("invalid leader id"))?
2837 .into();
2838 let follower_id = session.connection_id;
2839
2840 session
2841 .db()
2842 .await
2843 .check_room_participants(room_id, leader_id, session.connection_id)
2844 .await?;
2845
2846 session
2847 .peer
2848 .forward_send(session.connection_id, leader_id, request)?;
2849
2850 if let Some(project_id) = project_id {
2851 let room = session
2852 .db()
2853 .await
2854 .unfollow(room_id, project_id, leader_id, follower_id)
2855 .await?;
2856 room_updated(&room, &session.peer);
2857 }
2858
2859 Ok(())
2860}
2861
2862/// Notify everyone following you of your current location.
2863async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2864 let room_id = RoomId::from_proto(request.room_id);
2865 let database = session.db.lock().await;
2866
2867 let connection_ids = if let Some(project_id) = request.project_id {
2868 let project_id = ProjectId::from_proto(project_id);
2869 database
2870 .project_connection_ids(project_id, session.connection_id, true)
2871 .await?
2872 } else {
2873 database
2874 .room_connection_ids(room_id, session.connection_id)
2875 .await?
2876 };
2877
2878 // For now, don't send view update messages back to that view's current leader.
2879 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2880 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2881 _ => None,
2882 });
2883
2884 for connection_id in connection_ids.iter().cloned() {
2885 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2886 session
2887 .peer
2888 .forward_send(session.connection_id, connection_id, request.clone())?;
2889 }
2890 }
2891 Ok(())
2892}
2893
2894/// Get public data about users.
2895async fn get_users(
2896 request: proto::GetUsers,
2897 response: Response<proto::GetUsers>,
2898 session: Session,
2899) -> Result<()> {
2900 let user_ids = request
2901 .user_ids
2902 .into_iter()
2903 .map(UserId::from_proto)
2904 .collect();
2905 let users = session
2906 .db()
2907 .await
2908 .get_users_by_ids(user_ids)
2909 .await?
2910 .into_iter()
2911 .map(|user| proto::User {
2912 id: user.id.to_proto(),
2913 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2914 github_login: user.github_login,
2915 })
2916 .collect();
2917 response.send(proto::UsersResponse { users })?;
2918 Ok(())
2919}
2920
2921/// Search for users (to invite) buy Github login
2922async fn fuzzy_search_users(
2923 request: proto::FuzzySearchUsers,
2924 response: Response<proto::FuzzySearchUsers>,
2925 session: UserSession,
2926) -> Result<()> {
2927 let query = request.query;
2928 let users = match query.len() {
2929 0 => vec![],
2930 1 | 2 => session
2931 .db()
2932 .await
2933 .get_user_by_github_login(&query)
2934 .await?
2935 .into_iter()
2936 .collect(),
2937 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2938 };
2939 let users = users
2940 .into_iter()
2941 .filter(|user| user.id != session.user_id())
2942 .map(|user| proto::User {
2943 id: user.id.to_proto(),
2944 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2945 github_login: user.github_login,
2946 })
2947 .collect();
2948 response.send(proto::UsersResponse { users })?;
2949 Ok(())
2950}
2951
2952/// Send a contact request to another user.
2953async fn request_contact(
2954 request: proto::RequestContact,
2955 response: Response<proto::RequestContact>,
2956 session: UserSession,
2957) -> Result<()> {
2958 let requester_id = session.user_id();
2959 let responder_id = UserId::from_proto(request.responder_id);
2960 if requester_id == responder_id {
2961 return Err(anyhow!("cannot add yourself as a contact"))?;
2962 }
2963
2964 let notifications = session
2965 .db()
2966 .await
2967 .send_contact_request(requester_id, responder_id)
2968 .await?;
2969
2970 // Update outgoing contact requests of requester
2971 let mut update = proto::UpdateContacts::default();
2972 update.outgoing_requests.push(responder_id.to_proto());
2973 for connection_id in session
2974 .connection_pool()
2975 .await
2976 .user_connection_ids(requester_id)
2977 {
2978 session.peer.send(connection_id, update.clone())?;
2979 }
2980
2981 // Update incoming contact requests of responder
2982 let mut update = proto::UpdateContacts::default();
2983 update
2984 .incoming_requests
2985 .push(proto::IncomingContactRequest {
2986 requester_id: requester_id.to_proto(),
2987 });
2988 let connection_pool = session.connection_pool().await;
2989 for connection_id in connection_pool.user_connection_ids(responder_id) {
2990 session.peer.send(connection_id, update.clone())?;
2991 }
2992
2993 send_notifications(&connection_pool, &session.peer, notifications);
2994
2995 response.send(proto::Ack {})?;
2996 Ok(())
2997}
2998
2999/// Accept or decline a contact request
3000async fn respond_to_contact_request(
3001 request: proto::RespondToContactRequest,
3002 response: Response<proto::RespondToContactRequest>,
3003 session: UserSession,
3004) -> Result<()> {
3005 let responder_id = session.user_id();
3006 let requester_id = UserId::from_proto(request.requester_id);
3007 let db = session.db().await;
3008 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3009 db.dismiss_contact_notification(responder_id, requester_id)
3010 .await?;
3011 } else {
3012 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3013
3014 let notifications = db
3015 .respond_to_contact_request(responder_id, requester_id, accept)
3016 .await?;
3017 let requester_busy = db.is_user_busy(requester_id).await?;
3018 let responder_busy = db.is_user_busy(responder_id).await?;
3019
3020 let pool = session.connection_pool().await;
3021 // Update responder with new contact
3022 let mut update = proto::UpdateContacts::default();
3023 if accept {
3024 update
3025 .contacts
3026 .push(contact_for_user(requester_id, requester_busy, &pool));
3027 }
3028 update
3029 .remove_incoming_requests
3030 .push(requester_id.to_proto());
3031 for connection_id in pool.user_connection_ids(responder_id) {
3032 session.peer.send(connection_id, update.clone())?;
3033 }
3034
3035 // Update requester with new contact
3036 let mut update = proto::UpdateContacts::default();
3037 if accept {
3038 update
3039 .contacts
3040 .push(contact_for_user(responder_id, responder_busy, &pool));
3041 }
3042 update
3043 .remove_outgoing_requests
3044 .push(responder_id.to_proto());
3045
3046 for connection_id in pool.user_connection_ids(requester_id) {
3047 session.peer.send(connection_id, update.clone())?;
3048 }
3049
3050 send_notifications(&pool, &session.peer, notifications);
3051 }
3052
3053 response.send(proto::Ack {})?;
3054 Ok(())
3055}
3056
3057/// Remove a contact.
3058async fn remove_contact(
3059 request: proto::RemoveContact,
3060 response: Response<proto::RemoveContact>,
3061 session: UserSession,
3062) -> Result<()> {
3063 let requester_id = session.user_id();
3064 let responder_id = UserId::from_proto(request.user_id);
3065 let db = session.db().await;
3066 let (contact_accepted, deleted_notification_id) =
3067 db.remove_contact(requester_id, responder_id).await?;
3068
3069 let pool = session.connection_pool().await;
3070 // Update outgoing contact requests of requester
3071 let mut update = proto::UpdateContacts::default();
3072 if contact_accepted {
3073 update.remove_contacts.push(responder_id.to_proto());
3074 } else {
3075 update
3076 .remove_outgoing_requests
3077 .push(responder_id.to_proto());
3078 }
3079 for connection_id in pool.user_connection_ids(requester_id) {
3080 session.peer.send(connection_id, update.clone())?;
3081 }
3082
3083 // Update incoming contact requests of responder
3084 let mut update = proto::UpdateContacts::default();
3085 if contact_accepted {
3086 update.remove_contacts.push(requester_id.to_proto());
3087 } else {
3088 update
3089 .remove_incoming_requests
3090 .push(requester_id.to_proto());
3091 }
3092 for connection_id in pool.user_connection_ids(responder_id) {
3093 session.peer.send(connection_id, update.clone())?;
3094 if let Some(notification_id) = deleted_notification_id {
3095 session.peer.send(
3096 connection_id,
3097 proto::DeleteNotification {
3098 notification_id: notification_id.to_proto(),
3099 },
3100 )?;
3101 }
3102 }
3103
3104 response.send(proto::Ack {})?;
3105 Ok(())
3106}
3107
3108/// Creates a new channel.
3109async fn create_channel(
3110 request: proto::CreateChannel,
3111 response: Response<proto::CreateChannel>,
3112 session: UserSession,
3113) -> Result<()> {
3114 let db = session.db().await;
3115
3116 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3117 let (channel, membership) = db
3118 .create_channel(&request.name, parent_id, session.user_id())
3119 .await?;
3120
3121 let root_id = channel.root_id();
3122 let channel = Channel::from_model(channel);
3123
3124 response.send(proto::CreateChannelResponse {
3125 channel: Some(channel.to_proto()),
3126 parent_id: request.parent_id,
3127 })?;
3128
3129 let mut connection_pool = session.connection_pool().await;
3130 if let Some(membership) = membership {
3131 connection_pool.subscribe_to_channel(
3132 membership.user_id,
3133 membership.channel_id,
3134 membership.role,
3135 );
3136 let update = proto::UpdateUserChannels {
3137 channel_memberships: vec![proto::ChannelMembership {
3138 channel_id: membership.channel_id.to_proto(),
3139 role: membership.role.into(),
3140 }],
3141 ..Default::default()
3142 };
3143 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3144 session.peer.send(connection_id, update.clone())?;
3145 }
3146 }
3147
3148 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3149 if !role.can_see_channel(channel.visibility) {
3150 continue;
3151 }
3152
3153 let update = proto::UpdateChannels {
3154 channels: vec![channel.to_proto()],
3155 ..Default::default()
3156 };
3157 session.peer.send(connection_id, update.clone())?;
3158 }
3159
3160 Ok(())
3161}
3162
3163/// Delete a channel
3164async fn delete_channel(
3165 request: proto::DeleteChannel,
3166 response: Response<proto::DeleteChannel>,
3167 session: UserSession,
3168) -> Result<()> {
3169 let db = session.db().await;
3170
3171 let channel_id = request.channel_id;
3172 let (root_channel, removed_channels) = db
3173 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3174 .await?;
3175 response.send(proto::Ack {})?;
3176
3177 // Notify members of removed channels
3178 let mut update = proto::UpdateChannels::default();
3179 update
3180 .delete_channels
3181 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3182
3183 let connection_pool = session.connection_pool().await;
3184 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3185 session.peer.send(connection_id, update.clone())?;
3186 }
3187
3188 Ok(())
3189}
3190
3191/// Invite someone to join a channel.
3192async fn invite_channel_member(
3193 request: proto::InviteChannelMember,
3194 response: Response<proto::InviteChannelMember>,
3195 session: UserSession,
3196) -> Result<()> {
3197 let db = session.db().await;
3198 let channel_id = ChannelId::from_proto(request.channel_id);
3199 let invitee_id = UserId::from_proto(request.user_id);
3200 let InviteMemberResult {
3201 channel,
3202 notifications,
3203 } = db
3204 .invite_channel_member(
3205 channel_id,
3206 invitee_id,
3207 session.user_id(),
3208 request.role().into(),
3209 )
3210 .await?;
3211
3212 let update = proto::UpdateChannels {
3213 channel_invitations: vec![channel.to_proto()],
3214 ..Default::default()
3215 };
3216
3217 let connection_pool = session.connection_pool().await;
3218 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3219 session.peer.send(connection_id, update.clone())?;
3220 }
3221
3222 send_notifications(&connection_pool, &session.peer, notifications);
3223
3224 response.send(proto::Ack {})?;
3225 Ok(())
3226}
3227
3228/// remove someone from a channel
3229async fn remove_channel_member(
3230 request: proto::RemoveChannelMember,
3231 response: Response<proto::RemoveChannelMember>,
3232 session: UserSession,
3233) -> Result<()> {
3234 let db = session.db().await;
3235 let channel_id = ChannelId::from_proto(request.channel_id);
3236 let member_id = UserId::from_proto(request.user_id);
3237
3238 let RemoveChannelMemberResult {
3239 membership_update,
3240 notification_id,
3241 } = db
3242 .remove_channel_member(channel_id, member_id, session.user_id())
3243 .await?;
3244
3245 let mut connection_pool = session.connection_pool().await;
3246 notify_membership_updated(
3247 &mut connection_pool,
3248 membership_update,
3249 member_id,
3250 &session.peer,
3251 );
3252 for connection_id in connection_pool.user_connection_ids(member_id) {
3253 if let Some(notification_id) = notification_id {
3254 session
3255 .peer
3256 .send(
3257 connection_id,
3258 proto::DeleteNotification {
3259 notification_id: notification_id.to_proto(),
3260 },
3261 )
3262 .trace_err();
3263 }
3264 }
3265
3266 response.send(proto::Ack {})?;
3267 Ok(())
3268}
3269
3270/// Toggle the channel between public and private.
3271/// Care is taken to maintain the invariant that public channels only descend from public channels,
3272/// (though members-only channels can appear at any point in the hierarchy).
3273async fn set_channel_visibility(
3274 request: proto::SetChannelVisibility,
3275 response: Response<proto::SetChannelVisibility>,
3276 session: UserSession,
3277) -> Result<()> {
3278 let db = session.db().await;
3279 let channel_id = ChannelId::from_proto(request.channel_id);
3280 let visibility = request.visibility().into();
3281
3282 let channel_model = db
3283 .set_channel_visibility(channel_id, visibility, session.user_id())
3284 .await?;
3285 let root_id = channel_model.root_id();
3286 let channel = Channel::from_model(channel_model);
3287
3288 let mut connection_pool = session.connection_pool().await;
3289 for (user_id, role) in connection_pool
3290 .channel_user_ids(root_id)
3291 .collect::<Vec<_>>()
3292 .into_iter()
3293 {
3294 let update = if role.can_see_channel(channel.visibility) {
3295 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3296 proto::UpdateChannels {
3297 channels: vec![channel.to_proto()],
3298 ..Default::default()
3299 }
3300 } else {
3301 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3302 proto::UpdateChannels {
3303 delete_channels: vec![channel.id.to_proto()],
3304 ..Default::default()
3305 }
3306 };
3307
3308 for connection_id in connection_pool.user_connection_ids(user_id) {
3309 session.peer.send(connection_id, update.clone())?;
3310 }
3311 }
3312
3313 response.send(proto::Ack {})?;
3314 Ok(())
3315}
3316
3317/// Alter the role for a user in the channel.
3318async fn set_channel_member_role(
3319 request: proto::SetChannelMemberRole,
3320 response: Response<proto::SetChannelMemberRole>,
3321 session: UserSession,
3322) -> Result<()> {
3323 let db = session.db().await;
3324 let channel_id = ChannelId::from_proto(request.channel_id);
3325 let member_id = UserId::from_proto(request.user_id);
3326 let result = db
3327 .set_channel_member_role(
3328 channel_id,
3329 session.user_id(),
3330 member_id,
3331 request.role().into(),
3332 )
3333 .await?;
3334
3335 match result {
3336 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3337 let mut connection_pool = session.connection_pool().await;
3338 notify_membership_updated(
3339 &mut connection_pool,
3340 membership_update,
3341 member_id,
3342 &session.peer,
3343 )
3344 }
3345 db::SetMemberRoleResult::InviteUpdated(channel) => {
3346 let update = proto::UpdateChannels {
3347 channel_invitations: vec![channel.to_proto()],
3348 ..Default::default()
3349 };
3350
3351 for connection_id in session
3352 .connection_pool()
3353 .await
3354 .user_connection_ids(member_id)
3355 {
3356 session.peer.send(connection_id, update.clone())?;
3357 }
3358 }
3359 }
3360
3361 response.send(proto::Ack {})?;
3362 Ok(())
3363}
3364
3365/// Change the name of a channel
3366async fn rename_channel(
3367 request: proto::RenameChannel,
3368 response: Response<proto::RenameChannel>,
3369 session: UserSession,
3370) -> Result<()> {
3371 let db = session.db().await;
3372 let channel_id = ChannelId::from_proto(request.channel_id);
3373 let channel_model = db
3374 .rename_channel(channel_id, session.user_id(), &request.name)
3375 .await?;
3376 let root_id = channel_model.root_id();
3377 let channel = Channel::from_model(channel_model);
3378
3379 response.send(proto::RenameChannelResponse {
3380 channel: Some(channel.to_proto()),
3381 })?;
3382
3383 let connection_pool = session.connection_pool().await;
3384 let update = proto::UpdateChannels {
3385 channels: vec![channel.to_proto()],
3386 ..Default::default()
3387 };
3388 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3389 if role.can_see_channel(channel.visibility) {
3390 session.peer.send(connection_id, update.clone())?;
3391 }
3392 }
3393
3394 Ok(())
3395}
3396
3397/// Move a channel to a new parent.
3398async fn move_channel(
3399 request: proto::MoveChannel,
3400 response: Response<proto::MoveChannel>,
3401 session: UserSession,
3402) -> Result<()> {
3403 let channel_id = ChannelId::from_proto(request.channel_id);
3404 let to = ChannelId::from_proto(request.to);
3405
3406 let (root_id, channels) = session
3407 .db()
3408 .await
3409 .move_channel(channel_id, to, session.user_id())
3410 .await?;
3411
3412 let connection_pool = session.connection_pool().await;
3413 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3414 let channels = channels
3415 .iter()
3416 .filter_map(|channel| {
3417 if role.can_see_channel(channel.visibility) {
3418 Some(channel.to_proto())
3419 } else {
3420 None
3421 }
3422 })
3423 .collect::<Vec<_>>();
3424 if channels.is_empty() {
3425 continue;
3426 }
3427
3428 let update = proto::UpdateChannels {
3429 channels,
3430 ..Default::default()
3431 };
3432
3433 session.peer.send(connection_id, update.clone())?;
3434 }
3435
3436 response.send(Ack {})?;
3437 Ok(())
3438}
3439
3440/// Get the list of channel members
3441async fn get_channel_members(
3442 request: proto::GetChannelMembers,
3443 response: Response<proto::GetChannelMembers>,
3444 session: UserSession,
3445) -> Result<()> {
3446 let db = session.db().await;
3447 let channel_id = ChannelId::from_proto(request.channel_id);
3448 let members = db
3449 .get_channel_participant_details(channel_id, session.user_id())
3450 .await?;
3451 response.send(proto::GetChannelMembersResponse { members })?;
3452 Ok(())
3453}
3454
3455/// Accept or decline a channel invitation.
3456async fn respond_to_channel_invite(
3457 request: proto::RespondToChannelInvite,
3458 response: Response<proto::RespondToChannelInvite>,
3459 session: UserSession,
3460) -> Result<()> {
3461 let db = session.db().await;
3462 let channel_id = ChannelId::from_proto(request.channel_id);
3463 let RespondToChannelInvite {
3464 membership_update,
3465 notifications,
3466 } = db
3467 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3468 .await?;
3469
3470 let mut connection_pool = session.connection_pool().await;
3471 if let Some(membership_update) = membership_update {
3472 notify_membership_updated(
3473 &mut connection_pool,
3474 membership_update,
3475 session.user_id(),
3476 &session.peer,
3477 );
3478 } else {
3479 let update = proto::UpdateChannels {
3480 remove_channel_invitations: vec![channel_id.to_proto()],
3481 ..Default::default()
3482 };
3483
3484 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3485 session.peer.send(connection_id, update.clone())?;
3486 }
3487 };
3488
3489 send_notifications(&connection_pool, &session.peer, notifications);
3490
3491 response.send(proto::Ack {})?;
3492
3493 Ok(())
3494}
3495
3496/// Join the channels' room
3497async fn join_channel(
3498 request: proto::JoinChannel,
3499 response: Response<proto::JoinChannel>,
3500 session: UserSession,
3501) -> Result<()> {
3502 let channel_id = ChannelId::from_proto(request.channel_id);
3503 join_channel_internal(channel_id, Box::new(response), session).await
3504}
3505
3506trait JoinChannelInternalResponse {
3507 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3508}
3509impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3510 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3511 Response::<proto::JoinChannel>::send(self, result)
3512 }
3513}
3514impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3515 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3516 Response::<proto::JoinRoom>::send(self, result)
3517 }
3518}
3519
3520async fn join_channel_internal(
3521 channel_id: ChannelId,
3522 response: Box<impl JoinChannelInternalResponse>,
3523 session: UserSession,
3524) -> Result<()> {
3525 let joined_room = {
3526 let mut db = session.db().await;
3527 // If zed quits without leaving the room, and the user re-opens zed before the
3528 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3529 // room they were in.
3530 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3531 tracing::info!(
3532 stale_connection_id = %connection,
3533 "cleaning up stale connection",
3534 );
3535 drop(db);
3536 leave_room_for_session(&session, connection).await?;
3537 db = session.db().await;
3538 }
3539
3540 let (joined_room, membership_updated, role) = db
3541 .join_channel(channel_id, session.user_id(), session.connection_id)
3542 .await?;
3543
3544 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3545 let (can_publish, token) = if role == ChannelRole::Guest {
3546 (
3547 false,
3548 live_kit
3549 .guest_token(
3550 &joined_room.room.live_kit_room,
3551 &session.user_id().to_string(),
3552 )
3553 .trace_err()?,
3554 )
3555 } else {
3556 (
3557 true,
3558 live_kit
3559 .room_token(
3560 &joined_room.room.live_kit_room,
3561 &session.user_id().to_string(),
3562 )
3563 .trace_err()?,
3564 )
3565 };
3566
3567 Some(LiveKitConnectionInfo {
3568 server_url: live_kit.url().into(),
3569 token,
3570 can_publish,
3571 })
3572 });
3573
3574 response.send(proto::JoinRoomResponse {
3575 room: Some(joined_room.room.clone()),
3576 channel_id: joined_room
3577 .channel
3578 .as_ref()
3579 .map(|channel| channel.id.to_proto()),
3580 live_kit_connection_info,
3581 })?;
3582
3583 let mut connection_pool = session.connection_pool().await;
3584 if let Some(membership_updated) = membership_updated {
3585 notify_membership_updated(
3586 &mut connection_pool,
3587 membership_updated,
3588 session.user_id(),
3589 &session.peer,
3590 );
3591 }
3592
3593 room_updated(&joined_room.room, &session.peer);
3594
3595 joined_room
3596 };
3597
3598 channel_updated(
3599 &joined_room
3600 .channel
3601 .ok_or_else(|| anyhow!("channel not returned"))?,
3602 &joined_room.room,
3603 &session.peer,
3604 &*session.connection_pool().await,
3605 );
3606
3607 update_user_contacts(session.user_id(), &session).await?;
3608 Ok(())
3609}
3610
3611/// Start editing the channel notes
3612async fn join_channel_buffer(
3613 request: proto::JoinChannelBuffer,
3614 response: Response<proto::JoinChannelBuffer>,
3615 session: UserSession,
3616) -> Result<()> {
3617 let db = session.db().await;
3618 let channel_id = ChannelId::from_proto(request.channel_id);
3619
3620 let open_response = db
3621 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3622 .await?;
3623
3624 let collaborators = open_response.collaborators.clone();
3625 response.send(open_response)?;
3626
3627 let update = UpdateChannelBufferCollaborators {
3628 channel_id: channel_id.to_proto(),
3629 collaborators: collaborators.clone(),
3630 };
3631 channel_buffer_updated(
3632 session.connection_id,
3633 collaborators
3634 .iter()
3635 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3636 &update,
3637 &session.peer,
3638 );
3639
3640 Ok(())
3641}
3642
3643/// Edit the channel notes
3644async fn update_channel_buffer(
3645 request: proto::UpdateChannelBuffer,
3646 session: UserSession,
3647) -> Result<()> {
3648 let db = session.db().await;
3649 let channel_id = ChannelId::from_proto(request.channel_id);
3650
3651 let (collaborators, non_collaborators, epoch, version) = db
3652 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3653 .await?;
3654
3655 channel_buffer_updated(
3656 session.connection_id,
3657 collaborators,
3658 &proto::UpdateChannelBuffer {
3659 channel_id: channel_id.to_proto(),
3660 operations: request.operations,
3661 },
3662 &session.peer,
3663 );
3664
3665 let pool = &*session.connection_pool().await;
3666
3667 broadcast(
3668 None,
3669 non_collaborators
3670 .iter()
3671 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3672 |peer_id| {
3673 session.peer.send(
3674 peer_id,
3675 proto::UpdateChannels {
3676 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3677 channel_id: channel_id.to_proto(),
3678 epoch: epoch as u64,
3679 version: version.clone(),
3680 }],
3681 ..Default::default()
3682 },
3683 )
3684 },
3685 );
3686
3687 Ok(())
3688}
3689
3690/// Rejoin the channel notes after a connection blip
3691async fn rejoin_channel_buffers(
3692 request: proto::RejoinChannelBuffers,
3693 response: Response<proto::RejoinChannelBuffers>,
3694 session: UserSession,
3695) -> Result<()> {
3696 let db = session.db().await;
3697 let buffers = db
3698 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3699 .await?;
3700
3701 for rejoined_buffer in &buffers {
3702 let collaborators_to_notify = rejoined_buffer
3703 .buffer
3704 .collaborators
3705 .iter()
3706 .filter_map(|c| Some(c.peer_id?.into()));
3707 channel_buffer_updated(
3708 session.connection_id,
3709 collaborators_to_notify,
3710 &proto::UpdateChannelBufferCollaborators {
3711 channel_id: rejoined_buffer.buffer.channel_id,
3712 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3713 },
3714 &session.peer,
3715 );
3716 }
3717
3718 response.send(proto::RejoinChannelBuffersResponse {
3719 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3720 })?;
3721
3722 Ok(())
3723}
3724
3725/// Stop editing the channel notes
3726async fn leave_channel_buffer(
3727 request: proto::LeaveChannelBuffer,
3728 response: Response<proto::LeaveChannelBuffer>,
3729 session: UserSession,
3730) -> Result<()> {
3731 let db = session.db().await;
3732 let channel_id = ChannelId::from_proto(request.channel_id);
3733
3734 let left_buffer = db
3735 .leave_channel_buffer(channel_id, session.connection_id)
3736 .await?;
3737
3738 response.send(Ack {})?;
3739
3740 channel_buffer_updated(
3741 session.connection_id,
3742 left_buffer.connections,
3743 &proto::UpdateChannelBufferCollaborators {
3744 channel_id: channel_id.to_proto(),
3745 collaborators: left_buffer.collaborators,
3746 },
3747 &session.peer,
3748 );
3749
3750 Ok(())
3751}
3752
3753fn channel_buffer_updated<T: EnvelopedMessage>(
3754 sender_id: ConnectionId,
3755 collaborators: impl IntoIterator<Item = ConnectionId>,
3756 message: &T,
3757 peer: &Peer,
3758) {
3759 broadcast(Some(sender_id), collaborators, |peer_id| {
3760 peer.send(peer_id, message.clone())
3761 });
3762}
3763
3764fn send_notifications(
3765 connection_pool: &ConnectionPool,
3766 peer: &Peer,
3767 notifications: db::NotificationBatch,
3768) {
3769 for (user_id, notification) in notifications {
3770 for connection_id in connection_pool.user_connection_ids(user_id) {
3771 if let Err(error) = peer.send(
3772 connection_id,
3773 proto::AddNotification {
3774 notification: Some(notification.clone()),
3775 },
3776 ) {
3777 tracing::error!(
3778 "failed to send notification to {:?} {}",
3779 connection_id,
3780 error
3781 );
3782 }
3783 }
3784 }
3785}
3786
3787/// Send a message to the channel
3788async fn send_channel_message(
3789 request: proto::SendChannelMessage,
3790 response: Response<proto::SendChannelMessage>,
3791 session: UserSession,
3792) -> Result<()> {
3793 // Validate the message body.
3794 let body = request.body.trim().to_string();
3795 if body.len() > MAX_MESSAGE_LEN {
3796 return Err(anyhow!("message is too long"))?;
3797 }
3798 if body.is_empty() {
3799 return Err(anyhow!("message can't be blank"))?;
3800 }
3801
3802 // TODO: adjust mentions if body is trimmed
3803
3804 let timestamp = OffsetDateTime::now_utc();
3805 let nonce = request
3806 .nonce
3807 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3808
3809 let channel_id = ChannelId::from_proto(request.channel_id);
3810 let CreatedChannelMessage {
3811 message_id,
3812 participant_connection_ids,
3813 channel_members,
3814 notifications,
3815 } = session
3816 .db()
3817 .await
3818 .create_channel_message(
3819 channel_id,
3820 session.user_id(),
3821 &body,
3822 &request.mentions,
3823 timestamp,
3824 nonce.clone().into(),
3825 match request.reply_to_message_id {
3826 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3827 None => None,
3828 },
3829 )
3830 .await?;
3831
3832 let message = proto::ChannelMessage {
3833 sender_id: session.user_id().to_proto(),
3834 id: message_id.to_proto(),
3835 body,
3836 mentions: request.mentions,
3837 timestamp: timestamp.unix_timestamp() as u64,
3838 nonce: Some(nonce),
3839 reply_to_message_id: request.reply_to_message_id,
3840 edited_at: None,
3841 };
3842 broadcast(
3843 Some(session.connection_id),
3844 participant_connection_ids,
3845 |connection| {
3846 session.peer.send(
3847 connection,
3848 proto::ChannelMessageSent {
3849 channel_id: channel_id.to_proto(),
3850 message: Some(message.clone()),
3851 },
3852 )
3853 },
3854 );
3855 response.send(proto::SendChannelMessageResponse {
3856 message: Some(message),
3857 })?;
3858
3859 let pool = &*session.connection_pool().await;
3860 broadcast(
3861 None,
3862 channel_members
3863 .iter()
3864 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3865 |peer_id| {
3866 session.peer.send(
3867 peer_id,
3868 proto::UpdateChannels {
3869 latest_channel_message_ids: vec![proto::ChannelMessageId {
3870 channel_id: channel_id.to_proto(),
3871 message_id: message_id.to_proto(),
3872 }],
3873 ..Default::default()
3874 },
3875 )
3876 },
3877 );
3878 send_notifications(pool, &session.peer, notifications);
3879
3880 Ok(())
3881}
3882
3883/// Delete a channel message
3884async fn remove_channel_message(
3885 request: proto::RemoveChannelMessage,
3886 response: Response<proto::RemoveChannelMessage>,
3887 session: UserSession,
3888) -> Result<()> {
3889 let channel_id = ChannelId::from_proto(request.channel_id);
3890 let message_id = MessageId::from_proto(request.message_id);
3891 let (connection_ids, existing_notification_ids) = session
3892 .db()
3893 .await
3894 .remove_channel_message(channel_id, message_id, session.user_id())
3895 .await?;
3896
3897 broadcast(
3898 Some(session.connection_id),
3899 connection_ids,
3900 move |connection| {
3901 session.peer.send(connection, request.clone())?;
3902
3903 for notification_id in &existing_notification_ids {
3904 session.peer.send(
3905 connection,
3906 proto::DeleteNotification {
3907 notification_id: (*notification_id).to_proto(),
3908 },
3909 )?;
3910 }
3911
3912 Ok(())
3913 },
3914 );
3915 response.send(proto::Ack {})?;
3916 Ok(())
3917}
3918
3919async fn update_channel_message(
3920 request: proto::UpdateChannelMessage,
3921 response: Response<proto::UpdateChannelMessage>,
3922 session: UserSession,
3923) -> Result<()> {
3924 let channel_id = ChannelId::from_proto(request.channel_id);
3925 let message_id = MessageId::from_proto(request.message_id);
3926 let updated_at = OffsetDateTime::now_utc();
3927 let UpdatedChannelMessage {
3928 message_id,
3929 participant_connection_ids,
3930 notifications,
3931 reply_to_message_id,
3932 timestamp,
3933 deleted_mention_notification_ids,
3934 updated_mention_notifications,
3935 } = session
3936 .db()
3937 .await
3938 .update_channel_message(
3939 channel_id,
3940 message_id,
3941 session.user_id(),
3942 request.body.as_str(),
3943 &request.mentions,
3944 updated_at,
3945 )
3946 .await?;
3947
3948 let nonce = request
3949 .nonce
3950 .clone()
3951 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3952
3953 let message = proto::ChannelMessage {
3954 sender_id: session.user_id().to_proto(),
3955 id: message_id.to_proto(),
3956 body: request.body.clone(),
3957 mentions: request.mentions.clone(),
3958 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3959 nonce: Some(nonce),
3960 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3961 edited_at: Some(updated_at.unix_timestamp() as u64),
3962 };
3963
3964 response.send(proto::Ack {})?;
3965
3966 let pool = &*session.connection_pool().await;
3967 broadcast(
3968 Some(session.connection_id),
3969 participant_connection_ids,
3970 |connection| {
3971 session.peer.send(
3972 connection,
3973 proto::ChannelMessageUpdate {
3974 channel_id: channel_id.to_proto(),
3975 message: Some(message.clone()),
3976 },
3977 )?;
3978
3979 for notification_id in &deleted_mention_notification_ids {
3980 session.peer.send(
3981 connection,
3982 proto::DeleteNotification {
3983 notification_id: (*notification_id).to_proto(),
3984 },
3985 )?;
3986 }
3987
3988 for notification in &updated_mention_notifications {
3989 session.peer.send(
3990 connection,
3991 proto::UpdateNotification {
3992 notification: Some(notification.clone()),
3993 },
3994 )?;
3995 }
3996
3997 Ok(())
3998 },
3999 );
4000
4001 send_notifications(pool, &session.peer, notifications);
4002
4003 Ok(())
4004}
4005
4006/// Mark a channel message as read
4007async fn acknowledge_channel_message(
4008 request: proto::AckChannelMessage,
4009 session: UserSession,
4010) -> Result<()> {
4011 let channel_id = ChannelId::from_proto(request.channel_id);
4012 let message_id = MessageId::from_proto(request.message_id);
4013 let notifications = session
4014 .db()
4015 .await
4016 .observe_channel_message(channel_id, session.user_id(), message_id)
4017 .await?;
4018 send_notifications(
4019 &*session.connection_pool().await,
4020 &session.peer,
4021 notifications,
4022 );
4023 Ok(())
4024}
4025
4026/// Mark a buffer version as synced
4027async fn acknowledge_buffer_version(
4028 request: proto::AckBufferOperation,
4029 session: UserSession,
4030) -> Result<()> {
4031 let buffer_id = BufferId::from_proto(request.buffer_id);
4032 session
4033 .db()
4034 .await
4035 .observe_buffer_version(
4036 buffer_id,
4037 session.user_id(),
4038 request.epoch as i32,
4039 &request.version,
4040 )
4041 .await?;
4042 Ok(())
4043}
4044
4045struct CompleteWithLanguageModelRateLimit;
4046
4047impl RateLimit for CompleteWithLanguageModelRateLimit {
4048 fn capacity() -> usize {
4049 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4050 .ok()
4051 .and_then(|v| v.parse().ok())
4052 .unwrap_or(120) // Picked arbitrarily
4053 }
4054
4055 fn refill_duration() -> chrono::Duration {
4056 chrono::Duration::hours(1)
4057 }
4058
4059 fn db_name() -> &'static str {
4060 "complete-with-language-model"
4061 }
4062}
4063
4064async fn complete_with_language_model(
4065 request: proto::CompleteWithLanguageModel,
4066 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4067 session: Session,
4068 open_ai_api_key: Option<Arc<str>>,
4069 google_ai_api_key: Option<Arc<str>>,
4070 anthropic_api_key: Option<Arc<str>>,
4071) -> Result<()> {
4072 let Some(session) = session.for_user() else {
4073 return Err(anyhow!("user not found"))?;
4074 };
4075 authorize_access_to_language_models(&session).await?;
4076 session
4077 .rate_limiter
4078 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4079 .await?;
4080
4081 if request.model.starts_with("gpt") {
4082 let api_key =
4083 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4084 complete_with_open_ai(request, response, session, api_key).await?;
4085 } else if request.model.starts_with("gemini") {
4086 let api_key = google_ai_api_key
4087 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4088 complete_with_google_ai(request, response, session, api_key).await?;
4089 } else if request.model.starts_with("claude") {
4090 let api_key = anthropic_api_key
4091 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4092 complete_with_anthropic(request, response, session, api_key).await?;
4093 }
4094
4095 Ok(())
4096}
4097
4098async fn complete_with_open_ai(
4099 request: proto::CompleteWithLanguageModel,
4100 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4101 session: UserSession,
4102 api_key: Arc<str>,
4103) -> Result<()> {
4104 let mut completion_stream = open_ai::stream_completion(
4105 &session.http_client,
4106 OPEN_AI_API_URL,
4107 &api_key,
4108 crate::ai::language_model_request_to_open_ai(request)?,
4109 )
4110 .await
4111 .context("open_ai::stream_completion request failed within collab")?;
4112
4113 while let Some(event) = completion_stream.next().await {
4114 let event = event?;
4115 response.send(proto::LanguageModelResponse {
4116 choices: event
4117 .choices
4118 .into_iter()
4119 .map(|choice| proto::LanguageModelChoiceDelta {
4120 index: choice.index,
4121 delta: Some(proto::LanguageModelResponseMessage {
4122 role: choice.delta.role.map(|role| match role {
4123 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4124 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4125 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4126 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4127 } as i32),
4128 content: choice.delta.content,
4129 tool_calls: choice
4130 .delta
4131 .tool_calls
4132 .into_iter()
4133 .map(|delta| proto::ToolCallDelta {
4134 index: delta.index as u32,
4135 id: delta.id,
4136 variant: match delta.function {
4137 Some(function) => {
4138 let name = function.name;
4139 let arguments = function.arguments;
4140
4141 Some(proto::tool_call_delta::Variant::Function(
4142 proto::tool_call_delta::FunctionCallDelta {
4143 name,
4144 arguments,
4145 },
4146 ))
4147 }
4148 None => None,
4149 },
4150 })
4151 .collect(),
4152 }),
4153 finish_reason: choice.finish_reason,
4154 })
4155 .collect(),
4156 })?;
4157 }
4158
4159 Ok(())
4160}
4161
4162async fn complete_with_google_ai(
4163 request: proto::CompleteWithLanguageModel,
4164 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4165 session: UserSession,
4166 api_key: Arc<str>,
4167) -> Result<()> {
4168 let mut stream = google_ai::stream_generate_content(
4169 &session.http_client,
4170 google_ai::API_URL,
4171 api_key.as_ref(),
4172 crate::ai::language_model_request_to_google_ai(request)?,
4173 )
4174 .await
4175 .context("google_ai::stream_generate_content request failed")?;
4176
4177 while let Some(event) = stream.next().await {
4178 let event = event?;
4179 response.send(proto::LanguageModelResponse {
4180 choices: event
4181 .candidates
4182 .unwrap_or_default()
4183 .into_iter()
4184 .map(|candidate| proto::LanguageModelChoiceDelta {
4185 index: candidate.index as u32,
4186 delta: Some(proto::LanguageModelResponseMessage {
4187 role: Some(match candidate.content.role {
4188 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4189 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4190 } as i32),
4191 content: Some(
4192 candidate
4193 .content
4194 .parts
4195 .into_iter()
4196 .filter_map(|part| match part {
4197 google_ai::Part::TextPart(part) => Some(part.text),
4198 google_ai::Part::InlineDataPart(_) => None,
4199 })
4200 .collect(),
4201 ),
4202 // Tool calls are not supported for Google
4203 tool_calls: Vec::new(),
4204 }),
4205 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4206 })
4207 .collect(),
4208 })?;
4209 }
4210
4211 Ok(())
4212}
4213
4214async fn complete_with_anthropic(
4215 request: proto::CompleteWithLanguageModel,
4216 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4217 session: UserSession,
4218 api_key: Arc<str>,
4219) -> Result<()> {
4220 let model = anthropic::Model::from_id(&request.model)?;
4221
4222 let mut system_message = String::new();
4223 let messages = request
4224 .messages
4225 .into_iter()
4226 .filter_map(|message| {
4227 match message.role() {
4228 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4229 role: anthropic::Role::User,
4230 content: message.content,
4231 }),
4232 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4233 role: anthropic::Role::Assistant,
4234 content: message.content,
4235 }),
4236 // Anthropic's API breaks system instructions out as a separate field rather
4237 // than having a system message role.
4238 LanguageModelRole::LanguageModelSystem => {
4239 if !system_message.is_empty() {
4240 system_message.push_str("\n\n");
4241 }
4242 system_message.push_str(&message.content);
4243
4244 None
4245 }
4246 // We don't yet support tool calls for Anthropic
4247 LanguageModelRole::LanguageModelTool => None,
4248 }
4249 })
4250 .collect();
4251
4252 let mut stream = anthropic::stream_completion(
4253 &session.http_client,
4254 "https://api.anthropic.com",
4255 &api_key,
4256 anthropic::Request {
4257 model,
4258 messages,
4259 stream: true,
4260 system: system_message,
4261 max_tokens: 4092,
4262 },
4263 )
4264 .await?;
4265
4266 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4267
4268 while let Some(event) = stream.next().await {
4269 let event = event?;
4270
4271 match event {
4272 anthropic::ResponseEvent::MessageStart { message } => {
4273 if let Some(role) = message.role {
4274 if role == "assistant" {
4275 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4276 } else if role == "user" {
4277 current_role = proto::LanguageModelRole::LanguageModelUser;
4278 }
4279 }
4280 }
4281 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4282 match content_block {
4283 anthropic::ContentBlock::Text { text } => {
4284 if !text.is_empty() {
4285 response.send(proto::LanguageModelResponse {
4286 choices: vec![proto::LanguageModelChoiceDelta {
4287 index: 0,
4288 delta: Some(proto::LanguageModelResponseMessage {
4289 role: Some(current_role as i32),
4290 content: Some(text),
4291 tool_calls: Vec::new(),
4292 }),
4293 finish_reason: None,
4294 }],
4295 })?;
4296 }
4297 }
4298 }
4299 }
4300 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4301 anthropic::TextDelta::TextDelta { text } => {
4302 response.send(proto::LanguageModelResponse {
4303 choices: vec![proto::LanguageModelChoiceDelta {
4304 index: 0,
4305 delta: Some(proto::LanguageModelResponseMessage {
4306 role: Some(current_role as i32),
4307 content: Some(text),
4308 tool_calls: Vec::new(),
4309 }),
4310 finish_reason: None,
4311 }],
4312 })?;
4313 }
4314 },
4315 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4316 if let Some(stop_reason) = delta.stop_reason {
4317 response.send(proto::LanguageModelResponse {
4318 choices: vec![proto::LanguageModelChoiceDelta {
4319 index: 0,
4320 delta: None,
4321 finish_reason: Some(stop_reason),
4322 }],
4323 })?;
4324 }
4325 }
4326 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4327 anthropic::ResponseEvent::MessageStop {} => {}
4328 anthropic::ResponseEvent::Ping {} => {}
4329 }
4330 }
4331
4332 Ok(())
4333}
4334
4335struct CountTokensWithLanguageModelRateLimit;
4336
4337impl RateLimit for CountTokensWithLanguageModelRateLimit {
4338 fn capacity() -> usize {
4339 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4340 .ok()
4341 .and_then(|v| v.parse().ok())
4342 .unwrap_or(600) // Picked arbitrarily
4343 }
4344
4345 fn refill_duration() -> chrono::Duration {
4346 chrono::Duration::hours(1)
4347 }
4348
4349 fn db_name() -> &'static str {
4350 "count-tokens-with-language-model"
4351 }
4352}
4353
4354async fn count_tokens_with_language_model(
4355 request: proto::CountTokensWithLanguageModel,
4356 response: Response<proto::CountTokensWithLanguageModel>,
4357 session: UserSession,
4358 google_ai_api_key: Option<Arc<str>>,
4359) -> Result<()> {
4360 authorize_access_to_language_models(&session).await?;
4361
4362 if !request.model.starts_with("gemini") {
4363 return Err(anyhow!(
4364 "counting tokens for model: {:?} is not supported",
4365 request.model
4366 ))?;
4367 }
4368
4369 session
4370 .rate_limiter
4371 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4372 .await?;
4373
4374 let api_key = google_ai_api_key
4375 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4376 let tokens_response = google_ai::count_tokens(
4377 &session.http_client,
4378 google_ai::API_URL,
4379 &api_key,
4380 crate::ai::count_tokens_request_to_google_ai(request)?,
4381 )
4382 .await?;
4383 response.send(proto::CountTokensResponse {
4384 token_count: tokens_response.total_tokens as u32,
4385 })?;
4386 Ok(())
4387}
4388
4389struct ComputeEmbeddingsRateLimit;
4390
4391impl RateLimit for ComputeEmbeddingsRateLimit {
4392 fn capacity() -> usize {
4393 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4394 .ok()
4395 .and_then(|v| v.parse().ok())
4396 .unwrap_or(120) // Picked arbitrarily
4397 }
4398
4399 fn refill_duration() -> chrono::Duration {
4400 chrono::Duration::hours(1)
4401 }
4402
4403 fn db_name() -> &'static str {
4404 "compute-embeddings"
4405 }
4406}
4407
4408async fn compute_embeddings(
4409 request: proto::ComputeEmbeddings,
4410 response: Response<proto::ComputeEmbeddings>,
4411 session: UserSession,
4412 api_key: Option<Arc<str>>,
4413) -> Result<()> {
4414 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4415 authorize_access_to_language_models(&session).await?;
4416
4417 session
4418 .rate_limiter
4419 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4420 .await?;
4421
4422 let embeddings = match request.model.as_str() {
4423 "openai/text-embedding-3-small" => {
4424 open_ai::embed(
4425 &session.http_client,
4426 OPEN_AI_API_URL,
4427 &api_key,
4428 OpenAiEmbeddingModel::TextEmbedding3Small,
4429 request.texts.iter().map(|text| text.as_str()),
4430 )
4431 .await?
4432 }
4433 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4434 };
4435
4436 let embeddings = request
4437 .texts
4438 .iter()
4439 .map(|text| {
4440 let mut hasher = sha2::Sha256::new();
4441 hasher.update(text.as_bytes());
4442 let result = hasher.finalize();
4443 result.to_vec()
4444 })
4445 .zip(
4446 embeddings
4447 .data
4448 .into_iter()
4449 .map(|embedding| embedding.embedding),
4450 )
4451 .collect::<HashMap<_, _>>();
4452
4453 let db = session.db().await;
4454 db.save_embeddings(&request.model, &embeddings)
4455 .await
4456 .context("failed to save embeddings")
4457 .trace_err();
4458
4459 response.send(proto::ComputeEmbeddingsResponse {
4460 embeddings: embeddings
4461 .into_iter()
4462 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4463 .collect(),
4464 })?;
4465 Ok(())
4466}
4467
4468struct GetCachedEmbeddingsRateLimit;
4469
4470impl RateLimit for GetCachedEmbeddingsRateLimit {
4471 fn capacity() -> usize {
4472 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4473 .ok()
4474 .and_then(|v| v.parse().ok())
4475 .unwrap_or(120) // Picked arbitrarily
4476 }
4477
4478 fn refill_duration() -> chrono::Duration {
4479 chrono::Duration::hours(1)
4480 }
4481
4482 fn db_name() -> &'static str {
4483 "get-cached-embeddings"
4484 }
4485}
4486
4487async fn get_cached_embeddings(
4488 request: proto::GetCachedEmbeddings,
4489 response: Response<proto::GetCachedEmbeddings>,
4490 session: UserSession,
4491) -> Result<()> {
4492 authorize_access_to_language_models(&session).await?;
4493
4494 session
4495 .rate_limiter
4496 .check::<GetCachedEmbeddingsRateLimit>(session.user_id())
4497 .await?;
4498
4499 let db = session.db().await;
4500 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4501
4502 response.send(proto::GetCachedEmbeddingsResponse {
4503 embeddings: embeddings
4504 .into_iter()
4505 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4506 .collect(),
4507 })?;
4508 Ok(())
4509}
4510
4511async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4512 let db = session.db().await;
4513 let flags = db.get_user_flags(session.user_id()).await?;
4514 if flags.iter().any(|flag| flag == "language-models") {
4515 Ok(())
4516 } else {
4517 Err(anyhow!("permission denied"))?
4518 }
4519}
4520
4521/// Start receiving chat updates for a channel
4522async fn join_channel_chat(
4523 request: proto::JoinChannelChat,
4524 response: Response<proto::JoinChannelChat>,
4525 session: UserSession,
4526) -> Result<()> {
4527 let channel_id = ChannelId::from_proto(request.channel_id);
4528
4529 let db = session.db().await;
4530 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4531 .await?;
4532 let messages = db
4533 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4534 .await?;
4535 response.send(proto::JoinChannelChatResponse {
4536 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4537 messages,
4538 })?;
4539 Ok(())
4540}
4541
4542/// Stop receiving chat updates for a channel
4543async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4544 let channel_id = ChannelId::from_proto(request.channel_id);
4545 session
4546 .db()
4547 .await
4548 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4549 .await?;
4550 Ok(())
4551}
4552
4553/// Retrieve the chat history for a channel
4554async fn get_channel_messages(
4555 request: proto::GetChannelMessages,
4556 response: Response<proto::GetChannelMessages>,
4557 session: UserSession,
4558) -> Result<()> {
4559 let channel_id = ChannelId::from_proto(request.channel_id);
4560 let messages = session
4561 .db()
4562 .await
4563 .get_channel_messages(
4564 channel_id,
4565 session.user_id(),
4566 MESSAGE_COUNT_PER_PAGE,
4567 Some(MessageId::from_proto(request.before_message_id)),
4568 )
4569 .await?;
4570 response.send(proto::GetChannelMessagesResponse {
4571 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4572 messages,
4573 })?;
4574 Ok(())
4575}
4576
4577/// Retrieve specific chat messages
4578async fn get_channel_messages_by_id(
4579 request: proto::GetChannelMessagesById,
4580 response: Response<proto::GetChannelMessagesById>,
4581 session: UserSession,
4582) -> Result<()> {
4583 let message_ids = request
4584 .message_ids
4585 .iter()
4586 .map(|id| MessageId::from_proto(*id))
4587 .collect::<Vec<_>>();
4588 let messages = session
4589 .db()
4590 .await
4591 .get_channel_messages_by_id(session.user_id(), &message_ids)
4592 .await?;
4593 response.send(proto::GetChannelMessagesResponse {
4594 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4595 messages,
4596 })?;
4597 Ok(())
4598}
4599
4600/// Retrieve the current users notifications
4601async fn get_notifications(
4602 request: proto::GetNotifications,
4603 response: Response<proto::GetNotifications>,
4604 session: UserSession,
4605) -> Result<()> {
4606 let notifications = session
4607 .db()
4608 .await
4609 .get_notifications(
4610 session.user_id(),
4611 NOTIFICATION_COUNT_PER_PAGE,
4612 request
4613 .before_id
4614 .map(|id| db::NotificationId::from_proto(id)),
4615 )
4616 .await?;
4617 response.send(proto::GetNotificationsResponse {
4618 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4619 notifications,
4620 })?;
4621 Ok(())
4622}
4623
4624/// Mark notifications as read
4625async fn mark_notification_as_read(
4626 request: proto::MarkNotificationRead,
4627 response: Response<proto::MarkNotificationRead>,
4628 session: UserSession,
4629) -> Result<()> {
4630 let database = &session.db().await;
4631 let notifications = database
4632 .mark_notification_as_read_by_id(
4633 session.user_id(),
4634 NotificationId::from_proto(request.notification_id),
4635 )
4636 .await?;
4637 send_notifications(
4638 &*session.connection_pool().await,
4639 &session.peer,
4640 notifications,
4641 );
4642 response.send(proto::Ack {})?;
4643 Ok(())
4644}
4645
4646/// Get the current users information
4647async fn get_private_user_info(
4648 _request: proto::GetPrivateUserInfo,
4649 response: Response<proto::GetPrivateUserInfo>,
4650 session: UserSession,
4651) -> Result<()> {
4652 let db = session.db().await;
4653
4654 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4655 let user = db
4656 .get_user_by_id(session.user_id())
4657 .await?
4658 .ok_or_else(|| anyhow!("user not found"))?;
4659 let flags = db.get_user_flags(session.user_id()).await?;
4660
4661 response.send(proto::GetPrivateUserInfoResponse {
4662 metrics_id,
4663 staff: user.admin,
4664 flags,
4665 })?;
4666 Ok(())
4667}
4668
4669fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4670 match message {
4671 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4672 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4673 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4674 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4675 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4676 code: frame.code.into(),
4677 reason: frame.reason,
4678 })),
4679 }
4680}
4681
4682fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4683 match message {
4684 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4685 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4686 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4687 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4688 AxumMessage::Close(frame) => {
4689 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4690 code: frame.code.into(),
4691 reason: frame.reason,
4692 }))
4693 }
4694 }
4695}
4696
4697fn notify_membership_updated(
4698 connection_pool: &mut ConnectionPool,
4699 result: MembershipUpdated,
4700 user_id: UserId,
4701 peer: &Peer,
4702) {
4703 for membership in &result.new_channels.channel_memberships {
4704 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4705 }
4706 for channel_id in &result.removed_channels {
4707 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4708 }
4709
4710 let user_channels_update = proto::UpdateUserChannels {
4711 channel_memberships: result
4712 .new_channels
4713 .channel_memberships
4714 .iter()
4715 .map(|cm| proto::ChannelMembership {
4716 channel_id: cm.channel_id.to_proto(),
4717 role: cm.role.into(),
4718 })
4719 .collect(),
4720 ..Default::default()
4721 };
4722
4723 let mut update = build_channels_update(result.new_channels, vec![]);
4724 update.delete_channels = result
4725 .removed_channels
4726 .into_iter()
4727 .map(|id| id.to_proto())
4728 .collect();
4729 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4730
4731 for connection_id in connection_pool.user_connection_ids(user_id) {
4732 peer.send(connection_id, user_channels_update.clone())
4733 .trace_err();
4734 peer.send(connection_id, update.clone()).trace_err();
4735 }
4736}
4737
4738fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4739 proto::UpdateUserChannels {
4740 channel_memberships: channels
4741 .channel_memberships
4742 .iter()
4743 .map(|m| proto::ChannelMembership {
4744 channel_id: m.channel_id.to_proto(),
4745 role: m.role.into(),
4746 })
4747 .collect(),
4748 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4749 observed_channel_message_id: channels.observed_channel_messages.clone(),
4750 }
4751}
4752
4753fn build_channels_update(
4754 channels: ChannelsForUser,
4755 channel_invites: Vec<db::Channel>,
4756) -> proto::UpdateChannels {
4757 let mut update = proto::UpdateChannels::default();
4758
4759 for channel in channels.channels {
4760 update.channels.push(channel.to_proto());
4761 }
4762
4763 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4764 update.latest_channel_message_ids = channels.latest_channel_messages;
4765
4766 for (channel_id, participants) in channels.channel_participants {
4767 update
4768 .channel_participants
4769 .push(proto::ChannelParticipants {
4770 channel_id: channel_id.to_proto(),
4771 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4772 });
4773 }
4774
4775 for channel in channel_invites {
4776 update.channel_invitations.push(channel.to_proto());
4777 }
4778
4779 update.hosted_projects = channels.hosted_projects;
4780 update
4781}
4782
4783fn build_initial_contacts_update(
4784 contacts: Vec<db::Contact>,
4785 pool: &ConnectionPool,
4786) -> proto::UpdateContacts {
4787 let mut update = proto::UpdateContacts::default();
4788
4789 for contact in contacts {
4790 match contact {
4791 db::Contact::Accepted { user_id, busy } => {
4792 update.contacts.push(contact_for_user(user_id, busy, &pool));
4793 }
4794 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4795 db::Contact::Incoming { user_id } => {
4796 update
4797 .incoming_requests
4798 .push(proto::IncomingContactRequest {
4799 requester_id: user_id.to_proto(),
4800 })
4801 }
4802 }
4803 }
4804
4805 update
4806}
4807
4808fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4809 proto::Contact {
4810 user_id: user_id.to_proto(),
4811 online: pool.is_user_online(user_id),
4812 busy,
4813 }
4814}
4815
4816fn room_updated(room: &proto::Room, peer: &Peer) {
4817 broadcast(
4818 None,
4819 room.participants
4820 .iter()
4821 .filter_map(|participant| Some(participant.peer_id?.into())),
4822 |peer_id| {
4823 peer.send(
4824 peer_id,
4825 proto::RoomUpdated {
4826 room: Some(room.clone()),
4827 },
4828 )
4829 },
4830 );
4831}
4832
4833fn channel_updated(
4834 channel: &db::channel::Model,
4835 room: &proto::Room,
4836 peer: &Peer,
4837 pool: &ConnectionPool,
4838) {
4839 let participants = room
4840 .participants
4841 .iter()
4842 .map(|p| p.user_id)
4843 .collect::<Vec<_>>();
4844
4845 broadcast(
4846 None,
4847 pool.channel_connection_ids(channel.root_id())
4848 .filter_map(|(channel_id, role)| {
4849 role.can_see_channel(channel.visibility).then(|| channel_id)
4850 }),
4851 |peer_id| {
4852 peer.send(
4853 peer_id,
4854 proto::UpdateChannels {
4855 channel_participants: vec![proto::ChannelParticipants {
4856 channel_id: channel.id.to_proto(),
4857 participant_user_ids: participants.clone(),
4858 }],
4859 ..Default::default()
4860 },
4861 )
4862 },
4863 );
4864}
4865
4866async fn send_remote_projects_update(
4867 user_id: UserId,
4868 mut status: proto::RemoteProjectsUpdate,
4869 session: &Session,
4870) {
4871 let pool = session.connection_pool().await;
4872 for dev_server in &mut status.dev_servers {
4873 dev_server.status =
4874 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
4875 }
4876 let connections = pool.user_connection_ids(user_id);
4877 for connection_id in connections {
4878 session.peer.send(connection_id, status.clone()).trace_err();
4879 }
4880}
4881
4882async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4883 let db = session.db().await;
4884
4885 let contacts = db.get_contacts(user_id).await?;
4886 let busy = db.is_user_busy(user_id).await?;
4887
4888 let pool = session.connection_pool().await;
4889 let updated_contact = contact_for_user(user_id, busy, &pool);
4890 for contact in contacts {
4891 if let db::Contact::Accepted {
4892 user_id: contact_user_id,
4893 ..
4894 } = contact
4895 {
4896 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4897 session
4898 .peer
4899 .send(
4900 contact_conn_id,
4901 proto::UpdateContacts {
4902 contacts: vec![updated_contact.clone()],
4903 remove_contacts: Default::default(),
4904 incoming_requests: Default::default(),
4905 remove_incoming_requests: Default::default(),
4906 outgoing_requests: Default::default(),
4907 remove_outgoing_requests: Default::default(),
4908 },
4909 )
4910 .trace_err();
4911 }
4912 }
4913 }
4914 Ok(())
4915}
4916
4917async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
4918 log::info!("lost dev server connection, unsharing projects");
4919 let project_ids = session
4920 .db()
4921 .await
4922 .get_stale_dev_server_projects(session.connection_id)
4923 .await?;
4924
4925 for project_id in project_ids {
4926 // not unshare re-checks the connection ids match, so we get away with no transaction
4927 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
4928 }
4929
4930 let user_id = session.dev_server().user_id;
4931 let update = session.db().await.remote_projects_update(user_id).await?;
4932
4933 send_remote_projects_update(user_id, update, session).await;
4934
4935 Ok(())
4936}
4937
4938async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
4939 let mut contacts_to_update = HashSet::default();
4940
4941 let room_id;
4942 let canceled_calls_to_user_ids;
4943 let live_kit_room;
4944 let delete_live_kit_room;
4945 let room;
4946 let channel;
4947
4948 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4949 contacts_to_update.insert(session.user_id());
4950
4951 for project in left_room.left_projects.values() {
4952 project_left(project, session);
4953 }
4954
4955 room_id = RoomId::from_proto(left_room.room.id);
4956 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4957 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4958 delete_live_kit_room = left_room.deleted;
4959 room = mem::take(&mut left_room.room);
4960 channel = mem::take(&mut left_room.channel);
4961
4962 room_updated(&room, &session.peer);
4963 } else {
4964 return Ok(());
4965 }
4966
4967 if let Some(channel) = channel {
4968 channel_updated(
4969 &channel,
4970 &room,
4971 &session.peer,
4972 &*session.connection_pool().await,
4973 );
4974 }
4975
4976 {
4977 let pool = session.connection_pool().await;
4978 for canceled_user_id in canceled_calls_to_user_ids {
4979 for connection_id in pool.user_connection_ids(canceled_user_id) {
4980 session
4981 .peer
4982 .send(
4983 connection_id,
4984 proto::CallCanceled {
4985 room_id: room_id.to_proto(),
4986 },
4987 )
4988 .trace_err();
4989 }
4990 contacts_to_update.insert(canceled_user_id);
4991 }
4992 }
4993
4994 for contact_user_id in contacts_to_update {
4995 update_user_contacts(contact_user_id, &session).await?;
4996 }
4997
4998 if let Some(live_kit) = session.live_kit_client.as_ref() {
4999 live_kit
5000 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5001 .await
5002 .trace_err();
5003
5004 if delete_live_kit_room {
5005 live_kit.delete_room(live_kit_room).await.trace_err();
5006 }
5007 }
5008
5009 Ok(())
5010}
5011
5012async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5013 let left_channel_buffers = session
5014 .db()
5015 .await
5016 .leave_channel_buffers(session.connection_id)
5017 .await?;
5018
5019 for left_buffer in left_channel_buffers {
5020 channel_buffer_updated(
5021 session.connection_id,
5022 left_buffer.connections,
5023 &proto::UpdateChannelBufferCollaborators {
5024 channel_id: left_buffer.channel_id.to_proto(),
5025 collaborators: left_buffer.collaborators,
5026 },
5027 &session.peer,
5028 );
5029 }
5030
5031 Ok(())
5032}
5033
5034fn project_left(project: &db::LeftProject, session: &UserSession) {
5035 for connection_id in &project.connection_ids {
5036 if project.should_unshare {
5037 session
5038 .peer
5039 .send(
5040 *connection_id,
5041 proto::UnshareProject {
5042 project_id: project.id.to_proto(),
5043 },
5044 )
5045 .trace_err();
5046 } else {
5047 session
5048 .peer
5049 .send(
5050 *connection_id,
5051 proto::RemoveProjectCollaborator {
5052 project_id: project.id.to_proto(),
5053 peer_id: Some(session.connection_id.into()),
5054 },
5055 )
5056 .trace_err();
5057 }
5058 }
5059}
5060
5061pub trait ResultExt {
5062 type Ok;
5063
5064 fn trace_err(self) -> Option<Self::Ok>;
5065}
5066
5067impl<T, E> ResultExt for Result<T, E>
5068where
5069 E: std::fmt::Debug,
5070{
5071 type Ok = T;
5072
5073 #[track_caller]
5074 fn trace_err(self) -> Option<T> {
5075 match self {
5076 Ok(value) => Some(value),
5077 Err(error) => {
5078 tracing::error!("{:?}", error);
5079 None
5080 }
5081 }
5082 }
5083}