1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::db::billing_subscription::SubscriptionKind;
5use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
6use crate::{
7 AppState, Error, Result, auth,
8 db::{
9 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
10 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
11 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
12 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15};
16use anyhow::{Context as _, anyhow, bail};
17use async_tungstenite::tungstenite::{
18 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
19};
20use axum::{
21 Extension, Router, TypedHeader,
22 body::Body,
23 extract::{
24 ConnectInfo, WebSocketUpgrade,
25 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use reqwest_client::ReqwestClient;
38use rpc::proto::split_repository_update;
39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use time::OffsetDateTime;
70use tokio::sync::{MutexGuard, Semaphore, watch};
71use tower::ServiceBuilder;
72use tracing::{
73 Instrument,
74 field::{self},
75 info_span, instrument,
76};
77
78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
79
80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
82
83const MESSAGE_COUNT_PER_PAGE: usize = 100;
84const MAX_MESSAGE_LEN: usize = 1024;
85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
86
87type MessageHandler =
88 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
89
90struct Response<R> {
91 peer: Arc<Peer>,
92 receipt: Receipt<R>,
93 responded: Arc<AtomicBool>,
94}
95
96impl<R: RequestMessage> Response<R> {
97 fn send(self, payload: R::Response) -> Result<()> {
98 self.responded.store(true, SeqCst);
99 self.peer.respond(self.receipt, payload)?;
100 Ok(())
101 }
102}
103
104#[derive(Clone, Debug)]
105pub enum Principal {
106 User(User),
107 Impersonated { user: User, admin: User },
108}
109
110impl Principal {
111 fn update_span(&self, span: &tracing::Span) {
112 match &self {
113 Principal::User(user) => {
114 span.record("user_id", user.id.0);
115 span.record("login", &user.github_login);
116 }
117 Principal::Impersonated { user, admin } => {
118 span.record("user_id", user.id.0);
119 span.record("login", &user.github_login);
120 span.record("impersonator", &admin.github_login);
121 }
122 }
123 }
124}
125
126#[derive(Clone)]
127struct Session {
128 principal: Principal,
129 connection_id: ConnectionId,
130 db: Arc<tokio::sync::Mutex<DbHandle>>,
131 peer: Arc<Peer>,
132 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
133 app_state: Arc<AppState>,
134 supermaven_client: Option<Arc<SupermavenAdminApi>>,
135 /// The GeoIP country code for the user.
136 #[allow(unused)]
137 geoip_country_code: Option<String>,
138 system_id: Option<String>,
139 _executor: Executor,
140}
141
142impl Session {
143 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
144 #[cfg(test)]
145 tokio::task::yield_now().await;
146 let guard = self.db.lock().await;
147 #[cfg(test)]
148 tokio::task::yield_now().await;
149 guard
150 }
151
152 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
153 #[cfg(test)]
154 tokio::task::yield_now().await;
155 let guard = self.connection_pool.lock();
156 ConnectionPoolGuard {
157 guard,
158 _not_send: PhantomData,
159 }
160 }
161
162 fn is_staff(&self) -> bool {
163 match &self.principal {
164 Principal::User(user) => user.admin,
165 Principal::Impersonated { .. } => true,
166 }
167 }
168
169 pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
170 if self.is_staff() {
171 return Ok(proto::Plan::ZedPro);
172 }
173
174 let user_id = self.user_id();
175
176 let subscription = db.get_active_billing_subscription(user_id).await?;
177 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
178
179 let plan = if let Some(subscription_kind) = subscription_kind {
180 match subscription_kind {
181 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
182 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
183 SubscriptionKind::ZedFree => proto::Plan::Free,
184 }
185 } else {
186 proto::Plan::Free
187 };
188
189 Ok(plan)
190 }
191
192 fn user_id(&self) -> UserId {
193 match &self.principal {
194 Principal::User(user) => user.id,
195 Principal::Impersonated { user, .. } => user.id,
196 }
197 }
198
199 pub fn email(&self) -> Option<String> {
200 match &self.principal {
201 Principal::User(user) => user.email_address.clone(),
202 Principal::Impersonated { user, .. } => user.email_address.clone(),
203 }
204 }
205}
206
207impl Debug for Session {
208 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
209 let mut result = f.debug_struct("Session");
210 match &self.principal {
211 Principal::User(user) => {
212 result.field("user", &user.github_login);
213 }
214 Principal::Impersonated { user, admin } => {
215 result.field("user", &user.github_login);
216 result.field("impersonator", &admin.github_login);
217 }
218 }
219 result.field("connection_id", &self.connection_id).finish()
220 }
221}
222
223struct DbHandle(Arc<Database>);
224
225impl Deref for DbHandle {
226 type Target = Database;
227
228 fn deref(&self) -> &Self::Target {
229 self.0.as_ref()
230 }
231}
232
233pub struct Server {
234 id: parking_lot::Mutex<ServerId>,
235 peer: Arc<Peer>,
236 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
237 app_state: Arc<AppState>,
238 handlers: HashMap<TypeId, MessageHandler>,
239 teardown: watch::Sender<bool>,
240}
241
242pub(crate) struct ConnectionPoolGuard<'a> {
243 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
244 _not_send: PhantomData<Rc<()>>,
245}
246
247#[derive(Serialize)]
248pub struct ServerSnapshot<'a> {
249 peer: &'a Peer,
250 #[serde(serialize_with = "serialize_deref")]
251 connection_pool: ConnectionPoolGuard<'a>,
252}
253
254pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
255where
256 S: Serializer,
257 T: Deref<Target = U>,
258 U: Serialize,
259{
260 Serialize::serialize(value.deref(), serializer)
261}
262
263impl Server {
264 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
265 let mut server = Self {
266 id: parking_lot::Mutex::new(id),
267 peer: Peer::new(id.0 as u32),
268 app_state: app_state.clone(),
269 connection_pool: Default::default(),
270 handlers: Default::default(),
271 teardown: watch::channel(false).0,
272 };
273
274 server
275 .add_request_handler(ping)
276 .add_request_handler(create_room)
277 .add_request_handler(join_room)
278 .add_request_handler(rejoin_room)
279 .add_request_handler(leave_room)
280 .add_request_handler(set_room_participant_role)
281 .add_request_handler(call)
282 .add_request_handler(cancel_call)
283 .add_message_handler(decline_call)
284 .add_request_handler(update_participant_location)
285 .add_request_handler(share_project)
286 .add_message_handler(unshare_project)
287 .add_request_handler(join_project)
288 .add_message_handler(leave_project)
289 .add_request_handler(update_project)
290 .add_request_handler(update_worktree)
291 .add_request_handler(update_repository)
292 .add_request_handler(remove_repository)
293 .add_message_handler(start_language_server)
294 .add_message_handler(update_language_server)
295 .add_message_handler(update_diagnostic_summary)
296 .add_message_handler(update_worktree_settings)
297 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
298 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
301 .add_request_handler(forward_find_search_candidates_request)
302 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
314 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
315 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
316 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
317 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
318 .add_request_handler(
319 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
320 )
321 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
322 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
323 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
325 .add_request_handler(
326 forward_read_only_project_request::<proto::LanguageServerIdForName>,
327 )
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
341 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
342 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
344 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
345 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
346 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
347 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
351 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
352 .add_request_handler(
353 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
354 )
355 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
356 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
357 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
358 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
359 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
361 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
362 .add_message_handler(create_buffer_for_peer)
363 .add_request_handler(update_buffer)
364 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
365 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
366 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
367 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
368 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
369 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
370 .add_request_handler(get_users)
371 .add_request_handler(fuzzy_search_users)
372 .add_request_handler(request_contact)
373 .add_request_handler(remove_contact)
374 .add_request_handler(respond_to_contact_request)
375 .add_message_handler(subscribe_to_channels)
376 .add_request_handler(create_channel)
377 .add_request_handler(delete_channel)
378 .add_request_handler(invite_channel_member)
379 .add_request_handler(remove_channel_member)
380 .add_request_handler(set_channel_member_role)
381 .add_request_handler(set_channel_visibility)
382 .add_request_handler(rename_channel)
383 .add_request_handler(join_channel_buffer)
384 .add_request_handler(leave_channel_buffer)
385 .add_message_handler(update_channel_buffer)
386 .add_request_handler(rejoin_channel_buffers)
387 .add_request_handler(get_channel_members)
388 .add_request_handler(respond_to_channel_invite)
389 .add_request_handler(join_channel)
390 .add_request_handler(join_channel_chat)
391 .add_message_handler(leave_channel_chat)
392 .add_request_handler(send_channel_message)
393 .add_request_handler(remove_channel_message)
394 .add_request_handler(update_channel_message)
395 .add_request_handler(get_channel_messages)
396 .add_request_handler(get_channel_messages_by_id)
397 .add_request_handler(get_notifications)
398 .add_request_handler(mark_notification_as_read)
399 .add_request_handler(move_channel)
400 .add_request_handler(follow)
401 .add_message_handler(unfollow)
402 .add_message_handler(update_followers)
403 .add_request_handler(get_private_user_info)
404 .add_request_handler(get_llm_api_token)
405 .add_request_handler(accept_terms_of_service)
406 .add_message_handler(acknowledge_channel_message)
407 .add_message_handler(acknowledge_buffer_version)
408 .add_request_handler(get_supermaven_api_key)
409 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
410 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
411 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
412 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
413 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
414 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
415 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
416 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
417 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
418 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
419 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
420 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
421 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
422 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
423 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
424 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
425 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
426 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
427 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
428 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
429 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
430 .add_message_handler(update_context);
431
432 Arc::new(server)
433 }
434
435 pub async fn start(&self) -> Result<()> {
436 let server_id = *self.id.lock();
437 let app_state = self.app_state.clone();
438 let peer = self.peer.clone();
439 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
440 let pool = self.connection_pool.clone();
441 let livekit_client = self.app_state.livekit_client.clone();
442
443 let span = info_span!("start server");
444 self.app_state.executor.spawn_detached(
445 async move {
446 tracing::info!("waiting for cleanup timeout");
447 timeout.await;
448 tracing::info!("cleanup timeout expired, retrieving stale rooms");
449 if let Some((room_ids, channel_ids)) = app_state
450 .db
451 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
452 .await
453 .trace_err()
454 {
455 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
456 tracing::info!(
457 stale_channel_buffer_count = channel_ids.len(),
458 "retrieved stale channel buffers"
459 );
460
461 for channel_id in channel_ids {
462 if let Some(refreshed_channel_buffer) = app_state
463 .db
464 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
465 .await
466 .trace_err()
467 {
468 for connection_id in refreshed_channel_buffer.connection_ids {
469 peer.send(
470 connection_id,
471 proto::UpdateChannelBufferCollaborators {
472 channel_id: channel_id.to_proto(),
473 collaborators: refreshed_channel_buffer
474 .collaborators
475 .clone(),
476 },
477 )
478 .trace_err();
479 }
480 }
481 }
482
483 for room_id in room_ids {
484 let mut contacts_to_update = HashSet::default();
485 let mut canceled_calls_to_user_ids = Vec::new();
486 let mut livekit_room = String::new();
487 let mut delete_livekit_room = false;
488
489 if let Some(mut refreshed_room) = app_state
490 .db
491 .clear_stale_room_participants(room_id, server_id)
492 .await
493 .trace_err()
494 {
495 tracing::info!(
496 room_id = room_id.0,
497 new_participant_count = refreshed_room.room.participants.len(),
498 "refreshed room"
499 );
500 room_updated(&refreshed_room.room, &peer);
501 if let Some(channel) = refreshed_room.channel.as_ref() {
502 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
503 }
504 contacts_to_update
505 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
506 contacts_to_update
507 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
508 canceled_calls_to_user_ids =
509 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
510 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
511 delete_livekit_room = refreshed_room.room.participants.is_empty();
512 }
513
514 {
515 let pool = pool.lock();
516 for canceled_user_id in canceled_calls_to_user_ids {
517 for connection_id in pool.user_connection_ids(canceled_user_id) {
518 peer.send(
519 connection_id,
520 proto::CallCanceled {
521 room_id: room_id.to_proto(),
522 },
523 )
524 .trace_err();
525 }
526 }
527 }
528
529 for user_id in contacts_to_update {
530 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
531 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
532 if let Some((busy, contacts)) = busy.zip(contacts) {
533 let pool = pool.lock();
534 let updated_contact = contact_for_user(user_id, busy, &pool);
535 for contact in contacts {
536 if let db::Contact::Accepted {
537 user_id: contact_user_id,
538 ..
539 } = contact
540 {
541 for contact_conn_id in
542 pool.user_connection_ids(contact_user_id)
543 {
544 peer.send(
545 contact_conn_id,
546 proto::UpdateContacts {
547 contacts: vec![updated_contact.clone()],
548 remove_contacts: Default::default(),
549 incoming_requests: Default::default(),
550 remove_incoming_requests: Default::default(),
551 outgoing_requests: Default::default(),
552 remove_outgoing_requests: Default::default(),
553 },
554 )
555 .trace_err();
556 }
557 }
558 }
559 }
560 }
561
562 if let Some(live_kit) = livekit_client.as_ref() {
563 if delete_livekit_room {
564 live_kit.delete_room(livekit_room).await.trace_err();
565 }
566 }
567 }
568 }
569
570 app_state
571 .db
572 .delete_stale_servers(&app_state.config.zed_environment, server_id)
573 .await
574 .trace_err();
575 }
576 .instrument(span),
577 );
578 Ok(())
579 }
580
581 pub fn teardown(&self) {
582 self.peer.teardown();
583 self.connection_pool.lock().reset();
584 let _ = self.teardown.send(true);
585 }
586
587 #[cfg(test)]
588 pub fn reset(&self, id: ServerId) {
589 self.teardown();
590 *self.id.lock() = id;
591 self.peer.reset(id.0 as u32);
592 let _ = self.teardown.send(false);
593 }
594
595 #[cfg(test)]
596 pub fn id(&self) -> ServerId {
597 *self.id.lock()
598 }
599
600 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
601 where
602 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
603 Fut: 'static + Send + Future<Output = Result<()>>,
604 M: EnvelopedMessage,
605 {
606 let prev_handler = self.handlers.insert(
607 TypeId::of::<M>(),
608 Box::new(move |envelope, session| {
609 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
610 let received_at = envelope.received_at;
611 tracing::info!("message received");
612 let start_time = Instant::now();
613 let future = (handler)(*envelope, session);
614 async move {
615 let result = future.await;
616 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
617 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
618 let queue_duration_ms = total_duration_ms - processing_duration_ms;
619 let payload_type = M::NAME;
620
621 match result {
622 Err(error) => {
623 tracing::error!(
624 ?error,
625 total_duration_ms,
626 processing_duration_ms,
627 queue_duration_ms,
628 payload_type,
629 "error handling message"
630 )
631 }
632 Ok(()) => tracing::info!(
633 total_duration_ms,
634 processing_duration_ms,
635 queue_duration_ms,
636 "finished handling message"
637 ),
638 }
639 }
640 .boxed()
641 }),
642 );
643 if prev_handler.is_some() {
644 panic!("registered a handler for the same message twice");
645 }
646 self
647 }
648
649 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
650 where
651 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
652 Fut: 'static + Send + Future<Output = Result<()>>,
653 M: EnvelopedMessage,
654 {
655 self.add_handler(move |envelope, session| handler(envelope.payload, session));
656 self
657 }
658
659 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
660 where
661 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
662 Fut: Send + Future<Output = Result<()>>,
663 M: RequestMessage,
664 {
665 let handler = Arc::new(handler);
666 self.add_handler(move |envelope, session| {
667 let receipt = envelope.receipt();
668 let handler = handler.clone();
669 async move {
670 let peer = session.peer.clone();
671 let responded = Arc::new(AtomicBool::default());
672 let response = Response {
673 peer: peer.clone(),
674 responded: responded.clone(),
675 receipt,
676 };
677 match (handler)(envelope.payload, response, session).await {
678 Ok(()) => {
679 if responded.load(std::sync::atomic::Ordering::SeqCst) {
680 Ok(())
681 } else {
682 Err(anyhow!("handler did not send a response"))?
683 }
684 }
685 Err(error) => {
686 let proto_err = match &error {
687 Error::Internal(err) => err.to_proto(),
688 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
689 };
690 peer.respond_with_error(receipt, proto_err)?;
691 Err(error)
692 }
693 }
694 }
695 })
696 }
697
698 pub fn handle_connection(
699 self: &Arc<Self>,
700 connection: Connection,
701 address: String,
702 principal: Principal,
703 zed_version: ZedVersion,
704 geoip_country_code: Option<String>,
705 system_id: Option<String>,
706 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
707 executor: Executor,
708 ) -> impl Future<Output = ()> + use<> {
709 let this = self.clone();
710 let span = info_span!("handle connection", %address,
711 connection_id=field::Empty,
712 user_id=field::Empty,
713 login=field::Empty,
714 impersonator=field::Empty,
715 geoip_country_code=field::Empty
716 );
717 principal.update_span(&span);
718 if let Some(country_code) = geoip_country_code.as_ref() {
719 span.record("geoip_country_code", country_code);
720 }
721
722 let mut teardown = self.teardown.subscribe();
723 async move {
724 if *teardown.borrow() {
725 tracing::error!("server is tearing down");
726 return
727 }
728 let (connection_id, handle_io, mut incoming_rx) = this
729 .peer
730 .add_connection(connection, {
731 let executor = executor.clone();
732 move |duration| executor.sleep(duration)
733 });
734 tracing::Span::current().record("connection_id", format!("{}", connection_id));
735
736 tracing::info!("connection opened");
737
738 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
739 let http_client = match ReqwestClient::user_agent(&user_agent) {
740 Ok(http_client) => Arc::new(http_client),
741 Err(error) => {
742 tracing::error!(?error, "failed to create HTTP client");
743 return;
744 }
745 };
746
747 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
748 supermaven_admin_api_key.to_string(),
749 http_client.clone(),
750 )));
751
752 let session = Session {
753 principal: principal.clone(),
754 connection_id,
755 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
756 peer: this.peer.clone(),
757 connection_pool: this.connection_pool.clone(),
758 app_state: this.app_state.clone(),
759 geoip_country_code,
760 system_id,
761 _executor: executor.clone(),
762 supermaven_client,
763 };
764
765 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
766 tracing::error!(?error, "failed to send initial client update");
767 return;
768 }
769
770 let handle_io = handle_io.fuse();
771 futures::pin_mut!(handle_io);
772
773 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
774 // This prevents deadlocks when e.g., client A performs a request to client B and
775 // client B performs a request to client A. If both clients stop processing further
776 // messages until their respective request completes, they won't have a chance to
777 // respond to the other client's request and cause a deadlock.
778 //
779 // This arrangement ensures we will attempt to process earlier messages first, but fall
780 // back to processing messages arrived later in the spirit of making progress.
781 let mut foreground_message_handlers = FuturesUnordered::new();
782 let concurrent_handlers = Arc::new(Semaphore::new(256));
783 loop {
784 let next_message = async {
785 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
786 let message = incoming_rx.next().await;
787 (permit, message)
788 }.fuse();
789 futures::pin_mut!(next_message);
790 futures::select_biased! {
791 _ = teardown.changed().fuse() => return,
792 result = handle_io => {
793 if let Err(error) = result {
794 tracing::error!(?error, "error handling I/O");
795 }
796 break;
797 }
798 _ = foreground_message_handlers.next() => {}
799 next_message = next_message => {
800 let (permit, message) = next_message;
801 if let Some(message) = message {
802 let type_name = message.payload_type_name();
803 // note: we copy all the fields from the parent span so we can query them in the logs.
804 // (https://github.com/tokio-rs/tracing/issues/2670).
805 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
806 user_id=field::Empty,
807 login=field::Empty,
808 impersonator=field::Empty,
809 );
810 principal.update_span(&span);
811 let span_enter = span.enter();
812 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
813 let is_background = message.is_background();
814 let handle_message = (handler)(message, session.clone());
815 drop(span_enter);
816
817 let handle_message = async move {
818 handle_message.await;
819 drop(permit);
820 }.instrument(span);
821 if is_background {
822 executor.spawn_detached(handle_message);
823 } else {
824 foreground_message_handlers.push(handle_message);
825 }
826 } else {
827 tracing::error!("no message handler");
828 }
829 } else {
830 tracing::info!("connection closed");
831 break;
832 }
833 }
834 }
835 }
836
837 drop(foreground_message_handlers);
838 tracing::info!("signing out");
839 if let Err(error) = connection_lost(session, teardown, executor).await {
840 tracing::error!(?error, "error signing out");
841 }
842
843 }.instrument(span)
844 }
845
846 async fn send_initial_client_update(
847 &self,
848 connection_id: ConnectionId,
849 principal: &Principal,
850 zed_version: ZedVersion,
851 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
852 session: &Session,
853 ) -> Result<()> {
854 self.peer.send(
855 connection_id,
856 proto::Hello {
857 peer_id: Some(connection_id.into()),
858 },
859 )?;
860 tracing::info!("sent hello message");
861 if let Some(send_connection_id) = send_connection_id.take() {
862 let _ = send_connection_id.send(connection_id);
863 }
864
865 match principal {
866 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
867 if !user.connected_once {
868 self.peer.send(connection_id, proto::ShowContacts {})?;
869 self.app_state
870 .db
871 .set_user_connected_once(user.id, true)
872 .await?;
873 }
874
875 update_user_plan(user.id, session).await?;
876
877 let contacts = self.app_state.db.get_contacts(user.id).await?;
878
879 {
880 let mut pool = self.connection_pool.lock();
881 pool.add_connection(connection_id, user.id, user.admin, zed_version);
882 self.peer.send(
883 connection_id,
884 build_initial_contacts_update(contacts, &pool),
885 )?;
886 }
887
888 if should_auto_subscribe_to_channels(zed_version) {
889 subscribe_user_to_channels(user.id, session).await?;
890 }
891
892 if let Some(incoming_call) =
893 self.app_state.db.incoming_call_for_user(user.id).await?
894 {
895 self.peer.send(connection_id, incoming_call)?;
896 }
897
898 update_user_contacts(user.id, session).await?;
899 }
900 }
901
902 Ok(())
903 }
904
905 pub async fn invite_code_redeemed(
906 self: &Arc<Self>,
907 inviter_id: UserId,
908 invitee_id: UserId,
909 ) -> Result<()> {
910 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
911 if let Some(code) = &user.invite_code {
912 let pool = self.connection_pool.lock();
913 let invitee_contact = contact_for_user(invitee_id, false, &pool);
914 for connection_id in pool.user_connection_ids(inviter_id) {
915 self.peer.send(
916 connection_id,
917 proto::UpdateContacts {
918 contacts: vec![invitee_contact.clone()],
919 ..Default::default()
920 },
921 )?;
922 self.peer.send(
923 connection_id,
924 proto::UpdateInviteInfo {
925 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
926 count: user.invite_count as u32,
927 },
928 )?;
929 }
930 }
931 }
932 Ok(())
933 }
934
935 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
936 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
937 if let Some(invite_code) = &user.invite_code {
938 let pool = self.connection_pool.lock();
939 for connection_id in pool.user_connection_ids(user_id) {
940 self.peer.send(
941 connection_id,
942 proto::UpdateInviteInfo {
943 url: format!(
944 "{}{}",
945 self.app_state.config.invite_link_prefix, invite_code
946 ),
947 count: user.invite_count as u32,
948 },
949 )?;
950 }
951 }
952 }
953 Ok(())
954 }
955
956 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
957 let pool = self.connection_pool.lock();
958 for connection_id in pool.user_connection_ids(user_id) {
959 self.peer
960 .send(connection_id, proto::RefreshLlmToken {})
961 .trace_err();
962 }
963 }
964
965 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
966 ServerSnapshot {
967 connection_pool: ConnectionPoolGuard {
968 guard: self.connection_pool.lock(),
969 _not_send: PhantomData,
970 },
971 peer: &self.peer,
972 }
973 }
974}
975
976impl Deref for ConnectionPoolGuard<'_> {
977 type Target = ConnectionPool;
978
979 fn deref(&self) -> &Self::Target {
980 &self.guard
981 }
982}
983
984impl DerefMut for ConnectionPoolGuard<'_> {
985 fn deref_mut(&mut self) -> &mut Self::Target {
986 &mut self.guard
987 }
988}
989
990impl Drop for ConnectionPoolGuard<'_> {
991 fn drop(&mut self) {
992 #[cfg(test)]
993 self.check_invariants();
994 }
995}
996
997fn broadcast<F>(
998 sender_id: Option<ConnectionId>,
999 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1000 mut f: F,
1001) where
1002 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1003{
1004 for receiver_id in receiver_ids {
1005 if Some(receiver_id) != sender_id {
1006 if let Err(error) = f(receiver_id) {
1007 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1008 }
1009 }
1010 }
1011}
1012
1013pub struct ProtocolVersion(u32);
1014
1015impl Header for ProtocolVersion {
1016 fn name() -> &'static HeaderName {
1017 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1018 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1019 }
1020
1021 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1022 where
1023 Self: Sized,
1024 I: Iterator<Item = &'i axum::http::HeaderValue>,
1025 {
1026 let version = values
1027 .next()
1028 .ok_or_else(axum::headers::Error::invalid)?
1029 .to_str()
1030 .map_err(|_| axum::headers::Error::invalid())?
1031 .parse()
1032 .map_err(|_| axum::headers::Error::invalid())?;
1033 Ok(Self(version))
1034 }
1035
1036 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1037 values.extend([self.0.to_string().parse().unwrap()]);
1038 }
1039}
1040
1041pub struct AppVersionHeader(SemanticVersion);
1042impl Header for AppVersionHeader {
1043 fn name() -> &'static HeaderName {
1044 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1045 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1046 }
1047
1048 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1049 where
1050 Self: Sized,
1051 I: Iterator<Item = &'i axum::http::HeaderValue>,
1052 {
1053 let version = values
1054 .next()
1055 .ok_or_else(axum::headers::Error::invalid)?
1056 .to_str()
1057 .map_err(|_| axum::headers::Error::invalid())?
1058 .parse()
1059 .map_err(|_| axum::headers::Error::invalid())?;
1060 Ok(Self(version))
1061 }
1062
1063 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1064 values.extend([self.0.to_string().parse().unwrap()]);
1065 }
1066}
1067
1068pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1069 Router::new()
1070 .route("/rpc", get(handle_websocket_request))
1071 .layer(
1072 ServiceBuilder::new()
1073 .layer(Extension(server.app_state.clone()))
1074 .layer(middleware::from_fn(auth::validate_header)),
1075 )
1076 .route("/metrics", get(handle_metrics))
1077 .layer(Extension(server))
1078}
1079
1080pub async fn handle_websocket_request(
1081 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1082 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1083 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1084 Extension(server): Extension<Arc<Server>>,
1085 Extension(principal): Extension<Principal>,
1086 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1087 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1088 ws: WebSocketUpgrade,
1089) -> axum::response::Response {
1090 if protocol_version != rpc::PROTOCOL_VERSION {
1091 return (
1092 StatusCode::UPGRADE_REQUIRED,
1093 "client must be upgraded".to_string(),
1094 )
1095 .into_response();
1096 }
1097
1098 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1099 return (
1100 StatusCode::UPGRADE_REQUIRED,
1101 "no version header found".to_string(),
1102 )
1103 .into_response();
1104 };
1105
1106 if !version.can_collaborate() {
1107 return (
1108 StatusCode::UPGRADE_REQUIRED,
1109 "client must be upgraded".to_string(),
1110 )
1111 .into_response();
1112 }
1113
1114 let socket_address = socket_address.to_string();
1115 ws.on_upgrade(move |socket| {
1116 let socket = socket
1117 .map_ok(to_tungstenite_message)
1118 .err_into()
1119 .with(|message| async move { to_axum_message(message) });
1120 let connection = Connection::new(Box::pin(socket));
1121 async move {
1122 server
1123 .handle_connection(
1124 connection,
1125 socket_address,
1126 principal,
1127 version,
1128 country_code_header.map(|header| header.to_string()),
1129 system_id_header.map(|header| header.to_string()),
1130 None,
1131 Executor::Production,
1132 )
1133 .await;
1134 }
1135 })
1136}
1137
1138pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1139 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1140 let connections_metric = CONNECTIONS_METRIC
1141 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1142
1143 let connections = server
1144 .connection_pool
1145 .lock()
1146 .connections()
1147 .filter(|connection| !connection.admin)
1148 .count();
1149 connections_metric.set(connections as _);
1150
1151 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1152 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1153 register_int_gauge!(
1154 "shared_projects",
1155 "number of open projects with one or more guests"
1156 )
1157 .unwrap()
1158 });
1159
1160 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1161 shared_projects_metric.set(shared_projects as _);
1162
1163 let encoder = prometheus::TextEncoder::new();
1164 let metric_families = prometheus::gather();
1165 let encoded_metrics = encoder
1166 .encode_to_string(&metric_families)
1167 .map_err(|err| anyhow!("{}", err))?;
1168 Ok(encoded_metrics)
1169}
1170
1171#[instrument(err, skip(executor))]
1172async fn connection_lost(
1173 session: Session,
1174 mut teardown: watch::Receiver<bool>,
1175 executor: Executor,
1176) -> Result<()> {
1177 session.peer.disconnect(session.connection_id);
1178 session
1179 .connection_pool()
1180 .await
1181 .remove_connection(session.connection_id)?;
1182
1183 session
1184 .db()
1185 .await
1186 .connection_lost(session.connection_id)
1187 .await
1188 .trace_err();
1189
1190 futures::select_biased! {
1191 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1192
1193 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1194 leave_room_for_session(&session, session.connection_id).await.trace_err();
1195 leave_channel_buffers_for_session(&session)
1196 .await
1197 .trace_err();
1198
1199 if !session
1200 .connection_pool()
1201 .await
1202 .is_user_online(session.user_id())
1203 {
1204 let db = session.db().await;
1205 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1206 room_updated(&room, &session.peer);
1207 }
1208 }
1209
1210 update_user_contacts(session.user_id(), &session).await?;
1211 },
1212 _ = teardown.changed().fuse() => {}
1213 }
1214
1215 Ok(())
1216}
1217
1218/// Acknowledges a ping from a client, used to keep the connection alive.
1219async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1220 response.send(proto::Ack {})?;
1221 Ok(())
1222}
1223
1224/// Creates a new room for calling (outside of channels)
1225async fn create_room(
1226 _request: proto::CreateRoom,
1227 response: Response<proto::CreateRoom>,
1228 session: Session,
1229) -> Result<()> {
1230 let livekit_room = nanoid::nanoid!(30);
1231
1232 let live_kit_connection_info = util::maybe!(async {
1233 let live_kit = session.app_state.livekit_client.as_ref();
1234 let live_kit = live_kit?;
1235 let user_id = session.user_id().to_string();
1236
1237 let token = live_kit
1238 .room_token(&livekit_room, &user_id.to_string())
1239 .trace_err()?;
1240
1241 Some(proto::LiveKitConnectionInfo {
1242 server_url: live_kit.url().into(),
1243 token,
1244 can_publish: true,
1245 })
1246 })
1247 .await;
1248
1249 let room = session
1250 .db()
1251 .await
1252 .create_room(session.user_id(), session.connection_id, &livekit_room)
1253 .await?;
1254
1255 response.send(proto::CreateRoomResponse {
1256 room: Some(room.clone()),
1257 live_kit_connection_info,
1258 })?;
1259
1260 update_user_contacts(session.user_id(), &session).await?;
1261 Ok(())
1262}
1263
1264/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1265async fn join_room(
1266 request: proto::JoinRoom,
1267 response: Response<proto::JoinRoom>,
1268 session: Session,
1269) -> Result<()> {
1270 let room_id = RoomId::from_proto(request.id);
1271
1272 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1273
1274 if let Some(channel_id) = channel_id {
1275 return join_channel_internal(channel_id, Box::new(response), session).await;
1276 }
1277
1278 let joined_room = {
1279 let room = session
1280 .db()
1281 .await
1282 .join_room(room_id, session.user_id(), session.connection_id)
1283 .await?;
1284 room_updated(&room.room, &session.peer);
1285 room.into_inner()
1286 };
1287
1288 for connection_id in session
1289 .connection_pool()
1290 .await
1291 .user_connection_ids(session.user_id())
1292 {
1293 session
1294 .peer
1295 .send(
1296 connection_id,
1297 proto::CallCanceled {
1298 room_id: room_id.to_proto(),
1299 },
1300 )
1301 .trace_err();
1302 }
1303
1304 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1305 {
1306 live_kit
1307 .room_token(
1308 &joined_room.room.livekit_room,
1309 &session.user_id().to_string(),
1310 )
1311 .trace_err()
1312 .map(|token| proto::LiveKitConnectionInfo {
1313 server_url: live_kit.url().into(),
1314 token,
1315 can_publish: true,
1316 })
1317 } else {
1318 None
1319 };
1320
1321 response.send(proto::JoinRoomResponse {
1322 room: Some(joined_room.room),
1323 channel_id: None,
1324 live_kit_connection_info,
1325 })?;
1326
1327 update_user_contacts(session.user_id(), &session).await?;
1328 Ok(())
1329}
1330
1331/// Rejoin room is used to reconnect to a room after connection errors.
1332async fn rejoin_room(
1333 request: proto::RejoinRoom,
1334 response: Response<proto::RejoinRoom>,
1335 session: Session,
1336) -> Result<()> {
1337 let room;
1338 let channel;
1339 {
1340 let mut rejoined_room = session
1341 .db()
1342 .await
1343 .rejoin_room(request, session.user_id(), session.connection_id)
1344 .await?;
1345
1346 response.send(proto::RejoinRoomResponse {
1347 room: Some(rejoined_room.room.clone()),
1348 reshared_projects: rejoined_room
1349 .reshared_projects
1350 .iter()
1351 .map(|project| proto::ResharedProject {
1352 id: project.id.to_proto(),
1353 collaborators: project
1354 .collaborators
1355 .iter()
1356 .map(|collaborator| collaborator.to_proto())
1357 .collect(),
1358 })
1359 .collect(),
1360 rejoined_projects: rejoined_room
1361 .rejoined_projects
1362 .iter()
1363 .map(|rejoined_project| rejoined_project.to_proto())
1364 .collect(),
1365 })?;
1366 room_updated(&rejoined_room.room, &session.peer);
1367
1368 for project in &rejoined_room.reshared_projects {
1369 for collaborator in &project.collaborators {
1370 session
1371 .peer
1372 .send(
1373 collaborator.connection_id,
1374 proto::UpdateProjectCollaborator {
1375 project_id: project.id.to_proto(),
1376 old_peer_id: Some(project.old_connection_id.into()),
1377 new_peer_id: Some(session.connection_id.into()),
1378 },
1379 )
1380 .trace_err();
1381 }
1382
1383 broadcast(
1384 Some(session.connection_id),
1385 project
1386 .collaborators
1387 .iter()
1388 .map(|collaborator| collaborator.connection_id),
1389 |connection_id| {
1390 session.peer.forward_send(
1391 session.connection_id,
1392 connection_id,
1393 proto::UpdateProject {
1394 project_id: project.id.to_proto(),
1395 worktrees: project.worktrees.clone(),
1396 },
1397 )
1398 },
1399 );
1400 }
1401
1402 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1403
1404 let rejoined_room = rejoined_room.into_inner();
1405
1406 room = rejoined_room.room;
1407 channel = rejoined_room.channel;
1408 }
1409
1410 if let Some(channel) = channel {
1411 channel_updated(
1412 &channel,
1413 &room,
1414 &session.peer,
1415 &*session.connection_pool().await,
1416 );
1417 }
1418
1419 update_user_contacts(session.user_id(), &session).await?;
1420 Ok(())
1421}
1422
1423fn notify_rejoined_projects(
1424 rejoined_projects: &mut Vec<RejoinedProject>,
1425 session: &Session,
1426) -> Result<()> {
1427 for project in rejoined_projects.iter() {
1428 for collaborator in &project.collaborators {
1429 session
1430 .peer
1431 .send(
1432 collaborator.connection_id,
1433 proto::UpdateProjectCollaborator {
1434 project_id: project.id.to_proto(),
1435 old_peer_id: Some(project.old_connection_id.into()),
1436 new_peer_id: Some(session.connection_id.into()),
1437 },
1438 )
1439 .trace_err();
1440 }
1441 }
1442
1443 for project in rejoined_projects {
1444 for worktree in mem::take(&mut project.worktrees) {
1445 // Stream this worktree's entries.
1446 let message = proto::UpdateWorktree {
1447 project_id: project.id.to_proto(),
1448 worktree_id: worktree.id,
1449 abs_path: worktree.abs_path.clone(),
1450 root_name: worktree.root_name,
1451 updated_entries: worktree.updated_entries,
1452 removed_entries: worktree.removed_entries,
1453 scan_id: worktree.scan_id,
1454 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1455 updated_repositories: worktree.updated_repositories,
1456 removed_repositories: worktree.removed_repositories,
1457 };
1458 for update in proto::split_worktree_update(message) {
1459 session.peer.send(session.connection_id, update)?;
1460 }
1461
1462 // Stream this worktree's diagnostics.
1463 for summary in worktree.diagnostic_summaries {
1464 session.peer.send(
1465 session.connection_id,
1466 proto::UpdateDiagnosticSummary {
1467 project_id: project.id.to_proto(),
1468 worktree_id: worktree.id,
1469 summary: Some(summary),
1470 },
1471 )?;
1472 }
1473
1474 for settings_file in worktree.settings_files {
1475 session.peer.send(
1476 session.connection_id,
1477 proto::UpdateWorktreeSettings {
1478 project_id: project.id.to_proto(),
1479 worktree_id: worktree.id,
1480 path: settings_file.path,
1481 content: Some(settings_file.content),
1482 kind: Some(settings_file.kind.to_proto().into()),
1483 },
1484 )?;
1485 }
1486 }
1487
1488 for repository in mem::take(&mut project.updated_repositories) {
1489 for update in split_repository_update(repository) {
1490 session.peer.send(session.connection_id, update)?;
1491 }
1492 }
1493
1494 for id in mem::take(&mut project.removed_repositories) {
1495 session.peer.send(
1496 session.connection_id,
1497 proto::RemoveRepository {
1498 project_id: project.id.to_proto(),
1499 id,
1500 },
1501 )?;
1502 }
1503 }
1504
1505 Ok(())
1506}
1507
1508/// leave room disconnects from the room.
1509async fn leave_room(
1510 _: proto::LeaveRoom,
1511 response: Response<proto::LeaveRoom>,
1512 session: Session,
1513) -> Result<()> {
1514 leave_room_for_session(&session, session.connection_id).await?;
1515 response.send(proto::Ack {})?;
1516 Ok(())
1517}
1518
1519/// Updates the permissions of someone else in the room.
1520async fn set_room_participant_role(
1521 request: proto::SetRoomParticipantRole,
1522 response: Response<proto::SetRoomParticipantRole>,
1523 session: Session,
1524) -> Result<()> {
1525 let user_id = UserId::from_proto(request.user_id);
1526 let role = ChannelRole::from(request.role());
1527
1528 let (livekit_room, can_publish) = {
1529 let room = session
1530 .db()
1531 .await
1532 .set_room_participant_role(
1533 session.user_id(),
1534 RoomId::from_proto(request.room_id),
1535 user_id,
1536 role,
1537 )
1538 .await?;
1539
1540 let livekit_room = room.livekit_room.clone();
1541 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1542 room_updated(&room, &session.peer);
1543 (livekit_room, can_publish)
1544 };
1545
1546 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1547 live_kit
1548 .update_participant(
1549 livekit_room.clone(),
1550 request.user_id.to_string(),
1551 livekit_api::proto::ParticipantPermission {
1552 can_subscribe: true,
1553 can_publish,
1554 can_publish_data: can_publish,
1555 hidden: false,
1556 recorder: false,
1557 },
1558 )
1559 .await
1560 .trace_err();
1561 }
1562
1563 response.send(proto::Ack {})?;
1564 Ok(())
1565}
1566
1567/// Call someone else into the current room
1568async fn call(
1569 request: proto::Call,
1570 response: Response<proto::Call>,
1571 session: Session,
1572) -> Result<()> {
1573 let room_id = RoomId::from_proto(request.room_id);
1574 let calling_user_id = session.user_id();
1575 let calling_connection_id = session.connection_id;
1576 let called_user_id = UserId::from_proto(request.called_user_id);
1577 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1578 if !session
1579 .db()
1580 .await
1581 .has_contact(calling_user_id, called_user_id)
1582 .await?
1583 {
1584 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1585 }
1586
1587 let incoming_call = {
1588 let (room, incoming_call) = &mut *session
1589 .db()
1590 .await
1591 .call(
1592 room_id,
1593 calling_user_id,
1594 calling_connection_id,
1595 called_user_id,
1596 initial_project_id,
1597 )
1598 .await?;
1599 room_updated(room, &session.peer);
1600 mem::take(incoming_call)
1601 };
1602 update_user_contacts(called_user_id, &session).await?;
1603
1604 let mut calls = session
1605 .connection_pool()
1606 .await
1607 .user_connection_ids(called_user_id)
1608 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1609 .collect::<FuturesUnordered<_>>();
1610
1611 while let Some(call_response) = calls.next().await {
1612 match call_response.as_ref() {
1613 Ok(_) => {
1614 response.send(proto::Ack {})?;
1615 return Ok(());
1616 }
1617 Err(_) => {
1618 call_response.trace_err();
1619 }
1620 }
1621 }
1622
1623 {
1624 let room = session
1625 .db()
1626 .await
1627 .call_failed(room_id, called_user_id)
1628 .await?;
1629 room_updated(&room, &session.peer);
1630 }
1631 update_user_contacts(called_user_id, &session).await?;
1632
1633 Err(anyhow!("failed to ring user"))?
1634}
1635
1636/// Cancel an outgoing call.
1637async fn cancel_call(
1638 request: proto::CancelCall,
1639 response: Response<proto::CancelCall>,
1640 session: Session,
1641) -> Result<()> {
1642 let called_user_id = UserId::from_proto(request.called_user_id);
1643 let room_id = RoomId::from_proto(request.room_id);
1644 {
1645 let room = session
1646 .db()
1647 .await
1648 .cancel_call(room_id, session.connection_id, called_user_id)
1649 .await?;
1650 room_updated(&room, &session.peer);
1651 }
1652
1653 for connection_id in session
1654 .connection_pool()
1655 .await
1656 .user_connection_ids(called_user_id)
1657 {
1658 session
1659 .peer
1660 .send(
1661 connection_id,
1662 proto::CallCanceled {
1663 room_id: room_id.to_proto(),
1664 },
1665 )
1666 .trace_err();
1667 }
1668 response.send(proto::Ack {})?;
1669
1670 update_user_contacts(called_user_id, &session).await?;
1671 Ok(())
1672}
1673
1674/// Decline an incoming call.
1675async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1676 let room_id = RoomId::from_proto(message.room_id);
1677 {
1678 let room = session
1679 .db()
1680 .await
1681 .decline_call(Some(room_id), session.user_id())
1682 .await?
1683 .ok_or_else(|| anyhow!("failed to decline call"))?;
1684 room_updated(&room, &session.peer);
1685 }
1686
1687 for connection_id in session
1688 .connection_pool()
1689 .await
1690 .user_connection_ids(session.user_id())
1691 {
1692 session
1693 .peer
1694 .send(
1695 connection_id,
1696 proto::CallCanceled {
1697 room_id: room_id.to_proto(),
1698 },
1699 )
1700 .trace_err();
1701 }
1702 update_user_contacts(session.user_id(), &session).await?;
1703 Ok(())
1704}
1705
1706/// Updates other participants in the room with your current location.
1707async fn update_participant_location(
1708 request: proto::UpdateParticipantLocation,
1709 response: Response<proto::UpdateParticipantLocation>,
1710 session: Session,
1711) -> Result<()> {
1712 let room_id = RoomId::from_proto(request.room_id);
1713 let location = request
1714 .location
1715 .ok_or_else(|| anyhow!("invalid location"))?;
1716
1717 let db = session.db().await;
1718 let room = db
1719 .update_room_participant_location(room_id, session.connection_id, location)
1720 .await?;
1721
1722 room_updated(&room, &session.peer);
1723 response.send(proto::Ack {})?;
1724 Ok(())
1725}
1726
1727/// Share a project into the room.
1728async fn share_project(
1729 request: proto::ShareProject,
1730 response: Response<proto::ShareProject>,
1731 session: Session,
1732) -> Result<()> {
1733 let (project_id, room) = &*session
1734 .db()
1735 .await
1736 .share_project(
1737 RoomId::from_proto(request.room_id),
1738 session.connection_id,
1739 &request.worktrees,
1740 request.is_ssh_project,
1741 )
1742 .await?;
1743 response.send(proto::ShareProjectResponse {
1744 project_id: project_id.to_proto(),
1745 })?;
1746 room_updated(room, &session.peer);
1747
1748 Ok(())
1749}
1750
1751/// Unshare a project from the room.
1752async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1753 let project_id = ProjectId::from_proto(message.project_id);
1754 unshare_project_internal(project_id, session.connection_id, &session).await
1755}
1756
1757async fn unshare_project_internal(
1758 project_id: ProjectId,
1759 connection_id: ConnectionId,
1760 session: &Session,
1761) -> Result<()> {
1762 let delete = {
1763 let room_guard = session
1764 .db()
1765 .await
1766 .unshare_project(project_id, connection_id)
1767 .await?;
1768
1769 let (delete, room, guest_connection_ids) = &*room_guard;
1770
1771 let message = proto::UnshareProject {
1772 project_id: project_id.to_proto(),
1773 };
1774
1775 broadcast(
1776 Some(connection_id),
1777 guest_connection_ids.iter().copied(),
1778 |conn_id| session.peer.send(conn_id, message.clone()),
1779 );
1780 if let Some(room) = room {
1781 room_updated(room, &session.peer);
1782 }
1783
1784 *delete
1785 };
1786
1787 if delete {
1788 let db = session.db().await;
1789 db.delete_project(project_id).await?;
1790 }
1791
1792 Ok(())
1793}
1794
1795/// Join someone elses shared project.
1796async fn join_project(
1797 request: proto::JoinProject,
1798 response: Response<proto::JoinProject>,
1799 session: Session,
1800) -> Result<()> {
1801 let project_id = ProjectId::from_proto(request.project_id);
1802
1803 tracing::info!(%project_id, "join project");
1804
1805 let db = session.db().await;
1806 let (project, replica_id) = &mut *db
1807 .join_project(project_id, session.connection_id, session.user_id())
1808 .await?;
1809 drop(db);
1810 tracing::info!(%project_id, "join remote project");
1811 join_project_internal(response, session, project, replica_id)
1812}
1813
1814trait JoinProjectInternalResponse {
1815 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1816}
1817impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1818 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1819 Response::<proto::JoinProject>::send(self, result)
1820 }
1821}
1822
1823fn join_project_internal(
1824 response: impl JoinProjectInternalResponse,
1825 session: Session,
1826 project: &mut Project,
1827 replica_id: &ReplicaId,
1828) -> Result<()> {
1829 let collaborators = project
1830 .collaborators
1831 .iter()
1832 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1833 .map(|collaborator| collaborator.to_proto())
1834 .collect::<Vec<_>>();
1835 let project_id = project.id;
1836 let guest_user_id = session.user_id();
1837
1838 let worktrees = project
1839 .worktrees
1840 .iter()
1841 .map(|(id, worktree)| proto::WorktreeMetadata {
1842 id: *id,
1843 root_name: worktree.root_name.clone(),
1844 visible: worktree.visible,
1845 abs_path: worktree.abs_path.clone(),
1846 })
1847 .collect::<Vec<_>>();
1848
1849 let add_project_collaborator = proto::AddProjectCollaborator {
1850 project_id: project_id.to_proto(),
1851 collaborator: Some(proto::Collaborator {
1852 peer_id: Some(session.connection_id.into()),
1853 replica_id: replica_id.0 as u32,
1854 user_id: guest_user_id.to_proto(),
1855 is_host: false,
1856 }),
1857 };
1858
1859 for collaborator in &collaborators {
1860 session
1861 .peer
1862 .send(
1863 collaborator.peer_id.unwrap().into(),
1864 add_project_collaborator.clone(),
1865 )
1866 .trace_err();
1867 }
1868
1869 // First, we send the metadata associated with each worktree.
1870 response.send(proto::JoinProjectResponse {
1871 project_id: project.id.0 as u64,
1872 worktrees: worktrees.clone(),
1873 replica_id: replica_id.0 as u32,
1874 collaborators: collaborators.clone(),
1875 language_servers: project.language_servers.clone(),
1876 role: project.role.into(),
1877 })?;
1878
1879 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1880 // Stream this worktree's entries.
1881 let message = proto::UpdateWorktree {
1882 project_id: project_id.to_proto(),
1883 worktree_id,
1884 abs_path: worktree.abs_path.clone(),
1885 root_name: worktree.root_name,
1886 updated_entries: worktree.entries,
1887 removed_entries: Default::default(),
1888 scan_id: worktree.scan_id,
1889 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1890 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1891 removed_repositories: Default::default(),
1892 };
1893 for update in proto::split_worktree_update(message) {
1894 session.peer.send(session.connection_id, update.clone())?;
1895 }
1896
1897 // Stream this worktree's diagnostics.
1898 for summary in worktree.diagnostic_summaries {
1899 session.peer.send(
1900 session.connection_id,
1901 proto::UpdateDiagnosticSummary {
1902 project_id: project_id.to_proto(),
1903 worktree_id: worktree.id,
1904 summary: Some(summary),
1905 },
1906 )?;
1907 }
1908
1909 for settings_file in worktree.settings_files {
1910 session.peer.send(
1911 session.connection_id,
1912 proto::UpdateWorktreeSettings {
1913 project_id: project_id.to_proto(),
1914 worktree_id: worktree.id,
1915 path: settings_file.path,
1916 content: Some(settings_file.content),
1917 kind: Some(settings_file.kind.to_proto() as i32),
1918 },
1919 )?;
1920 }
1921 }
1922
1923 for repository in mem::take(&mut project.repositories) {
1924 for update in split_repository_update(repository) {
1925 session.peer.send(session.connection_id, update)?;
1926 }
1927 }
1928
1929 for language_server in &project.language_servers {
1930 session.peer.send(
1931 session.connection_id,
1932 proto::UpdateLanguageServer {
1933 project_id: project_id.to_proto(),
1934 language_server_id: language_server.id,
1935 variant: Some(
1936 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1937 proto::LspDiskBasedDiagnosticsUpdated {},
1938 ),
1939 ),
1940 },
1941 )?;
1942 }
1943
1944 Ok(())
1945}
1946
1947/// Leave someone elses shared project.
1948async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1949 let sender_id = session.connection_id;
1950 let project_id = ProjectId::from_proto(request.project_id);
1951 let db = session.db().await;
1952
1953 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1954 tracing::info!(
1955 %project_id,
1956 "leave project"
1957 );
1958
1959 project_left(project, &session);
1960 if let Some(room) = room {
1961 room_updated(room, &session.peer);
1962 }
1963
1964 Ok(())
1965}
1966
1967/// Updates other participants with changes to the project
1968async fn update_project(
1969 request: proto::UpdateProject,
1970 response: Response<proto::UpdateProject>,
1971 session: Session,
1972) -> Result<()> {
1973 let project_id = ProjectId::from_proto(request.project_id);
1974 let (room, guest_connection_ids) = &*session
1975 .db()
1976 .await
1977 .update_project(project_id, session.connection_id, &request.worktrees)
1978 .await?;
1979 broadcast(
1980 Some(session.connection_id),
1981 guest_connection_ids.iter().copied(),
1982 |connection_id| {
1983 session
1984 .peer
1985 .forward_send(session.connection_id, connection_id, request.clone())
1986 },
1987 );
1988 if let Some(room) = room {
1989 room_updated(room, &session.peer);
1990 }
1991 response.send(proto::Ack {})?;
1992
1993 Ok(())
1994}
1995
1996/// Updates other participants with changes to the worktree
1997async fn update_worktree(
1998 request: proto::UpdateWorktree,
1999 response: Response<proto::UpdateWorktree>,
2000 session: Session,
2001) -> Result<()> {
2002 let guest_connection_ids = session
2003 .db()
2004 .await
2005 .update_worktree(&request, session.connection_id)
2006 .await?;
2007
2008 broadcast(
2009 Some(session.connection_id),
2010 guest_connection_ids.iter().copied(),
2011 |connection_id| {
2012 session
2013 .peer
2014 .forward_send(session.connection_id, connection_id, request.clone())
2015 },
2016 );
2017 response.send(proto::Ack {})?;
2018 Ok(())
2019}
2020
2021async fn update_repository(
2022 request: proto::UpdateRepository,
2023 response: Response<proto::UpdateRepository>,
2024 session: Session,
2025) -> Result<()> {
2026 let guest_connection_ids = session
2027 .db()
2028 .await
2029 .update_repository(&request, session.connection_id)
2030 .await?;
2031
2032 broadcast(
2033 Some(session.connection_id),
2034 guest_connection_ids.iter().copied(),
2035 |connection_id| {
2036 session
2037 .peer
2038 .forward_send(session.connection_id, connection_id, request.clone())
2039 },
2040 );
2041 response.send(proto::Ack {})?;
2042 Ok(())
2043}
2044
2045async fn remove_repository(
2046 request: proto::RemoveRepository,
2047 response: Response<proto::RemoveRepository>,
2048 session: Session,
2049) -> Result<()> {
2050 let guest_connection_ids = session
2051 .db()
2052 .await
2053 .remove_repository(&request, session.connection_id)
2054 .await?;
2055
2056 broadcast(
2057 Some(session.connection_id),
2058 guest_connection_ids.iter().copied(),
2059 |connection_id| {
2060 session
2061 .peer
2062 .forward_send(session.connection_id, connection_id, request.clone())
2063 },
2064 );
2065 response.send(proto::Ack {})?;
2066 Ok(())
2067}
2068
2069/// Updates other participants with changes to the diagnostics
2070async fn update_diagnostic_summary(
2071 message: proto::UpdateDiagnosticSummary,
2072 session: Session,
2073) -> Result<()> {
2074 let guest_connection_ids = session
2075 .db()
2076 .await
2077 .update_diagnostic_summary(&message, session.connection_id)
2078 .await?;
2079
2080 broadcast(
2081 Some(session.connection_id),
2082 guest_connection_ids.iter().copied(),
2083 |connection_id| {
2084 session
2085 .peer
2086 .forward_send(session.connection_id, connection_id, message.clone())
2087 },
2088 );
2089
2090 Ok(())
2091}
2092
2093/// Updates other participants with changes to the worktree settings
2094async fn update_worktree_settings(
2095 message: proto::UpdateWorktreeSettings,
2096 session: Session,
2097) -> Result<()> {
2098 let guest_connection_ids = session
2099 .db()
2100 .await
2101 .update_worktree_settings(&message, session.connection_id)
2102 .await?;
2103
2104 broadcast(
2105 Some(session.connection_id),
2106 guest_connection_ids.iter().copied(),
2107 |connection_id| {
2108 session
2109 .peer
2110 .forward_send(session.connection_id, connection_id, message.clone())
2111 },
2112 );
2113
2114 Ok(())
2115}
2116
2117/// Notify other participants that a language server has started.
2118async fn start_language_server(
2119 request: proto::StartLanguageServer,
2120 session: Session,
2121) -> Result<()> {
2122 let guest_connection_ids = session
2123 .db()
2124 .await
2125 .start_language_server(&request, session.connection_id)
2126 .await?;
2127
2128 broadcast(
2129 Some(session.connection_id),
2130 guest_connection_ids.iter().copied(),
2131 |connection_id| {
2132 session
2133 .peer
2134 .forward_send(session.connection_id, connection_id, request.clone())
2135 },
2136 );
2137 Ok(())
2138}
2139
2140/// Notify other participants that a language server has changed.
2141async fn update_language_server(
2142 request: proto::UpdateLanguageServer,
2143 session: Session,
2144) -> Result<()> {
2145 let project_id = ProjectId::from_proto(request.project_id);
2146 let project_connection_ids = session
2147 .db()
2148 .await
2149 .project_connection_ids(project_id, session.connection_id, true)
2150 .await?;
2151 broadcast(
2152 Some(session.connection_id),
2153 project_connection_ids.iter().copied(),
2154 |connection_id| {
2155 session
2156 .peer
2157 .forward_send(session.connection_id, connection_id, request.clone())
2158 },
2159 );
2160 Ok(())
2161}
2162
2163/// forward a project request to the host. These requests should be read only
2164/// as guests are allowed to send them.
2165async fn forward_read_only_project_request<T>(
2166 request: T,
2167 response: Response<T>,
2168 session: Session,
2169) -> Result<()>
2170where
2171 T: EntityMessage + RequestMessage,
2172{
2173 let project_id = ProjectId::from_proto(request.remote_entity_id());
2174 let host_connection_id = session
2175 .db()
2176 .await
2177 .host_for_read_only_project_request(project_id, session.connection_id)
2178 .await?;
2179 let payload = session
2180 .peer
2181 .forward_request(session.connection_id, host_connection_id, request)
2182 .await?;
2183 response.send(payload)?;
2184 Ok(())
2185}
2186
2187async fn forward_find_search_candidates_request(
2188 request: proto::FindSearchCandidates,
2189 response: Response<proto::FindSearchCandidates>,
2190 session: Session,
2191) -> Result<()> {
2192 let project_id = ProjectId::from_proto(request.remote_entity_id());
2193 let host_connection_id = session
2194 .db()
2195 .await
2196 .host_for_read_only_project_request(project_id, session.connection_id)
2197 .await?;
2198 let payload = session
2199 .peer
2200 .forward_request(session.connection_id, host_connection_id, request)
2201 .await?;
2202 response.send(payload)?;
2203 Ok(())
2204}
2205
2206/// forward a project request to the host. These requests are disallowed
2207/// for guests.
2208async fn forward_mutating_project_request<T>(
2209 request: T,
2210 response: Response<T>,
2211 session: Session,
2212) -> Result<()>
2213where
2214 T: EntityMessage + RequestMessage,
2215{
2216 let project_id = ProjectId::from_proto(request.remote_entity_id());
2217
2218 let host_connection_id = session
2219 .db()
2220 .await
2221 .host_for_mutating_project_request(project_id, session.connection_id)
2222 .await?;
2223 let payload = session
2224 .peer
2225 .forward_request(session.connection_id, host_connection_id, request)
2226 .await?;
2227 response.send(payload)?;
2228 Ok(())
2229}
2230
2231/// Notify other participants that a new buffer has been created
2232async fn create_buffer_for_peer(
2233 request: proto::CreateBufferForPeer,
2234 session: Session,
2235) -> Result<()> {
2236 session
2237 .db()
2238 .await
2239 .check_user_is_project_host(
2240 ProjectId::from_proto(request.project_id),
2241 session.connection_id,
2242 )
2243 .await?;
2244 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2245 session
2246 .peer
2247 .forward_send(session.connection_id, peer_id.into(), request)?;
2248 Ok(())
2249}
2250
2251/// Notify other participants that a buffer has been updated. This is
2252/// allowed for guests as long as the update is limited to selections.
2253async fn update_buffer(
2254 request: proto::UpdateBuffer,
2255 response: Response<proto::UpdateBuffer>,
2256 session: Session,
2257) -> Result<()> {
2258 let project_id = ProjectId::from_proto(request.project_id);
2259 let mut capability = Capability::ReadOnly;
2260
2261 for op in request.operations.iter() {
2262 match op.variant {
2263 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2264 Some(_) => capability = Capability::ReadWrite,
2265 }
2266 }
2267
2268 let host = {
2269 let guard = session
2270 .db()
2271 .await
2272 .connections_for_buffer_update(project_id, session.connection_id, capability)
2273 .await?;
2274
2275 let (host, guests) = &*guard;
2276
2277 broadcast(
2278 Some(session.connection_id),
2279 guests.clone(),
2280 |connection_id| {
2281 session
2282 .peer
2283 .forward_send(session.connection_id, connection_id, request.clone())
2284 },
2285 );
2286
2287 *host
2288 };
2289
2290 if host != session.connection_id {
2291 session
2292 .peer
2293 .forward_request(session.connection_id, host, request.clone())
2294 .await?;
2295 }
2296
2297 response.send(proto::Ack {})?;
2298 Ok(())
2299}
2300
2301async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2302 let project_id = ProjectId::from_proto(message.project_id);
2303
2304 let operation = message.operation.as_ref().context("invalid operation")?;
2305 let capability = match operation.variant.as_ref() {
2306 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2307 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2308 match buffer_op.variant {
2309 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2310 Capability::ReadOnly
2311 }
2312 _ => Capability::ReadWrite,
2313 }
2314 } else {
2315 Capability::ReadWrite
2316 }
2317 }
2318 Some(_) => Capability::ReadWrite,
2319 None => Capability::ReadOnly,
2320 };
2321
2322 let guard = session
2323 .db()
2324 .await
2325 .connections_for_buffer_update(project_id, session.connection_id, capability)
2326 .await?;
2327
2328 let (host, guests) = &*guard;
2329
2330 broadcast(
2331 Some(session.connection_id),
2332 guests.iter().chain([host]).copied(),
2333 |connection_id| {
2334 session
2335 .peer
2336 .forward_send(session.connection_id, connection_id, message.clone())
2337 },
2338 );
2339
2340 Ok(())
2341}
2342
2343/// Notify other participants that a project has been updated.
2344async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2345 request: T,
2346 session: Session,
2347) -> Result<()> {
2348 let project_id = ProjectId::from_proto(request.remote_entity_id());
2349 let project_connection_ids = session
2350 .db()
2351 .await
2352 .project_connection_ids(project_id, session.connection_id, false)
2353 .await?;
2354
2355 broadcast(
2356 Some(session.connection_id),
2357 project_connection_ids.iter().copied(),
2358 |connection_id| {
2359 session
2360 .peer
2361 .forward_send(session.connection_id, connection_id, request.clone())
2362 },
2363 );
2364 Ok(())
2365}
2366
2367/// Start following another user in a call.
2368async fn follow(
2369 request: proto::Follow,
2370 response: Response<proto::Follow>,
2371 session: Session,
2372) -> Result<()> {
2373 let room_id = RoomId::from_proto(request.room_id);
2374 let project_id = request.project_id.map(ProjectId::from_proto);
2375 let leader_id = request
2376 .leader_id
2377 .ok_or_else(|| anyhow!("invalid leader id"))?
2378 .into();
2379 let follower_id = session.connection_id;
2380
2381 session
2382 .db()
2383 .await
2384 .check_room_participants(room_id, leader_id, session.connection_id)
2385 .await?;
2386
2387 let response_payload = session
2388 .peer
2389 .forward_request(session.connection_id, leader_id, request)
2390 .await?;
2391 response.send(response_payload)?;
2392
2393 if let Some(project_id) = project_id {
2394 let room = session
2395 .db()
2396 .await
2397 .follow(room_id, project_id, leader_id, follower_id)
2398 .await?;
2399 room_updated(&room, &session.peer);
2400 }
2401
2402 Ok(())
2403}
2404
2405/// Stop following another user in a call.
2406async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2407 let room_id = RoomId::from_proto(request.room_id);
2408 let project_id = request.project_id.map(ProjectId::from_proto);
2409 let leader_id = request
2410 .leader_id
2411 .ok_or_else(|| anyhow!("invalid leader id"))?
2412 .into();
2413 let follower_id = session.connection_id;
2414
2415 session
2416 .db()
2417 .await
2418 .check_room_participants(room_id, leader_id, session.connection_id)
2419 .await?;
2420
2421 session
2422 .peer
2423 .forward_send(session.connection_id, leader_id, request)?;
2424
2425 if let Some(project_id) = project_id {
2426 let room = session
2427 .db()
2428 .await
2429 .unfollow(room_id, project_id, leader_id, follower_id)
2430 .await?;
2431 room_updated(&room, &session.peer);
2432 }
2433
2434 Ok(())
2435}
2436
2437/// Notify everyone following you of your current location.
2438async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2439 let room_id = RoomId::from_proto(request.room_id);
2440 let database = session.db.lock().await;
2441
2442 let connection_ids = if let Some(project_id) = request.project_id {
2443 let project_id = ProjectId::from_proto(project_id);
2444 database
2445 .project_connection_ids(project_id, session.connection_id, true)
2446 .await?
2447 } else {
2448 database
2449 .room_connection_ids(room_id, session.connection_id)
2450 .await?
2451 };
2452
2453 // For now, don't send view update messages back to that view's current leader.
2454 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2455 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2456 _ => None,
2457 });
2458
2459 for connection_id in connection_ids.iter().cloned() {
2460 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2461 session
2462 .peer
2463 .forward_send(session.connection_id, connection_id, request.clone())?;
2464 }
2465 }
2466 Ok(())
2467}
2468
2469/// Get public data about users.
2470async fn get_users(
2471 request: proto::GetUsers,
2472 response: Response<proto::GetUsers>,
2473 session: Session,
2474) -> Result<()> {
2475 let user_ids = request
2476 .user_ids
2477 .into_iter()
2478 .map(UserId::from_proto)
2479 .collect();
2480 let users = session
2481 .db()
2482 .await
2483 .get_users_by_ids(user_ids)
2484 .await?
2485 .into_iter()
2486 .map(|user| proto::User {
2487 id: user.id.to_proto(),
2488 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2489 github_login: user.github_login,
2490 email: user.email_address,
2491 name: user.name,
2492 })
2493 .collect();
2494 response.send(proto::UsersResponse { users })?;
2495 Ok(())
2496}
2497
2498/// Search for users (to invite) buy Github login
2499async fn fuzzy_search_users(
2500 request: proto::FuzzySearchUsers,
2501 response: Response<proto::FuzzySearchUsers>,
2502 session: Session,
2503) -> Result<()> {
2504 let query = request.query;
2505 let users = match query.len() {
2506 0 => vec![],
2507 1 | 2 => session
2508 .db()
2509 .await
2510 .get_user_by_github_login(&query)
2511 .await?
2512 .into_iter()
2513 .collect(),
2514 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2515 };
2516 let users = users
2517 .into_iter()
2518 .filter(|user| user.id != session.user_id())
2519 .map(|user| proto::User {
2520 id: user.id.to_proto(),
2521 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2522 github_login: user.github_login,
2523 name: user.name,
2524 email: user.email_address,
2525 })
2526 .collect();
2527 response.send(proto::UsersResponse { users })?;
2528 Ok(())
2529}
2530
2531/// Send a contact request to another user.
2532async fn request_contact(
2533 request: proto::RequestContact,
2534 response: Response<proto::RequestContact>,
2535 session: Session,
2536) -> Result<()> {
2537 let requester_id = session.user_id();
2538 let responder_id = UserId::from_proto(request.responder_id);
2539 if requester_id == responder_id {
2540 return Err(anyhow!("cannot add yourself as a contact"))?;
2541 }
2542
2543 let notifications = session
2544 .db()
2545 .await
2546 .send_contact_request(requester_id, responder_id)
2547 .await?;
2548
2549 // Update outgoing contact requests of requester
2550 let mut update = proto::UpdateContacts::default();
2551 update.outgoing_requests.push(responder_id.to_proto());
2552 for connection_id in session
2553 .connection_pool()
2554 .await
2555 .user_connection_ids(requester_id)
2556 {
2557 session.peer.send(connection_id, update.clone())?;
2558 }
2559
2560 // Update incoming contact requests of responder
2561 let mut update = proto::UpdateContacts::default();
2562 update
2563 .incoming_requests
2564 .push(proto::IncomingContactRequest {
2565 requester_id: requester_id.to_proto(),
2566 });
2567 let connection_pool = session.connection_pool().await;
2568 for connection_id in connection_pool.user_connection_ids(responder_id) {
2569 session.peer.send(connection_id, update.clone())?;
2570 }
2571
2572 send_notifications(&connection_pool, &session.peer, notifications);
2573
2574 response.send(proto::Ack {})?;
2575 Ok(())
2576}
2577
2578/// Accept or decline a contact request
2579async fn respond_to_contact_request(
2580 request: proto::RespondToContactRequest,
2581 response: Response<proto::RespondToContactRequest>,
2582 session: Session,
2583) -> Result<()> {
2584 let responder_id = session.user_id();
2585 let requester_id = UserId::from_proto(request.requester_id);
2586 let db = session.db().await;
2587 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2588 db.dismiss_contact_notification(responder_id, requester_id)
2589 .await?;
2590 } else {
2591 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2592
2593 let notifications = db
2594 .respond_to_contact_request(responder_id, requester_id, accept)
2595 .await?;
2596 let requester_busy = db.is_user_busy(requester_id).await?;
2597 let responder_busy = db.is_user_busy(responder_id).await?;
2598
2599 let pool = session.connection_pool().await;
2600 // Update responder with new contact
2601 let mut update = proto::UpdateContacts::default();
2602 if accept {
2603 update
2604 .contacts
2605 .push(contact_for_user(requester_id, requester_busy, &pool));
2606 }
2607 update
2608 .remove_incoming_requests
2609 .push(requester_id.to_proto());
2610 for connection_id in pool.user_connection_ids(responder_id) {
2611 session.peer.send(connection_id, update.clone())?;
2612 }
2613
2614 // Update requester with new contact
2615 let mut update = proto::UpdateContacts::default();
2616 if accept {
2617 update
2618 .contacts
2619 .push(contact_for_user(responder_id, responder_busy, &pool));
2620 }
2621 update
2622 .remove_outgoing_requests
2623 .push(responder_id.to_proto());
2624
2625 for connection_id in pool.user_connection_ids(requester_id) {
2626 session.peer.send(connection_id, update.clone())?;
2627 }
2628
2629 send_notifications(&pool, &session.peer, notifications);
2630 }
2631
2632 response.send(proto::Ack {})?;
2633 Ok(())
2634}
2635
2636/// Remove a contact.
2637async fn remove_contact(
2638 request: proto::RemoveContact,
2639 response: Response<proto::RemoveContact>,
2640 session: Session,
2641) -> Result<()> {
2642 let requester_id = session.user_id();
2643 let responder_id = UserId::from_proto(request.user_id);
2644 let db = session.db().await;
2645 let (contact_accepted, deleted_notification_id) =
2646 db.remove_contact(requester_id, responder_id).await?;
2647
2648 let pool = session.connection_pool().await;
2649 // Update outgoing contact requests of requester
2650 let mut update = proto::UpdateContacts::default();
2651 if contact_accepted {
2652 update.remove_contacts.push(responder_id.to_proto());
2653 } else {
2654 update
2655 .remove_outgoing_requests
2656 .push(responder_id.to_proto());
2657 }
2658 for connection_id in pool.user_connection_ids(requester_id) {
2659 session.peer.send(connection_id, update.clone())?;
2660 }
2661
2662 // Update incoming contact requests of responder
2663 let mut update = proto::UpdateContacts::default();
2664 if contact_accepted {
2665 update.remove_contacts.push(requester_id.to_proto());
2666 } else {
2667 update
2668 .remove_incoming_requests
2669 .push(requester_id.to_proto());
2670 }
2671 for connection_id in pool.user_connection_ids(responder_id) {
2672 session.peer.send(connection_id, update.clone())?;
2673 if let Some(notification_id) = deleted_notification_id {
2674 session.peer.send(
2675 connection_id,
2676 proto::DeleteNotification {
2677 notification_id: notification_id.to_proto(),
2678 },
2679 )?;
2680 }
2681 }
2682
2683 response.send(proto::Ack {})?;
2684 Ok(())
2685}
2686
2687fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2688 version.0.minor() < 139
2689}
2690
2691async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
2692 let db = session.db().await;
2693
2694 let feature_flags = db.get_user_flags(user_id).await?;
2695 let plan = session.current_plan(&db).await?;
2696 let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
2697 let billing_preferences = db.get_billing_preferences(user_id).await?;
2698
2699 let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() {
2700 let subscription = db.get_active_billing_subscription(user_id).await?;
2701
2702 let subscription_period = crate::db::billing_subscription::Model::current_period(
2703 subscription,
2704 session.is_staff(),
2705 );
2706
2707 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2708 llm_db
2709 .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
2710 .await?
2711 } else {
2712 None
2713 };
2714
2715 (subscription_period, usage)
2716 } else {
2717 (None, None)
2718 };
2719
2720 session
2721 .peer
2722 .send(
2723 session.connection_id,
2724 proto::UpdateUserPlan {
2725 plan: plan.into(),
2726 trial_started_at: billing_customer
2727 .and_then(|billing_customer| billing_customer.trial_started_at)
2728 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2729 is_usage_based_billing_enabled: if session.is_staff() {
2730 Some(true)
2731 } else {
2732 billing_preferences
2733 .map(|preferences| preferences.model_request_overages_enabled)
2734 },
2735 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2736 proto::SubscriptionPeriod {
2737 started_at: started_at.timestamp() as u64,
2738 ended_at: ended_at.timestamp() as u64,
2739 }
2740 }),
2741 usage: usage.map(|usage| {
2742 let plan = match plan {
2743 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2744 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2745 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2746 };
2747
2748 let model_requests_limit = match plan.model_requests_limit() {
2749 zed_llm_client::UsageLimit::Limited(limit) => {
2750 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2751 && feature_flags
2752 .iter()
2753 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2754 {
2755 1_000
2756 } else {
2757 limit
2758 };
2759
2760 zed_llm_client::UsageLimit::Limited(limit)
2761 }
2762 zed_llm_client::UsageLimit::Unlimited => {
2763 zed_llm_client::UsageLimit::Unlimited
2764 }
2765 };
2766
2767 proto::SubscriptionUsage {
2768 model_requests_usage_amount: usage.model_requests as u32,
2769 model_requests_usage_limit: Some(proto::UsageLimit {
2770 variant: Some(match model_requests_limit {
2771 zed_llm_client::UsageLimit::Limited(limit) => {
2772 proto::usage_limit::Variant::Limited(
2773 proto::usage_limit::Limited {
2774 limit: limit as u32,
2775 },
2776 )
2777 }
2778 zed_llm_client::UsageLimit::Unlimited => {
2779 proto::usage_limit::Variant::Unlimited(
2780 proto::usage_limit::Unlimited {},
2781 )
2782 }
2783 }),
2784 }),
2785 edit_predictions_usage_amount: usage.edit_predictions as u32,
2786 edit_predictions_usage_limit: Some(proto::UsageLimit {
2787 variant: Some(match plan.edit_predictions_limit() {
2788 zed_llm_client::UsageLimit::Limited(limit) => {
2789 proto::usage_limit::Variant::Limited(
2790 proto::usage_limit::Limited {
2791 limit: limit as u32,
2792 },
2793 )
2794 }
2795 zed_llm_client::UsageLimit::Unlimited => {
2796 proto::usage_limit::Variant::Unlimited(
2797 proto::usage_limit::Unlimited {},
2798 )
2799 }
2800 }),
2801 }),
2802 }
2803 }),
2804 },
2805 )
2806 .trace_err();
2807
2808 Ok(())
2809}
2810
2811async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2812 subscribe_user_to_channels(session.user_id(), &session).await?;
2813 Ok(())
2814}
2815
2816async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2817 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2818 let mut pool = session.connection_pool().await;
2819 for membership in &channels_for_user.channel_memberships {
2820 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2821 }
2822 session.peer.send(
2823 session.connection_id,
2824 build_update_user_channels(&channels_for_user),
2825 )?;
2826 session.peer.send(
2827 session.connection_id,
2828 build_channels_update(channels_for_user),
2829 )?;
2830 Ok(())
2831}
2832
2833/// Creates a new channel.
2834async fn create_channel(
2835 request: proto::CreateChannel,
2836 response: Response<proto::CreateChannel>,
2837 session: Session,
2838) -> Result<()> {
2839 let db = session.db().await;
2840
2841 let parent_id = request.parent_id.map(ChannelId::from_proto);
2842 let (channel, membership) = db
2843 .create_channel(&request.name, parent_id, session.user_id())
2844 .await?;
2845
2846 let root_id = channel.root_id();
2847 let channel = Channel::from_model(channel);
2848
2849 response.send(proto::CreateChannelResponse {
2850 channel: Some(channel.to_proto()),
2851 parent_id: request.parent_id,
2852 })?;
2853
2854 let mut connection_pool = session.connection_pool().await;
2855 if let Some(membership) = membership {
2856 connection_pool.subscribe_to_channel(
2857 membership.user_id,
2858 membership.channel_id,
2859 membership.role,
2860 );
2861 let update = proto::UpdateUserChannels {
2862 channel_memberships: vec![proto::ChannelMembership {
2863 channel_id: membership.channel_id.to_proto(),
2864 role: membership.role.into(),
2865 }],
2866 ..Default::default()
2867 };
2868 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2869 session.peer.send(connection_id, update.clone())?;
2870 }
2871 }
2872
2873 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2874 if !role.can_see_channel(channel.visibility) {
2875 continue;
2876 }
2877
2878 let update = proto::UpdateChannels {
2879 channels: vec![channel.to_proto()],
2880 ..Default::default()
2881 };
2882 session.peer.send(connection_id, update.clone())?;
2883 }
2884
2885 Ok(())
2886}
2887
2888/// Delete a channel
2889async fn delete_channel(
2890 request: proto::DeleteChannel,
2891 response: Response<proto::DeleteChannel>,
2892 session: Session,
2893) -> Result<()> {
2894 let db = session.db().await;
2895
2896 let channel_id = request.channel_id;
2897 let (root_channel, removed_channels) = db
2898 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2899 .await?;
2900 response.send(proto::Ack {})?;
2901
2902 // Notify members of removed channels
2903 let mut update = proto::UpdateChannels::default();
2904 update
2905 .delete_channels
2906 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2907
2908 let connection_pool = session.connection_pool().await;
2909 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2910 session.peer.send(connection_id, update.clone())?;
2911 }
2912
2913 Ok(())
2914}
2915
2916/// Invite someone to join a channel.
2917async fn invite_channel_member(
2918 request: proto::InviteChannelMember,
2919 response: Response<proto::InviteChannelMember>,
2920 session: Session,
2921) -> Result<()> {
2922 let db = session.db().await;
2923 let channel_id = ChannelId::from_proto(request.channel_id);
2924 let invitee_id = UserId::from_proto(request.user_id);
2925 let InviteMemberResult {
2926 channel,
2927 notifications,
2928 } = db
2929 .invite_channel_member(
2930 channel_id,
2931 invitee_id,
2932 session.user_id(),
2933 request.role().into(),
2934 )
2935 .await?;
2936
2937 let update = proto::UpdateChannels {
2938 channel_invitations: vec![channel.to_proto()],
2939 ..Default::default()
2940 };
2941
2942 let connection_pool = session.connection_pool().await;
2943 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2944 session.peer.send(connection_id, update.clone())?;
2945 }
2946
2947 send_notifications(&connection_pool, &session.peer, notifications);
2948
2949 response.send(proto::Ack {})?;
2950 Ok(())
2951}
2952
2953/// remove someone from a channel
2954async fn remove_channel_member(
2955 request: proto::RemoveChannelMember,
2956 response: Response<proto::RemoveChannelMember>,
2957 session: Session,
2958) -> Result<()> {
2959 let db = session.db().await;
2960 let channel_id = ChannelId::from_proto(request.channel_id);
2961 let member_id = UserId::from_proto(request.user_id);
2962
2963 let RemoveChannelMemberResult {
2964 membership_update,
2965 notification_id,
2966 } = db
2967 .remove_channel_member(channel_id, member_id, session.user_id())
2968 .await?;
2969
2970 let mut connection_pool = session.connection_pool().await;
2971 notify_membership_updated(
2972 &mut connection_pool,
2973 membership_update,
2974 member_id,
2975 &session.peer,
2976 );
2977 for connection_id in connection_pool.user_connection_ids(member_id) {
2978 if let Some(notification_id) = notification_id {
2979 session
2980 .peer
2981 .send(
2982 connection_id,
2983 proto::DeleteNotification {
2984 notification_id: notification_id.to_proto(),
2985 },
2986 )
2987 .trace_err();
2988 }
2989 }
2990
2991 response.send(proto::Ack {})?;
2992 Ok(())
2993}
2994
2995/// Toggle the channel between public and private.
2996/// Care is taken to maintain the invariant that public channels only descend from public channels,
2997/// (though members-only channels can appear at any point in the hierarchy).
2998async fn set_channel_visibility(
2999 request: proto::SetChannelVisibility,
3000 response: Response<proto::SetChannelVisibility>,
3001 session: Session,
3002) -> Result<()> {
3003 let db = session.db().await;
3004 let channel_id = ChannelId::from_proto(request.channel_id);
3005 let visibility = request.visibility().into();
3006
3007 let channel_model = db
3008 .set_channel_visibility(channel_id, visibility, session.user_id())
3009 .await?;
3010 let root_id = channel_model.root_id();
3011 let channel = Channel::from_model(channel_model);
3012
3013 let mut connection_pool = session.connection_pool().await;
3014 for (user_id, role) in connection_pool
3015 .channel_user_ids(root_id)
3016 .collect::<Vec<_>>()
3017 .into_iter()
3018 {
3019 let update = if role.can_see_channel(channel.visibility) {
3020 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3021 proto::UpdateChannels {
3022 channels: vec![channel.to_proto()],
3023 ..Default::default()
3024 }
3025 } else {
3026 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3027 proto::UpdateChannels {
3028 delete_channels: vec![channel.id.to_proto()],
3029 ..Default::default()
3030 }
3031 };
3032
3033 for connection_id in connection_pool.user_connection_ids(user_id) {
3034 session.peer.send(connection_id, update.clone())?;
3035 }
3036 }
3037
3038 response.send(proto::Ack {})?;
3039 Ok(())
3040}
3041
3042/// Alter the role for a user in the channel.
3043async fn set_channel_member_role(
3044 request: proto::SetChannelMemberRole,
3045 response: Response<proto::SetChannelMemberRole>,
3046 session: Session,
3047) -> Result<()> {
3048 let db = session.db().await;
3049 let channel_id = ChannelId::from_proto(request.channel_id);
3050 let member_id = UserId::from_proto(request.user_id);
3051 let result = db
3052 .set_channel_member_role(
3053 channel_id,
3054 session.user_id(),
3055 member_id,
3056 request.role().into(),
3057 )
3058 .await?;
3059
3060 match result {
3061 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3062 let mut connection_pool = session.connection_pool().await;
3063 notify_membership_updated(
3064 &mut connection_pool,
3065 membership_update,
3066 member_id,
3067 &session.peer,
3068 )
3069 }
3070 db::SetMemberRoleResult::InviteUpdated(channel) => {
3071 let update = proto::UpdateChannels {
3072 channel_invitations: vec![channel.to_proto()],
3073 ..Default::default()
3074 };
3075
3076 for connection_id in session
3077 .connection_pool()
3078 .await
3079 .user_connection_ids(member_id)
3080 {
3081 session.peer.send(connection_id, update.clone())?;
3082 }
3083 }
3084 }
3085
3086 response.send(proto::Ack {})?;
3087 Ok(())
3088}
3089
3090/// Change the name of a channel
3091async fn rename_channel(
3092 request: proto::RenameChannel,
3093 response: Response<proto::RenameChannel>,
3094 session: Session,
3095) -> Result<()> {
3096 let db = session.db().await;
3097 let channel_id = ChannelId::from_proto(request.channel_id);
3098 let channel_model = db
3099 .rename_channel(channel_id, session.user_id(), &request.name)
3100 .await?;
3101 let root_id = channel_model.root_id();
3102 let channel = Channel::from_model(channel_model);
3103
3104 response.send(proto::RenameChannelResponse {
3105 channel: Some(channel.to_proto()),
3106 })?;
3107
3108 let connection_pool = session.connection_pool().await;
3109 let update = proto::UpdateChannels {
3110 channels: vec![channel.to_proto()],
3111 ..Default::default()
3112 };
3113 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3114 if role.can_see_channel(channel.visibility) {
3115 session.peer.send(connection_id, update.clone())?;
3116 }
3117 }
3118
3119 Ok(())
3120}
3121
3122/// Move a channel to a new parent.
3123async fn move_channel(
3124 request: proto::MoveChannel,
3125 response: Response<proto::MoveChannel>,
3126 session: Session,
3127) -> Result<()> {
3128 let channel_id = ChannelId::from_proto(request.channel_id);
3129 let to = ChannelId::from_proto(request.to);
3130
3131 let (root_id, channels) = session
3132 .db()
3133 .await
3134 .move_channel(channel_id, to, session.user_id())
3135 .await?;
3136
3137 let connection_pool = session.connection_pool().await;
3138 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3139 let channels = channels
3140 .iter()
3141 .filter_map(|channel| {
3142 if role.can_see_channel(channel.visibility) {
3143 Some(channel.to_proto())
3144 } else {
3145 None
3146 }
3147 })
3148 .collect::<Vec<_>>();
3149 if channels.is_empty() {
3150 continue;
3151 }
3152
3153 let update = proto::UpdateChannels {
3154 channels,
3155 ..Default::default()
3156 };
3157
3158 session.peer.send(connection_id, update.clone())?;
3159 }
3160
3161 response.send(Ack {})?;
3162 Ok(())
3163}
3164
3165/// Get the list of channel members
3166async fn get_channel_members(
3167 request: proto::GetChannelMembers,
3168 response: Response<proto::GetChannelMembers>,
3169 session: Session,
3170) -> Result<()> {
3171 let db = session.db().await;
3172 let channel_id = ChannelId::from_proto(request.channel_id);
3173 let limit = if request.limit == 0 {
3174 u16::MAX as u64
3175 } else {
3176 request.limit
3177 };
3178 let (members, users) = db
3179 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3180 .await?;
3181 response.send(proto::GetChannelMembersResponse { members, users })?;
3182 Ok(())
3183}
3184
3185/// Accept or decline a channel invitation.
3186async fn respond_to_channel_invite(
3187 request: proto::RespondToChannelInvite,
3188 response: Response<proto::RespondToChannelInvite>,
3189 session: Session,
3190) -> Result<()> {
3191 let db = session.db().await;
3192 let channel_id = ChannelId::from_proto(request.channel_id);
3193 let RespondToChannelInvite {
3194 membership_update,
3195 notifications,
3196 } = db
3197 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3198 .await?;
3199
3200 let mut connection_pool = session.connection_pool().await;
3201 if let Some(membership_update) = membership_update {
3202 notify_membership_updated(
3203 &mut connection_pool,
3204 membership_update,
3205 session.user_id(),
3206 &session.peer,
3207 );
3208 } else {
3209 let update = proto::UpdateChannels {
3210 remove_channel_invitations: vec![channel_id.to_proto()],
3211 ..Default::default()
3212 };
3213
3214 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3215 session.peer.send(connection_id, update.clone())?;
3216 }
3217 };
3218
3219 send_notifications(&connection_pool, &session.peer, notifications);
3220
3221 response.send(proto::Ack {})?;
3222
3223 Ok(())
3224}
3225
3226/// Join the channels' room
3227async fn join_channel(
3228 request: proto::JoinChannel,
3229 response: Response<proto::JoinChannel>,
3230 session: Session,
3231) -> Result<()> {
3232 let channel_id = ChannelId::from_proto(request.channel_id);
3233 join_channel_internal(channel_id, Box::new(response), session).await
3234}
3235
3236trait JoinChannelInternalResponse {
3237 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3238}
3239impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3240 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3241 Response::<proto::JoinChannel>::send(self, result)
3242 }
3243}
3244impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3245 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3246 Response::<proto::JoinRoom>::send(self, result)
3247 }
3248}
3249
3250async fn join_channel_internal(
3251 channel_id: ChannelId,
3252 response: Box<impl JoinChannelInternalResponse>,
3253 session: Session,
3254) -> Result<()> {
3255 let joined_room = {
3256 let mut db = session.db().await;
3257 // If zed quits without leaving the room, and the user re-opens zed before the
3258 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3259 // room they were in.
3260 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3261 tracing::info!(
3262 stale_connection_id = %connection,
3263 "cleaning up stale connection",
3264 );
3265 drop(db);
3266 leave_room_for_session(&session, connection).await?;
3267 db = session.db().await;
3268 }
3269
3270 let (joined_room, membership_updated, role) = db
3271 .join_channel(channel_id, session.user_id(), session.connection_id)
3272 .await?;
3273
3274 let live_kit_connection_info =
3275 session
3276 .app_state
3277 .livekit_client
3278 .as_ref()
3279 .and_then(|live_kit| {
3280 let (can_publish, token) = if role == ChannelRole::Guest {
3281 (
3282 false,
3283 live_kit
3284 .guest_token(
3285 &joined_room.room.livekit_room,
3286 &session.user_id().to_string(),
3287 )
3288 .trace_err()?,
3289 )
3290 } else {
3291 (
3292 true,
3293 live_kit
3294 .room_token(
3295 &joined_room.room.livekit_room,
3296 &session.user_id().to_string(),
3297 )
3298 .trace_err()?,
3299 )
3300 };
3301
3302 Some(LiveKitConnectionInfo {
3303 server_url: live_kit.url().into(),
3304 token,
3305 can_publish,
3306 })
3307 });
3308
3309 response.send(proto::JoinRoomResponse {
3310 room: Some(joined_room.room.clone()),
3311 channel_id: joined_room
3312 .channel
3313 .as_ref()
3314 .map(|channel| channel.id.to_proto()),
3315 live_kit_connection_info,
3316 })?;
3317
3318 let mut connection_pool = session.connection_pool().await;
3319 if let Some(membership_updated) = membership_updated {
3320 notify_membership_updated(
3321 &mut connection_pool,
3322 membership_updated,
3323 session.user_id(),
3324 &session.peer,
3325 );
3326 }
3327
3328 room_updated(&joined_room.room, &session.peer);
3329
3330 joined_room
3331 };
3332
3333 channel_updated(
3334 &joined_room
3335 .channel
3336 .ok_or_else(|| anyhow!("channel not returned"))?,
3337 &joined_room.room,
3338 &session.peer,
3339 &*session.connection_pool().await,
3340 );
3341
3342 update_user_contacts(session.user_id(), &session).await?;
3343 Ok(())
3344}
3345
3346/// Start editing the channel notes
3347async fn join_channel_buffer(
3348 request: proto::JoinChannelBuffer,
3349 response: Response<proto::JoinChannelBuffer>,
3350 session: Session,
3351) -> Result<()> {
3352 let db = session.db().await;
3353 let channel_id = ChannelId::from_proto(request.channel_id);
3354
3355 let open_response = db
3356 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3357 .await?;
3358
3359 let collaborators = open_response.collaborators.clone();
3360 response.send(open_response)?;
3361
3362 let update = UpdateChannelBufferCollaborators {
3363 channel_id: channel_id.to_proto(),
3364 collaborators: collaborators.clone(),
3365 };
3366 channel_buffer_updated(
3367 session.connection_id,
3368 collaborators
3369 .iter()
3370 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3371 &update,
3372 &session.peer,
3373 );
3374
3375 Ok(())
3376}
3377
3378/// Edit the channel notes
3379async fn update_channel_buffer(
3380 request: proto::UpdateChannelBuffer,
3381 session: Session,
3382) -> Result<()> {
3383 let db = session.db().await;
3384 let channel_id = ChannelId::from_proto(request.channel_id);
3385
3386 let (collaborators, epoch, version) = db
3387 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3388 .await?;
3389
3390 channel_buffer_updated(
3391 session.connection_id,
3392 collaborators.clone(),
3393 &proto::UpdateChannelBuffer {
3394 channel_id: channel_id.to_proto(),
3395 operations: request.operations,
3396 },
3397 &session.peer,
3398 );
3399
3400 let pool = &*session.connection_pool().await;
3401
3402 let non_collaborators =
3403 pool.channel_connection_ids(channel_id)
3404 .filter_map(|(connection_id, _)| {
3405 if collaborators.contains(&connection_id) {
3406 None
3407 } else {
3408 Some(connection_id)
3409 }
3410 });
3411
3412 broadcast(None, non_collaborators, |peer_id| {
3413 session.peer.send(
3414 peer_id,
3415 proto::UpdateChannels {
3416 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3417 channel_id: channel_id.to_proto(),
3418 epoch: epoch as u64,
3419 version: version.clone(),
3420 }],
3421 ..Default::default()
3422 },
3423 )
3424 });
3425
3426 Ok(())
3427}
3428
3429/// Rejoin the channel notes after a connection blip
3430async fn rejoin_channel_buffers(
3431 request: proto::RejoinChannelBuffers,
3432 response: Response<proto::RejoinChannelBuffers>,
3433 session: Session,
3434) -> Result<()> {
3435 let db = session.db().await;
3436 let buffers = db
3437 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3438 .await?;
3439
3440 for rejoined_buffer in &buffers {
3441 let collaborators_to_notify = rejoined_buffer
3442 .buffer
3443 .collaborators
3444 .iter()
3445 .filter_map(|c| Some(c.peer_id?.into()));
3446 channel_buffer_updated(
3447 session.connection_id,
3448 collaborators_to_notify,
3449 &proto::UpdateChannelBufferCollaborators {
3450 channel_id: rejoined_buffer.buffer.channel_id,
3451 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3452 },
3453 &session.peer,
3454 );
3455 }
3456
3457 response.send(proto::RejoinChannelBuffersResponse {
3458 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3459 })?;
3460
3461 Ok(())
3462}
3463
3464/// Stop editing the channel notes
3465async fn leave_channel_buffer(
3466 request: proto::LeaveChannelBuffer,
3467 response: Response<proto::LeaveChannelBuffer>,
3468 session: Session,
3469) -> Result<()> {
3470 let db = session.db().await;
3471 let channel_id = ChannelId::from_proto(request.channel_id);
3472
3473 let left_buffer = db
3474 .leave_channel_buffer(channel_id, session.connection_id)
3475 .await?;
3476
3477 response.send(Ack {})?;
3478
3479 channel_buffer_updated(
3480 session.connection_id,
3481 left_buffer.connections,
3482 &proto::UpdateChannelBufferCollaborators {
3483 channel_id: channel_id.to_proto(),
3484 collaborators: left_buffer.collaborators,
3485 },
3486 &session.peer,
3487 );
3488
3489 Ok(())
3490}
3491
3492fn channel_buffer_updated<T: EnvelopedMessage>(
3493 sender_id: ConnectionId,
3494 collaborators: impl IntoIterator<Item = ConnectionId>,
3495 message: &T,
3496 peer: &Peer,
3497) {
3498 broadcast(Some(sender_id), collaborators, |peer_id| {
3499 peer.send(peer_id, message.clone())
3500 });
3501}
3502
3503fn send_notifications(
3504 connection_pool: &ConnectionPool,
3505 peer: &Peer,
3506 notifications: db::NotificationBatch,
3507) {
3508 for (user_id, notification) in notifications {
3509 for connection_id in connection_pool.user_connection_ids(user_id) {
3510 if let Err(error) = peer.send(
3511 connection_id,
3512 proto::AddNotification {
3513 notification: Some(notification.clone()),
3514 },
3515 ) {
3516 tracing::error!(
3517 "failed to send notification to {:?} {}",
3518 connection_id,
3519 error
3520 );
3521 }
3522 }
3523 }
3524}
3525
3526/// Send a message to the channel
3527async fn send_channel_message(
3528 request: proto::SendChannelMessage,
3529 response: Response<proto::SendChannelMessage>,
3530 session: Session,
3531) -> Result<()> {
3532 // Validate the message body.
3533 let body = request.body.trim().to_string();
3534 if body.len() > MAX_MESSAGE_LEN {
3535 return Err(anyhow!("message is too long"))?;
3536 }
3537 if body.is_empty() {
3538 return Err(anyhow!("message can't be blank"))?;
3539 }
3540
3541 // TODO: adjust mentions if body is trimmed
3542
3543 let timestamp = OffsetDateTime::now_utc();
3544 let nonce = request
3545 .nonce
3546 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3547
3548 let channel_id = ChannelId::from_proto(request.channel_id);
3549 let CreatedChannelMessage {
3550 message_id,
3551 participant_connection_ids,
3552 notifications,
3553 } = session
3554 .db()
3555 .await
3556 .create_channel_message(
3557 channel_id,
3558 session.user_id(),
3559 &body,
3560 &request.mentions,
3561 timestamp,
3562 nonce.clone().into(),
3563 request.reply_to_message_id.map(MessageId::from_proto),
3564 )
3565 .await?;
3566
3567 let message = proto::ChannelMessage {
3568 sender_id: session.user_id().to_proto(),
3569 id: message_id.to_proto(),
3570 body,
3571 mentions: request.mentions,
3572 timestamp: timestamp.unix_timestamp() as u64,
3573 nonce: Some(nonce),
3574 reply_to_message_id: request.reply_to_message_id,
3575 edited_at: None,
3576 };
3577 broadcast(
3578 Some(session.connection_id),
3579 participant_connection_ids.clone(),
3580 |connection| {
3581 session.peer.send(
3582 connection,
3583 proto::ChannelMessageSent {
3584 channel_id: channel_id.to_proto(),
3585 message: Some(message.clone()),
3586 },
3587 )
3588 },
3589 );
3590 response.send(proto::SendChannelMessageResponse {
3591 message: Some(message),
3592 })?;
3593
3594 let pool = &*session.connection_pool().await;
3595 let non_participants =
3596 pool.channel_connection_ids(channel_id)
3597 .filter_map(|(connection_id, _)| {
3598 if participant_connection_ids.contains(&connection_id) {
3599 None
3600 } else {
3601 Some(connection_id)
3602 }
3603 });
3604 broadcast(None, non_participants, |peer_id| {
3605 session.peer.send(
3606 peer_id,
3607 proto::UpdateChannels {
3608 latest_channel_message_ids: vec![proto::ChannelMessageId {
3609 channel_id: channel_id.to_proto(),
3610 message_id: message_id.to_proto(),
3611 }],
3612 ..Default::default()
3613 },
3614 )
3615 });
3616 send_notifications(pool, &session.peer, notifications);
3617
3618 Ok(())
3619}
3620
3621/// Delete a channel message
3622async fn remove_channel_message(
3623 request: proto::RemoveChannelMessage,
3624 response: Response<proto::RemoveChannelMessage>,
3625 session: Session,
3626) -> Result<()> {
3627 let channel_id = ChannelId::from_proto(request.channel_id);
3628 let message_id = MessageId::from_proto(request.message_id);
3629 let (connection_ids, existing_notification_ids) = session
3630 .db()
3631 .await
3632 .remove_channel_message(channel_id, message_id, session.user_id())
3633 .await?;
3634
3635 broadcast(
3636 Some(session.connection_id),
3637 connection_ids,
3638 move |connection| {
3639 session.peer.send(connection, request.clone())?;
3640
3641 for notification_id in &existing_notification_ids {
3642 session.peer.send(
3643 connection,
3644 proto::DeleteNotification {
3645 notification_id: (*notification_id).to_proto(),
3646 },
3647 )?;
3648 }
3649
3650 Ok(())
3651 },
3652 );
3653 response.send(proto::Ack {})?;
3654 Ok(())
3655}
3656
3657async fn update_channel_message(
3658 request: proto::UpdateChannelMessage,
3659 response: Response<proto::UpdateChannelMessage>,
3660 session: Session,
3661) -> Result<()> {
3662 let channel_id = ChannelId::from_proto(request.channel_id);
3663 let message_id = MessageId::from_proto(request.message_id);
3664 let updated_at = OffsetDateTime::now_utc();
3665 let UpdatedChannelMessage {
3666 message_id,
3667 participant_connection_ids,
3668 notifications,
3669 reply_to_message_id,
3670 timestamp,
3671 deleted_mention_notification_ids,
3672 updated_mention_notifications,
3673 } = session
3674 .db()
3675 .await
3676 .update_channel_message(
3677 channel_id,
3678 message_id,
3679 session.user_id(),
3680 request.body.as_str(),
3681 &request.mentions,
3682 updated_at,
3683 )
3684 .await?;
3685
3686 let nonce = request
3687 .nonce
3688 .clone()
3689 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3690
3691 let message = proto::ChannelMessage {
3692 sender_id: session.user_id().to_proto(),
3693 id: message_id.to_proto(),
3694 body: request.body.clone(),
3695 mentions: request.mentions.clone(),
3696 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3697 nonce: Some(nonce),
3698 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3699 edited_at: Some(updated_at.unix_timestamp() as u64),
3700 };
3701
3702 response.send(proto::Ack {})?;
3703
3704 let pool = &*session.connection_pool().await;
3705 broadcast(
3706 Some(session.connection_id),
3707 participant_connection_ids,
3708 |connection| {
3709 session.peer.send(
3710 connection,
3711 proto::ChannelMessageUpdate {
3712 channel_id: channel_id.to_proto(),
3713 message: Some(message.clone()),
3714 },
3715 )?;
3716
3717 for notification_id in &deleted_mention_notification_ids {
3718 session.peer.send(
3719 connection,
3720 proto::DeleteNotification {
3721 notification_id: (*notification_id).to_proto(),
3722 },
3723 )?;
3724 }
3725
3726 for notification in &updated_mention_notifications {
3727 session.peer.send(
3728 connection,
3729 proto::UpdateNotification {
3730 notification: Some(notification.clone()),
3731 },
3732 )?;
3733 }
3734
3735 Ok(())
3736 },
3737 );
3738
3739 send_notifications(pool, &session.peer, notifications);
3740
3741 Ok(())
3742}
3743
3744/// Mark a channel message as read
3745async fn acknowledge_channel_message(
3746 request: proto::AckChannelMessage,
3747 session: Session,
3748) -> Result<()> {
3749 let channel_id = ChannelId::from_proto(request.channel_id);
3750 let message_id = MessageId::from_proto(request.message_id);
3751 let notifications = session
3752 .db()
3753 .await
3754 .observe_channel_message(channel_id, session.user_id(), message_id)
3755 .await?;
3756 send_notifications(
3757 &*session.connection_pool().await,
3758 &session.peer,
3759 notifications,
3760 );
3761 Ok(())
3762}
3763
3764/// Mark a buffer version as synced
3765async fn acknowledge_buffer_version(
3766 request: proto::AckBufferOperation,
3767 session: Session,
3768) -> Result<()> {
3769 let buffer_id = BufferId::from_proto(request.buffer_id);
3770 session
3771 .db()
3772 .await
3773 .observe_buffer_version(
3774 buffer_id,
3775 session.user_id(),
3776 request.epoch as i32,
3777 &request.version,
3778 )
3779 .await?;
3780 Ok(())
3781}
3782
3783/// Get a Supermaven API key for the user
3784async fn get_supermaven_api_key(
3785 _request: proto::GetSupermavenApiKey,
3786 response: Response<proto::GetSupermavenApiKey>,
3787 session: Session,
3788) -> Result<()> {
3789 let user_id: String = session.user_id().to_string();
3790 if !session.is_staff() {
3791 return Err(anyhow!("supermaven not enabled for this account"))?;
3792 }
3793
3794 let email = session
3795 .email()
3796 .ok_or_else(|| anyhow!("user must have an email"))?;
3797
3798 let supermaven_admin_api = session
3799 .supermaven_client
3800 .as_ref()
3801 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3802
3803 let result = supermaven_admin_api
3804 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3805 .await?;
3806
3807 response.send(proto::GetSupermavenApiKeyResponse {
3808 api_key: result.api_key,
3809 })?;
3810
3811 Ok(())
3812}
3813
3814/// Start receiving chat updates for a channel
3815async fn join_channel_chat(
3816 request: proto::JoinChannelChat,
3817 response: Response<proto::JoinChannelChat>,
3818 session: Session,
3819) -> Result<()> {
3820 let channel_id = ChannelId::from_proto(request.channel_id);
3821
3822 let db = session.db().await;
3823 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3824 .await?;
3825 let messages = db
3826 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3827 .await?;
3828 response.send(proto::JoinChannelChatResponse {
3829 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3830 messages,
3831 })?;
3832 Ok(())
3833}
3834
3835/// Stop receiving chat updates for a channel
3836async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3837 let channel_id = ChannelId::from_proto(request.channel_id);
3838 session
3839 .db()
3840 .await
3841 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3842 .await?;
3843 Ok(())
3844}
3845
3846/// Retrieve the chat history for a channel
3847async fn get_channel_messages(
3848 request: proto::GetChannelMessages,
3849 response: Response<proto::GetChannelMessages>,
3850 session: Session,
3851) -> Result<()> {
3852 let channel_id = ChannelId::from_proto(request.channel_id);
3853 let messages = session
3854 .db()
3855 .await
3856 .get_channel_messages(
3857 channel_id,
3858 session.user_id(),
3859 MESSAGE_COUNT_PER_PAGE,
3860 Some(MessageId::from_proto(request.before_message_id)),
3861 )
3862 .await?;
3863 response.send(proto::GetChannelMessagesResponse {
3864 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3865 messages,
3866 })?;
3867 Ok(())
3868}
3869
3870/// Retrieve specific chat messages
3871async fn get_channel_messages_by_id(
3872 request: proto::GetChannelMessagesById,
3873 response: Response<proto::GetChannelMessagesById>,
3874 session: Session,
3875) -> Result<()> {
3876 let message_ids = request
3877 .message_ids
3878 .iter()
3879 .map(|id| MessageId::from_proto(*id))
3880 .collect::<Vec<_>>();
3881 let messages = session
3882 .db()
3883 .await
3884 .get_channel_messages_by_id(session.user_id(), &message_ids)
3885 .await?;
3886 response.send(proto::GetChannelMessagesResponse {
3887 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3888 messages,
3889 })?;
3890 Ok(())
3891}
3892
3893/// Retrieve the current users notifications
3894async fn get_notifications(
3895 request: proto::GetNotifications,
3896 response: Response<proto::GetNotifications>,
3897 session: Session,
3898) -> Result<()> {
3899 let notifications = session
3900 .db()
3901 .await
3902 .get_notifications(
3903 session.user_id(),
3904 NOTIFICATION_COUNT_PER_PAGE,
3905 request.before_id.map(db::NotificationId::from_proto),
3906 )
3907 .await?;
3908 response.send(proto::GetNotificationsResponse {
3909 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3910 notifications,
3911 })?;
3912 Ok(())
3913}
3914
3915/// Mark notifications as read
3916async fn mark_notification_as_read(
3917 request: proto::MarkNotificationRead,
3918 response: Response<proto::MarkNotificationRead>,
3919 session: Session,
3920) -> Result<()> {
3921 let database = &session.db().await;
3922 let notifications = database
3923 .mark_notification_as_read_by_id(
3924 session.user_id(),
3925 NotificationId::from_proto(request.notification_id),
3926 )
3927 .await?;
3928 send_notifications(
3929 &*session.connection_pool().await,
3930 &session.peer,
3931 notifications,
3932 );
3933 response.send(proto::Ack {})?;
3934 Ok(())
3935}
3936
3937/// Get the current users information
3938async fn get_private_user_info(
3939 _request: proto::GetPrivateUserInfo,
3940 response: Response<proto::GetPrivateUserInfo>,
3941 session: Session,
3942) -> Result<()> {
3943 let db = session.db().await;
3944
3945 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3946 let user = db
3947 .get_user_by_id(session.user_id())
3948 .await?
3949 .ok_or_else(|| anyhow!("user not found"))?;
3950 let flags = db.get_user_flags(session.user_id()).await?;
3951
3952 response.send(proto::GetPrivateUserInfoResponse {
3953 metrics_id,
3954 staff: user.admin,
3955 flags,
3956 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3957 })?;
3958 Ok(())
3959}
3960
3961/// Accept the terms of service (tos) on behalf of the current user
3962async fn accept_terms_of_service(
3963 _request: proto::AcceptTermsOfService,
3964 response: Response<proto::AcceptTermsOfService>,
3965 session: Session,
3966) -> Result<()> {
3967 let db = session.db().await;
3968
3969 let accepted_tos_at = Utc::now();
3970 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3971 .await?;
3972
3973 response.send(proto::AcceptTermsOfServiceResponse {
3974 accepted_tos_at: accepted_tos_at.timestamp() as u64,
3975 })?;
3976 Ok(())
3977}
3978
3979/// The minimum account age an account must have in order to use the LLM service.
3980pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3981
3982async fn get_llm_api_token(
3983 _request: proto::GetLlmToken,
3984 response: Response<proto::GetLlmToken>,
3985 session: Session,
3986) -> Result<()> {
3987 let db = session.db().await;
3988
3989 let flags = db.get_user_flags(session.user_id()).await?;
3990
3991 let user_id = session.user_id();
3992 let user = db
3993 .get_user_by_id(user_id)
3994 .await?
3995 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
3996
3997 if user.accepted_tos_at.is_none() {
3998 Err(anyhow!("terms of service not accepted"))?
3999 }
4000
4001 let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4002 let billing_preferences = db.get_billing_preferences(user.id).await?;
4003
4004 let token = LlmTokenClaims::create(
4005 &user,
4006 session.is_staff(),
4007 billing_preferences,
4008 &flags,
4009 billing_subscription,
4010 session.system_id.clone(),
4011 &session.app_state.config,
4012 )?;
4013 response.send(proto::GetLlmTokenResponse { token })?;
4014 Ok(())
4015}
4016
4017fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4018 let message = match message {
4019 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4020 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4021 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4022 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4023 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4024 code: frame.code.into(),
4025 reason: frame.reason.as_str().to_owned().into(),
4026 })),
4027 // We should never receive a frame while reading the message, according
4028 // to the `tungstenite` maintainers:
4029 //
4030 // > It cannot occur when you read messages from the WebSocket, but it
4031 // > can be used when you want to send the raw frames (e.g. you want to
4032 // > send the frames to the WebSocket without composing the full message first).
4033 // >
4034 // > — https://github.com/snapview/tungstenite-rs/issues/268
4035 TungsteniteMessage::Frame(_) => {
4036 bail!("received an unexpected frame while reading the message")
4037 }
4038 };
4039
4040 Ok(message)
4041}
4042
4043fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4044 match message {
4045 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4046 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4047 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4048 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4049 AxumMessage::Close(frame) => {
4050 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4051 code: frame.code.into(),
4052 reason: frame.reason.as_ref().into(),
4053 }))
4054 }
4055 }
4056}
4057
4058fn notify_membership_updated(
4059 connection_pool: &mut ConnectionPool,
4060 result: MembershipUpdated,
4061 user_id: UserId,
4062 peer: &Peer,
4063) {
4064 for membership in &result.new_channels.channel_memberships {
4065 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4066 }
4067 for channel_id in &result.removed_channels {
4068 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4069 }
4070
4071 let user_channels_update = proto::UpdateUserChannels {
4072 channel_memberships: result
4073 .new_channels
4074 .channel_memberships
4075 .iter()
4076 .map(|cm| proto::ChannelMembership {
4077 channel_id: cm.channel_id.to_proto(),
4078 role: cm.role.into(),
4079 })
4080 .collect(),
4081 ..Default::default()
4082 };
4083
4084 let mut update = build_channels_update(result.new_channels);
4085 update.delete_channels = result
4086 .removed_channels
4087 .into_iter()
4088 .map(|id| id.to_proto())
4089 .collect();
4090 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4091
4092 for connection_id in connection_pool.user_connection_ids(user_id) {
4093 peer.send(connection_id, user_channels_update.clone())
4094 .trace_err();
4095 peer.send(connection_id, update.clone()).trace_err();
4096 }
4097}
4098
4099fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4100 proto::UpdateUserChannels {
4101 channel_memberships: channels
4102 .channel_memberships
4103 .iter()
4104 .map(|m| proto::ChannelMembership {
4105 channel_id: m.channel_id.to_proto(),
4106 role: m.role.into(),
4107 })
4108 .collect(),
4109 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4110 observed_channel_message_id: channels.observed_channel_messages.clone(),
4111 }
4112}
4113
4114fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4115 let mut update = proto::UpdateChannels::default();
4116
4117 for channel in channels.channels {
4118 update.channels.push(channel.to_proto());
4119 }
4120
4121 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4122 update.latest_channel_message_ids = channels.latest_channel_messages;
4123
4124 for (channel_id, participants) in channels.channel_participants {
4125 update
4126 .channel_participants
4127 .push(proto::ChannelParticipants {
4128 channel_id: channel_id.to_proto(),
4129 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4130 });
4131 }
4132
4133 for channel in channels.invited_channels {
4134 update.channel_invitations.push(channel.to_proto());
4135 }
4136
4137 update
4138}
4139
4140fn build_initial_contacts_update(
4141 contacts: Vec<db::Contact>,
4142 pool: &ConnectionPool,
4143) -> proto::UpdateContacts {
4144 let mut update = proto::UpdateContacts::default();
4145
4146 for contact in contacts {
4147 match contact {
4148 db::Contact::Accepted { user_id, busy } => {
4149 update.contacts.push(contact_for_user(user_id, busy, pool));
4150 }
4151 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4152 db::Contact::Incoming { user_id } => {
4153 update
4154 .incoming_requests
4155 .push(proto::IncomingContactRequest {
4156 requester_id: user_id.to_proto(),
4157 })
4158 }
4159 }
4160 }
4161
4162 update
4163}
4164
4165fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4166 proto::Contact {
4167 user_id: user_id.to_proto(),
4168 online: pool.is_user_online(user_id),
4169 busy,
4170 }
4171}
4172
4173fn room_updated(room: &proto::Room, peer: &Peer) {
4174 broadcast(
4175 None,
4176 room.participants
4177 .iter()
4178 .filter_map(|participant| Some(participant.peer_id?.into())),
4179 |peer_id| {
4180 peer.send(
4181 peer_id,
4182 proto::RoomUpdated {
4183 room: Some(room.clone()),
4184 },
4185 )
4186 },
4187 );
4188}
4189
4190fn channel_updated(
4191 channel: &db::channel::Model,
4192 room: &proto::Room,
4193 peer: &Peer,
4194 pool: &ConnectionPool,
4195) {
4196 let participants = room
4197 .participants
4198 .iter()
4199 .map(|p| p.user_id)
4200 .collect::<Vec<_>>();
4201
4202 broadcast(
4203 None,
4204 pool.channel_connection_ids(channel.root_id())
4205 .filter_map(|(channel_id, role)| {
4206 role.can_see_channel(channel.visibility)
4207 .then_some(channel_id)
4208 }),
4209 |peer_id| {
4210 peer.send(
4211 peer_id,
4212 proto::UpdateChannels {
4213 channel_participants: vec![proto::ChannelParticipants {
4214 channel_id: channel.id.to_proto(),
4215 participant_user_ids: participants.clone(),
4216 }],
4217 ..Default::default()
4218 },
4219 )
4220 },
4221 );
4222}
4223
4224async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4225 let db = session.db().await;
4226
4227 let contacts = db.get_contacts(user_id).await?;
4228 let busy = db.is_user_busy(user_id).await?;
4229
4230 let pool = session.connection_pool().await;
4231 let updated_contact = contact_for_user(user_id, busy, &pool);
4232 for contact in contacts {
4233 if let db::Contact::Accepted {
4234 user_id: contact_user_id,
4235 ..
4236 } = contact
4237 {
4238 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4239 session
4240 .peer
4241 .send(
4242 contact_conn_id,
4243 proto::UpdateContacts {
4244 contacts: vec![updated_contact.clone()],
4245 remove_contacts: Default::default(),
4246 incoming_requests: Default::default(),
4247 remove_incoming_requests: Default::default(),
4248 outgoing_requests: Default::default(),
4249 remove_outgoing_requests: Default::default(),
4250 },
4251 )
4252 .trace_err();
4253 }
4254 }
4255 }
4256 Ok(())
4257}
4258
4259async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4260 let mut contacts_to_update = HashSet::default();
4261
4262 let room_id;
4263 let canceled_calls_to_user_ids;
4264 let livekit_room;
4265 let delete_livekit_room;
4266 let room;
4267 let channel;
4268
4269 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4270 contacts_to_update.insert(session.user_id());
4271
4272 for project in left_room.left_projects.values() {
4273 project_left(project, session);
4274 }
4275
4276 room_id = RoomId::from_proto(left_room.room.id);
4277 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4278 livekit_room = mem::take(&mut left_room.room.livekit_room);
4279 delete_livekit_room = left_room.deleted;
4280 room = mem::take(&mut left_room.room);
4281 channel = mem::take(&mut left_room.channel);
4282
4283 room_updated(&room, &session.peer);
4284 } else {
4285 return Ok(());
4286 }
4287
4288 if let Some(channel) = channel {
4289 channel_updated(
4290 &channel,
4291 &room,
4292 &session.peer,
4293 &*session.connection_pool().await,
4294 );
4295 }
4296
4297 {
4298 let pool = session.connection_pool().await;
4299 for canceled_user_id in canceled_calls_to_user_ids {
4300 for connection_id in pool.user_connection_ids(canceled_user_id) {
4301 session
4302 .peer
4303 .send(
4304 connection_id,
4305 proto::CallCanceled {
4306 room_id: room_id.to_proto(),
4307 },
4308 )
4309 .trace_err();
4310 }
4311 contacts_to_update.insert(canceled_user_id);
4312 }
4313 }
4314
4315 for contact_user_id in contacts_to_update {
4316 update_user_contacts(contact_user_id, session).await?;
4317 }
4318
4319 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4320 live_kit
4321 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4322 .await
4323 .trace_err();
4324
4325 if delete_livekit_room {
4326 live_kit.delete_room(livekit_room).await.trace_err();
4327 }
4328 }
4329
4330 Ok(())
4331}
4332
4333async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4334 let left_channel_buffers = session
4335 .db()
4336 .await
4337 .leave_channel_buffers(session.connection_id)
4338 .await?;
4339
4340 for left_buffer in left_channel_buffers {
4341 channel_buffer_updated(
4342 session.connection_id,
4343 left_buffer.connections,
4344 &proto::UpdateChannelBufferCollaborators {
4345 channel_id: left_buffer.channel_id.to_proto(),
4346 collaborators: left_buffer.collaborators,
4347 },
4348 &session.peer,
4349 );
4350 }
4351
4352 Ok(())
4353}
4354
4355fn project_left(project: &db::LeftProject, session: &Session) {
4356 for connection_id in &project.connection_ids {
4357 if project.should_unshare {
4358 session
4359 .peer
4360 .send(
4361 *connection_id,
4362 proto::UnshareProject {
4363 project_id: project.id.to_proto(),
4364 },
4365 )
4366 .trace_err();
4367 } else {
4368 session
4369 .peer
4370 .send(
4371 *connection_id,
4372 proto::RemoveProjectCollaborator {
4373 project_id: project.id.to_proto(),
4374 peer_id: Some(session.connection_id.into()),
4375 },
4376 )
4377 .trace_err();
4378 }
4379 }
4380}
4381
4382pub trait ResultExt {
4383 type Ok;
4384
4385 fn trace_err(self) -> Option<Self::Ok>;
4386}
4387
4388impl<T, E> ResultExt for Result<T, E>
4389where
4390 E: std::fmt::Debug,
4391{
4392 type Ok = T;
4393
4394 #[track_caller]
4395 fn trace_err(self) -> Option<T> {
4396 match self {
4397 Ok(value) => Some(value),
4398 Err(error) => {
4399 tracing::error!("{:?}", error);
4400 None
4401 }
4402 }
4403 }
4404}