1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::db::billing_subscription::SubscriptionKind;
5use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
6use crate::{
7 AppState, Error, Result, auth,
8 db::{
9 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
10 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
11 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
12 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15};
16use anyhow::{Context as _, anyhow, bail};
17use async_tungstenite::tungstenite::{
18 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
19};
20use axum::{
21 Extension, Router, TypedHeader,
22 body::Body,
23 extract::{
24 ConnectInfo, WebSocketUpgrade,
25 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use reqwest_client::ReqwestClient;
38use rpc::proto::split_repository_update;
39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
40use util::maybe;
41
42use futures::{
43 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
44 stream::FuturesUnordered,
45};
46use prometheus::{IntGauge, register_int_gauge};
47use rpc::{
48 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 Arc, OnceLock,
66 atomic::{AtomicBool, Ordering::SeqCst},
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{MutexGuard, Semaphore, watch};
72use tower::ServiceBuilder;
73use tracing::{
74 Instrument,
75 field::{self},
76 info_span, instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 /// The GeoIP country code for the user.
137 #[allow(unused)]
138 geoip_country_code: Option<String>,
139 system_id: Option<String>,
140 _executor: Executor,
141}
142
143impl Session {
144 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
145 #[cfg(test)]
146 tokio::task::yield_now().await;
147 let guard = self.db.lock().await;
148 #[cfg(test)]
149 tokio::task::yield_now().await;
150 guard
151 }
152
153 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 let guard = self.connection_pool.lock();
157 ConnectionPoolGuard {
158 guard,
159 _not_send: PhantomData,
160 }
161 }
162
163 fn is_staff(&self) -> bool {
164 match &self.principal {
165 Principal::User(user) => user.admin,
166 Principal::Impersonated { .. } => true,
167 }
168 }
169
170 pub async fn has_llm_subscription(
171 &self,
172 db: &MutexGuard<'_, DbHandle>,
173 ) -> anyhow::Result<bool> {
174 if self.is_staff() {
175 return Ok(true);
176 }
177
178 let user_id = self.user_id();
179
180 Ok(db.has_active_billing_subscription(user_id).await?)
181 }
182
183 pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
184 if self.is_staff() {
185 return Ok(proto::Plan::ZedPro);
186 }
187
188 let user_id = self.user_id();
189
190 let subscription = db.get_active_billing_subscription(user_id).await?;
191 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
192
193 let plan = if let Some(subscription_kind) = subscription_kind {
194 match subscription_kind {
195 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
196 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
197 SubscriptionKind::ZedFree => proto::Plan::Free,
198 }
199 } else {
200 proto::Plan::Free
201 };
202
203 Ok(plan)
204 }
205
206 fn user_id(&self) -> UserId {
207 match &self.principal {
208 Principal::User(user) => user.id,
209 Principal::Impersonated { user, .. } => user.id,
210 }
211 }
212
213 pub fn email(&self) -> Option<String> {
214 match &self.principal {
215 Principal::User(user) => user.email_address.clone(),
216 Principal::Impersonated { user, .. } => user.email_address.clone(),
217 }
218 }
219}
220
221impl Debug for Session {
222 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
223 let mut result = f.debug_struct("Session");
224 match &self.principal {
225 Principal::User(user) => {
226 result.field("user", &user.github_login);
227 }
228 Principal::Impersonated { user, admin } => {
229 result.field("user", &user.github_login);
230 result.field("impersonator", &admin.github_login);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct DbHandle(Arc<Database>);
238
239impl Deref for DbHandle {
240 type Target = Database;
241
242 fn deref(&self) -> &Self::Target {
243 self.0.as_ref()
244 }
245}
246
247pub struct Server {
248 id: parking_lot::Mutex<ServerId>,
249 peer: Arc<Peer>,
250 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
251 app_state: Arc<AppState>,
252 handlers: HashMap<TypeId, MessageHandler>,
253 teardown: watch::Sender<bool>,
254}
255
256pub(crate) struct ConnectionPoolGuard<'a> {
257 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
258 _not_send: PhantomData<Rc<()>>,
259}
260
261#[derive(Serialize)]
262pub struct ServerSnapshot<'a> {
263 peer: &'a Peer,
264 #[serde(serialize_with = "serialize_deref")]
265 connection_pool: ConnectionPoolGuard<'a>,
266}
267
268pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
269where
270 S: Serializer,
271 T: Deref<Target = U>,
272 U: Serialize,
273{
274 Serialize::serialize(value.deref(), serializer)
275}
276
277impl Server {
278 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
279 let mut server = Self {
280 id: parking_lot::Mutex::new(id),
281 peer: Peer::new(id.0 as u32),
282 app_state: app_state.clone(),
283 connection_pool: Default::default(),
284 handlers: Default::default(),
285 teardown: watch::channel(false).0,
286 };
287
288 server
289 .add_request_handler(ping)
290 .add_request_handler(create_room)
291 .add_request_handler(join_room)
292 .add_request_handler(rejoin_room)
293 .add_request_handler(leave_room)
294 .add_request_handler(set_room_participant_role)
295 .add_request_handler(call)
296 .add_request_handler(cancel_call)
297 .add_message_handler(decline_call)
298 .add_request_handler(update_participant_location)
299 .add_request_handler(share_project)
300 .add_message_handler(unshare_project)
301 .add_request_handler(join_project)
302 .add_message_handler(leave_project)
303 .add_request_handler(update_project)
304 .add_request_handler(update_worktree)
305 .add_request_handler(update_repository)
306 .add_request_handler(remove_repository)
307 .add_message_handler(start_language_server)
308 .add_message_handler(update_language_server)
309 .add_message_handler(update_diagnostic_summary)
310 .add_message_handler(update_worktree_settings)
311 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
312 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
313 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
314 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
315 .add_request_handler(forward_find_search_candidates_request)
316 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
317 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
318 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
319 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
320 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
321 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
322 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
323 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
324 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
325 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
326 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
327 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
328 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
329 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
330 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
331 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
332 .add_request_handler(
333 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
334 )
335 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
336 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
337 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
338 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
339 .add_request_handler(
340 forward_read_only_project_request::<proto::LanguageServerIdForName>,
341 )
342 .add_request_handler(
343 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
344 )
345 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
346 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
347 .add_request_handler(
348 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
349 )
350 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
355 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
356 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
357 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
358 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
359 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
360 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
361 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
362 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
363 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
364 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
365 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
366 .add_request_handler(
367 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
368 )
369 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
370 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
371 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
372 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
373 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
374 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
375 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
376 .add_message_handler(create_buffer_for_peer)
377 .add_request_handler(update_buffer)
378 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
379 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
380 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
381 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
382 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
383 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
384 .add_request_handler(get_users)
385 .add_request_handler(fuzzy_search_users)
386 .add_request_handler(request_contact)
387 .add_request_handler(remove_contact)
388 .add_request_handler(respond_to_contact_request)
389 .add_message_handler(subscribe_to_channels)
390 .add_request_handler(create_channel)
391 .add_request_handler(delete_channel)
392 .add_request_handler(invite_channel_member)
393 .add_request_handler(remove_channel_member)
394 .add_request_handler(set_channel_member_role)
395 .add_request_handler(set_channel_visibility)
396 .add_request_handler(rename_channel)
397 .add_request_handler(join_channel_buffer)
398 .add_request_handler(leave_channel_buffer)
399 .add_message_handler(update_channel_buffer)
400 .add_request_handler(rejoin_channel_buffers)
401 .add_request_handler(get_channel_members)
402 .add_request_handler(respond_to_channel_invite)
403 .add_request_handler(join_channel)
404 .add_request_handler(join_channel_chat)
405 .add_message_handler(leave_channel_chat)
406 .add_request_handler(send_channel_message)
407 .add_request_handler(remove_channel_message)
408 .add_request_handler(update_channel_message)
409 .add_request_handler(get_channel_messages)
410 .add_request_handler(get_channel_messages_by_id)
411 .add_request_handler(get_notifications)
412 .add_request_handler(mark_notification_as_read)
413 .add_request_handler(move_channel)
414 .add_request_handler(follow)
415 .add_message_handler(unfollow)
416 .add_message_handler(update_followers)
417 .add_request_handler(get_private_user_info)
418 .add_request_handler(get_llm_api_token)
419 .add_request_handler(accept_terms_of_service)
420 .add_message_handler(acknowledge_channel_message)
421 .add_message_handler(acknowledge_buffer_version)
422 .add_request_handler(get_supermaven_api_key)
423 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
424 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
425 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
426 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
427 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
428 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
429 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
430 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
431 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
432 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
433 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
434 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
435 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
436 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
437 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
438 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
439 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
440 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
441 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
442 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
443 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
444 .add_message_handler(update_context);
445
446 Arc::new(server)
447 }
448
449 pub async fn start(&self) -> Result<()> {
450 let server_id = *self.id.lock();
451 let app_state = self.app_state.clone();
452 let peer = self.peer.clone();
453 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
454 let pool = self.connection_pool.clone();
455 let livekit_client = self.app_state.livekit_client.clone();
456
457 let span = info_span!("start server");
458 self.app_state.executor.spawn_detached(
459 async move {
460 tracing::info!("waiting for cleanup timeout");
461 timeout.await;
462 tracing::info!("cleanup timeout expired, retrieving stale rooms");
463 if let Some((room_ids, channel_ids)) = app_state
464 .db
465 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
466 .await
467 .trace_err()
468 {
469 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
470 tracing::info!(
471 stale_channel_buffer_count = channel_ids.len(),
472 "retrieved stale channel buffers"
473 );
474
475 for channel_id in channel_ids {
476 if let Some(refreshed_channel_buffer) = app_state
477 .db
478 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
479 .await
480 .trace_err()
481 {
482 for connection_id in refreshed_channel_buffer.connection_ids {
483 peer.send(
484 connection_id,
485 proto::UpdateChannelBufferCollaborators {
486 channel_id: channel_id.to_proto(),
487 collaborators: refreshed_channel_buffer
488 .collaborators
489 .clone(),
490 },
491 )
492 .trace_err();
493 }
494 }
495 }
496
497 for room_id in room_ids {
498 let mut contacts_to_update = HashSet::default();
499 let mut canceled_calls_to_user_ids = Vec::new();
500 let mut livekit_room = String::new();
501 let mut delete_livekit_room = false;
502
503 if let Some(mut refreshed_room) = app_state
504 .db
505 .clear_stale_room_participants(room_id, server_id)
506 .await
507 .trace_err()
508 {
509 tracing::info!(
510 room_id = room_id.0,
511 new_participant_count = refreshed_room.room.participants.len(),
512 "refreshed room"
513 );
514 room_updated(&refreshed_room.room, &peer);
515 if let Some(channel) = refreshed_room.channel.as_ref() {
516 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
517 }
518 contacts_to_update
519 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
520 contacts_to_update
521 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
522 canceled_calls_to_user_ids =
523 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
524 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
525 delete_livekit_room = refreshed_room.room.participants.is_empty();
526 }
527
528 {
529 let pool = pool.lock();
530 for canceled_user_id in canceled_calls_to_user_ids {
531 for connection_id in pool.user_connection_ids(canceled_user_id) {
532 peer.send(
533 connection_id,
534 proto::CallCanceled {
535 room_id: room_id.to_proto(),
536 },
537 )
538 .trace_err();
539 }
540 }
541 }
542
543 for user_id in contacts_to_update {
544 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
545 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
546 if let Some((busy, contacts)) = busy.zip(contacts) {
547 let pool = pool.lock();
548 let updated_contact = contact_for_user(user_id, busy, &pool);
549 for contact in contacts {
550 if let db::Contact::Accepted {
551 user_id: contact_user_id,
552 ..
553 } = contact
554 {
555 for contact_conn_id in
556 pool.user_connection_ids(contact_user_id)
557 {
558 peer.send(
559 contact_conn_id,
560 proto::UpdateContacts {
561 contacts: vec![updated_contact.clone()],
562 remove_contacts: Default::default(),
563 incoming_requests: Default::default(),
564 remove_incoming_requests: Default::default(),
565 outgoing_requests: Default::default(),
566 remove_outgoing_requests: Default::default(),
567 },
568 )
569 .trace_err();
570 }
571 }
572 }
573 }
574 }
575
576 if let Some(live_kit) = livekit_client.as_ref() {
577 if delete_livekit_room {
578 live_kit.delete_room(livekit_room).await.trace_err();
579 }
580 }
581 }
582 }
583
584 app_state
585 .db
586 .delete_stale_servers(&app_state.config.zed_environment, server_id)
587 .await
588 .trace_err();
589 }
590 .instrument(span),
591 );
592 Ok(())
593 }
594
595 pub fn teardown(&self) {
596 self.peer.teardown();
597 self.connection_pool.lock().reset();
598 let _ = self.teardown.send(true);
599 }
600
601 #[cfg(test)]
602 pub fn reset(&self, id: ServerId) {
603 self.teardown();
604 *self.id.lock() = id;
605 self.peer.reset(id.0 as u32);
606 let _ = self.teardown.send(false);
607 }
608
609 #[cfg(test)]
610 pub fn id(&self) -> ServerId {
611 *self.id.lock()
612 }
613
614 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
615 where
616 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
617 Fut: 'static + Send + Future<Output = Result<()>>,
618 M: EnvelopedMessage,
619 {
620 let prev_handler = self.handlers.insert(
621 TypeId::of::<M>(),
622 Box::new(move |envelope, session| {
623 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
624 let received_at = envelope.received_at;
625 tracing::info!("message received");
626 let start_time = Instant::now();
627 let future = (handler)(*envelope, session);
628 async move {
629 let result = future.await;
630 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
631 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
632 let queue_duration_ms = total_duration_ms - processing_duration_ms;
633 let payload_type = M::NAME;
634
635 match result {
636 Err(error) => {
637 tracing::error!(
638 ?error,
639 total_duration_ms,
640 processing_duration_ms,
641 queue_duration_ms,
642 payload_type,
643 "error handling message"
644 )
645 }
646 Ok(()) => tracing::info!(
647 total_duration_ms,
648 processing_duration_ms,
649 queue_duration_ms,
650 "finished handling message"
651 ),
652 }
653 }
654 .boxed()
655 }),
656 );
657 if prev_handler.is_some() {
658 panic!("registered a handler for the same message twice");
659 }
660 self
661 }
662
663 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
664 where
665 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
666 Fut: 'static + Send + Future<Output = Result<()>>,
667 M: EnvelopedMessage,
668 {
669 self.add_handler(move |envelope, session| handler(envelope.payload, session));
670 self
671 }
672
673 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
674 where
675 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
676 Fut: Send + Future<Output = Result<()>>,
677 M: RequestMessage,
678 {
679 let handler = Arc::new(handler);
680 self.add_handler(move |envelope, session| {
681 let receipt = envelope.receipt();
682 let handler = handler.clone();
683 async move {
684 let peer = session.peer.clone();
685 let responded = Arc::new(AtomicBool::default());
686 let response = Response {
687 peer: peer.clone(),
688 responded: responded.clone(),
689 receipt,
690 };
691 match (handler)(envelope.payload, response, session).await {
692 Ok(()) => {
693 if responded.load(std::sync::atomic::Ordering::SeqCst) {
694 Ok(())
695 } else {
696 Err(anyhow!("handler did not send a response"))?
697 }
698 }
699 Err(error) => {
700 let proto_err = match &error {
701 Error::Internal(err) => err.to_proto(),
702 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
703 };
704 peer.respond_with_error(receipt, proto_err)?;
705 Err(error)
706 }
707 }
708 }
709 })
710 }
711
712 pub fn handle_connection(
713 self: &Arc<Self>,
714 connection: Connection,
715 address: String,
716 principal: Principal,
717 zed_version: ZedVersion,
718 geoip_country_code: Option<String>,
719 system_id: Option<String>,
720 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
721 executor: Executor,
722 ) -> impl Future<Output = ()> + use<> {
723 let this = self.clone();
724 let span = info_span!("handle connection", %address,
725 connection_id=field::Empty,
726 user_id=field::Empty,
727 login=field::Empty,
728 impersonator=field::Empty,
729 geoip_country_code=field::Empty
730 );
731 principal.update_span(&span);
732 if let Some(country_code) = geoip_country_code.as_ref() {
733 span.record("geoip_country_code", country_code);
734 }
735
736 let mut teardown = self.teardown.subscribe();
737 async move {
738 if *teardown.borrow() {
739 tracing::error!("server is tearing down");
740 return
741 }
742 let (connection_id, handle_io, mut incoming_rx) = this
743 .peer
744 .add_connection(connection, {
745 let executor = executor.clone();
746 move |duration| executor.sleep(duration)
747 });
748 tracing::Span::current().record("connection_id", format!("{}", connection_id));
749
750 tracing::info!("connection opened");
751
752 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
753 let http_client = match ReqwestClient::user_agent(&user_agent) {
754 Ok(http_client) => Arc::new(http_client),
755 Err(error) => {
756 tracing::error!(?error, "failed to create HTTP client");
757 return;
758 }
759 };
760
761 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
762 supermaven_admin_api_key.to_string(),
763 http_client.clone(),
764 )));
765
766 let session = Session {
767 principal: principal.clone(),
768 connection_id,
769 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
770 peer: this.peer.clone(),
771 connection_pool: this.connection_pool.clone(),
772 app_state: this.app_state.clone(),
773 geoip_country_code,
774 system_id,
775 _executor: executor.clone(),
776 supermaven_client,
777 };
778
779 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
780 tracing::error!(?error, "failed to send initial client update");
781 return;
782 }
783
784 let handle_io = handle_io.fuse();
785 futures::pin_mut!(handle_io);
786
787 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
788 // This prevents deadlocks when e.g., client A performs a request to client B and
789 // client B performs a request to client A. If both clients stop processing further
790 // messages until their respective request completes, they won't have a chance to
791 // respond to the other client's request and cause a deadlock.
792 //
793 // This arrangement ensures we will attempt to process earlier messages first, but fall
794 // back to processing messages arrived later in the spirit of making progress.
795 let mut foreground_message_handlers = FuturesUnordered::new();
796 let concurrent_handlers = Arc::new(Semaphore::new(256));
797 loop {
798 let next_message = async {
799 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
800 let message = incoming_rx.next().await;
801 (permit, message)
802 }.fuse();
803 futures::pin_mut!(next_message);
804 futures::select_biased! {
805 _ = teardown.changed().fuse() => return,
806 result = handle_io => {
807 if let Err(error) = result {
808 tracing::error!(?error, "error handling I/O");
809 }
810 break;
811 }
812 _ = foreground_message_handlers.next() => {}
813 next_message = next_message => {
814 let (permit, message) = next_message;
815 if let Some(message) = message {
816 let type_name = message.payload_type_name();
817 // note: we copy all the fields from the parent span so we can query them in the logs.
818 // (https://github.com/tokio-rs/tracing/issues/2670).
819 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
820 user_id=field::Empty,
821 login=field::Empty,
822 impersonator=field::Empty,
823 );
824 principal.update_span(&span);
825 let span_enter = span.enter();
826 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
827 let is_background = message.is_background();
828 let handle_message = (handler)(message, session.clone());
829 drop(span_enter);
830
831 let handle_message = async move {
832 handle_message.await;
833 drop(permit);
834 }.instrument(span);
835 if is_background {
836 executor.spawn_detached(handle_message);
837 } else {
838 foreground_message_handlers.push(handle_message);
839 }
840 } else {
841 tracing::error!("no message handler");
842 }
843 } else {
844 tracing::info!("connection closed");
845 break;
846 }
847 }
848 }
849 }
850
851 drop(foreground_message_handlers);
852 tracing::info!("signing out");
853 if let Err(error) = connection_lost(session, teardown, executor).await {
854 tracing::error!(?error, "error signing out");
855 }
856
857 }.instrument(span)
858 }
859
860 async fn send_initial_client_update(
861 &self,
862 connection_id: ConnectionId,
863 principal: &Principal,
864 zed_version: ZedVersion,
865 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
866 session: &Session,
867 ) -> Result<()> {
868 self.peer.send(
869 connection_id,
870 proto::Hello {
871 peer_id: Some(connection_id.into()),
872 },
873 )?;
874 tracing::info!("sent hello message");
875 if let Some(send_connection_id) = send_connection_id.take() {
876 let _ = send_connection_id.send(connection_id);
877 }
878
879 match principal {
880 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
881 if !user.connected_once {
882 self.peer.send(connection_id, proto::ShowContacts {})?;
883 self.app_state
884 .db
885 .set_user_connected_once(user.id, true)
886 .await?;
887 }
888
889 update_user_plan(user.id, session).await?;
890
891 let contacts = self.app_state.db.get_contacts(user.id).await?;
892
893 {
894 let mut pool = self.connection_pool.lock();
895 pool.add_connection(connection_id, user.id, user.admin, zed_version);
896 self.peer.send(
897 connection_id,
898 build_initial_contacts_update(contacts, &pool),
899 )?;
900 }
901
902 if should_auto_subscribe_to_channels(zed_version) {
903 subscribe_user_to_channels(user.id, session).await?;
904 }
905
906 if let Some(incoming_call) =
907 self.app_state.db.incoming_call_for_user(user.id).await?
908 {
909 self.peer.send(connection_id, incoming_call)?;
910 }
911
912 update_user_contacts(user.id, session).await?;
913 }
914 }
915
916 Ok(())
917 }
918
919 pub async fn invite_code_redeemed(
920 self: &Arc<Self>,
921 inviter_id: UserId,
922 invitee_id: UserId,
923 ) -> Result<()> {
924 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
925 if let Some(code) = &user.invite_code {
926 let pool = self.connection_pool.lock();
927 let invitee_contact = contact_for_user(invitee_id, false, &pool);
928 for connection_id in pool.user_connection_ids(inviter_id) {
929 self.peer.send(
930 connection_id,
931 proto::UpdateContacts {
932 contacts: vec![invitee_contact.clone()],
933 ..Default::default()
934 },
935 )?;
936 self.peer.send(
937 connection_id,
938 proto::UpdateInviteInfo {
939 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
940 count: user.invite_count as u32,
941 },
942 )?;
943 }
944 }
945 }
946 Ok(())
947 }
948
949 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
950 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
951 if let Some(invite_code) = &user.invite_code {
952 let pool = self.connection_pool.lock();
953 for connection_id in pool.user_connection_ids(user_id) {
954 self.peer.send(
955 connection_id,
956 proto::UpdateInviteInfo {
957 url: format!(
958 "{}{}",
959 self.app_state.config.invite_link_prefix, invite_code
960 ),
961 count: user.invite_count as u32,
962 },
963 )?;
964 }
965 }
966 }
967 Ok(())
968 }
969
970 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
971 let pool = self.connection_pool.lock();
972 for connection_id in pool.user_connection_ids(user_id) {
973 self.peer
974 .send(connection_id, proto::RefreshLlmToken {})
975 .trace_err();
976 }
977 }
978
979 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
980 ServerSnapshot {
981 connection_pool: ConnectionPoolGuard {
982 guard: self.connection_pool.lock(),
983 _not_send: PhantomData,
984 },
985 peer: &self.peer,
986 }
987 }
988}
989
990impl Deref for ConnectionPoolGuard<'_> {
991 type Target = ConnectionPool;
992
993 fn deref(&self) -> &Self::Target {
994 &self.guard
995 }
996}
997
998impl DerefMut for ConnectionPoolGuard<'_> {
999 fn deref_mut(&mut self) -> &mut Self::Target {
1000 &mut self.guard
1001 }
1002}
1003
1004impl Drop for ConnectionPoolGuard<'_> {
1005 fn drop(&mut self) {
1006 #[cfg(test)]
1007 self.check_invariants();
1008 }
1009}
1010
1011fn broadcast<F>(
1012 sender_id: Option<ConnectionId>,
1013 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1014 mut f: F,
1015) where
1016 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1017{
1018 for receiver_id in receiver_ids {
1019 if Some(receiver_id) != sender_id {
1020 if let Err(error) = f(receiver_id) {
1021 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1022 }
1023 }
1024 }
1025}
1026
1027pub struct ProtocolVersion(u32);
1028
1029impl Header for ProtocolVersion {
1030 fn name() -> &'static HeaderName {
1031 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1032 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1033 }
1034
1035 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1036 where
1037 Self: Sized,
1038 I: Iterator<Item = &'i axum::http::HeaderValue>,
1039 {
1040 let version = values
1041 .next()
1042 .ok_or_else(axum::headers::Error::invalid)?
1043 .to_str()
1044 .map_err(|_| axum::headers::Error::invalid())?
1045 .parse()
1046 .map_err(|_| axum::headers::Error::invalid())?;
1047 Ok(Self(version))
1048 }
1049
1050 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1051 values.extend([self.0.to_string().parse().unwrap()]);
1052 }
1053}
1054
1055pub struct AppVersionHeader(SemanticVersion);
1056impl Header for AppVersionHeader {
1057 fn name() -> &'static HeaderName {
1058 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1059 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1060 }
1061
1062 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1063 where
1064 Self: Sized,
1065 I: Iterator<Item = &'i axum::http::HeaderValue>,
1066 {
1067 let version = values
1068 .next()
1069 .ok_or_else(axum::headers::Error::invalid)?
1070 .to_str()
1071 .map_err(|_| axum::headers::Error::invalid())?
1072 .parse()
1073 .map_err(|_| axum::headers::Error::invalid())?;
1074 Ok(Self(version))
1075 }
1076
1077 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1078 values.extend([self.0.to_string().parse().unwrap()]);
1079 }
1080}
1081
1082pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1083 Router::new()
1084 .route("/rpc", get(handle_websocket_request))
1085 .layer(
1086 ServiceBuilder::new()
1087 .layer(Extension(server.app_state.clone()))
1088 .layer(middleware::from_fn(auth::validate_header)),
1089 )
1090 .route("/metrics", get(handle_metrics))
1091 .layer(Extension(server))
1092}
1093
1094pub async fn handle_websocket_request(
1095 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1096 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1097 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1098 Extension(server): Extension<Arc<Server>>,
1099 Extension(principal): Extension<Principal>,
1100 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1101 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1102 ws: WebSocketUpgrade,
1103) -> axum::response::Response {
1104 if protocol_version != rpc::PROTOCOL_VERSION {
1105 return (
1106 StatusCode::UPGRADE_REQUIRED,
1107 "client must be upgraded".to_string(),
1108 )
1109 .into_response();
1110 }
1111
1112 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1113 return (
1114 StatusCode::UPGRADE_REQUIRED,
1115 "no version header found".to_string(),
1116 )
1117 .into_response();
1118 };
1119
1120 if !version.can_collaborate() {
1121 return (
1122 StatusCode::UPGRADE_REQUIRED,
1123 "client must be upgraded".to_string(),
1124 )
1125 .into_response();
1126 }
1127
1128 let socket_address = socket_address.to_string();
1129 ws.on_upgrade(move |socket| {
1130 let socket = socket
1131 .map_ok(to_tungstenite_message)
1132 .err_into()
1133 .with(|message| async move { to_axum_message(message) });
1134 let connection = Connection::new(Box::pin(socket));
1135 async move {
1136 server
1137 .handle_connection(
1138 connection,
1139 socket_address,
1140 principal,
1141 version,
1142 country_code_header.map(|header| header.to_string()),
1143 system_id_header.map(|header| header.to_string()),
1144 None,
1145 Executor::Production,
1146 )
1147 .await;
1148 }
1149 })
1150}
1151
1152pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1153 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1154 let connections_metric = CONNECTIONS_METRIC
1155 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1156
1157 let connections = server
1158 .connection_pool
1159 .lock()
1160 .connections()
1161 .filter(|connection| !connection.admin)
1162 .count();
1163 connections_metric.set(connections as _);
1164
1165 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1166 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1167 register_int_gauge!(
1168 "shared_projects",
1169 "number of open projects with one or more guests"
1170 )
1171 .unwrap()
1172 });
1173
1174 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1175 shared_projects_metric.set(shared_projects as _);
1176
1177 let encoder = prometheus::TextEncoder::new();
1178 let metric_families = prometheus::gather();
1179 let encoded_metrics = encoder
1180 .encode_to_string(&metric_families)
1181 .map_err(|err| anyhow!("{}", err))?;
1182 Ok(encoded_metrics)
1183}
1184
1185#[instrument(err, skip(executor))]
1186async fn connection_lost(
1187 session: Session,
1188 mut teardown: watch::Receiver<bool>,
1189 executor: Executor,
1190) -> Result<()> {
1191 session.peer.disconnect(session.connection_id);
1192 session
1193 .connection_pool()
1194 .await
1195 .remove_connection(session.connection_id)?;
1196
1197 session
1198 .db()
1199 .await
1200 .connection_lost(session.connection_id)
1201 .await
1202 .trace_err();
1203
1204 futures::select_biased! {
1205 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1206
1207 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1208 leave_room_for_session(&session, session.connection_id).await.trace_err();
1209 leave_channel_buffers_for_session(&session)
1210 .await
1211 .trace_err();
1212
1213 if !session
1214 .connection_pool()
1215 .await
1216 .is_user_online(session.user_id())
1217 {
1218 let db = session.db().await;
1219 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1220 room_updated(&room, &session.peer);
1221 }
1222 }
1223
1224 update_user_contacts(session.user_id(), &session).await?;
1225 },
1226 _ = teardown.changed().fuse() => {}
1227 }
1228
1229 Ok(())
1230}
1231
1232/// Acknowledges a ping from a client, used to keep the connection alive.
1233async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1234 response.send(proto::Ack {})?;
1235 Ok(())
1236}
1237
1238/// Creates a new room for calling (outside of channels)
1239async fn create_room(
1240 _request: proto::CreateRoom,
1241 response: Response<proto::CreateRoom>,
1242 session: Session,
1243) -> Result<()> {
1244 let livekit_room = nanoid::nanoid!(30);
1245
1246 let live_kit_connection_info = util::maybe!(async {
1247 let live_kit = session.app_state.livekit_client.as_ref();
1248 let live_kit = live_kit?;
1249 let user_id = session.user_id().to_string();
1250
1251 let token = live_kit
1252 .room_token(&livekit_room, &user_id.to_string())
1253 .trace_err()?;
1254
1255 Some(proto::LiveKitConnectionInfo {
1256 server_url: live_kit.url().into(),
1257 token,
1258 can_publish: true,
1259 })
1260 })
1261 .await;
1262
1263 let room = session
1264 .db()
1265 .await
1266 .create_room(session.user_id(), session.connection_id, &livekit_room)
1267 .await?;
1268
1269 response.send(proto::CreateRoomResponse {
1270 room: Some(room.clone()),
1271 live_kit_connection_info,
1272 })?;
1273
1274 update_user_contacts(session.user_id(), &session).await?;
1275 Ok(())
1276}
1277
1278/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1279async fn join_room(
1280 request: proto::JoinRoom,
1281 response: Response<proto::JoinRoom>,
1282 session: Session,
1283) -> Result<()> {
1284 let room_id = RoomId::from_proto(request.id);
1285
1286 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1287
1288 if let Some(channel_id) = channel_id {
1289 return join_channel_internal(channel_id, Box::new(response), session).await;
1290 }
1291
1292 let joined_room = {
1293 let room = session
1294 .db()
1295 .await
1296 .join_room(room_id, session.user_id(), session.connection_id)
1297 .await?;
1298 room_updated(&room.room, &session.peer);
1299 room.into_inner()
1300 };
1301
1302 for connection_id in session
1303 .connection_pool()
1304 .await
1305 .user_connection_ids(session.user_id())
1306 {
1307 session
1308 .peer
1309 .send(
1310 connection_id,
1311 proto::CallCanceled {
1312 room_id: room_id.to_proto(),
1313 },
1314 )
1315 .trace_err();
1316 }
1317
1318 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1319 {
1320 live_kit
1321 .room_token(
1322 &joined_room.room.livekit_room,
1323 &session.user_id().to_string(),
1324 )
1325 .trace_err()
1326 .map(|token| proto::LiveKitConnectionInfo {
1327 server_url: live_kit.url().into(),
1328 token,
1329 can_publish: true,
1330 })
1331 } else {
1332 None
1333 };
1334
1335 response.send(proto::JoinRoomResponse {
1336 room: Some(joined_room.room),
1337 channel_id: None,
1338 live_kit_connection_info,
1339 })?;
1340
1341 update_user_contacts(session.user_id(), &session).await?;
1342 Ok(())
1343}
1344
1345/// Rejoin room is used to reconnect to a room after connection errors.
1346async fn rejoin_room(
1347 request: proto::RejoinRoom,
1348 response: Response<proto::RejoinRoom>,
1349 session: Session,
1350) -> Result<()> {
1351 let room;
1352 let channel;
1353 {
1354 let mut rejoined_room = session
1355 .db()
1356 .await
1357 .rejoin_room(request, session.user_id(), session.connection_id)
1358 .await?;
1359
1360 response.send(proto::RejoinRoomResponse {
1361 room: Some(rejoined_room.room.clone()),
1362 reshared_projects: rejoined_room
1363 .reshared_projects
1364 .iter()
1365 .map(|project| proto::ResharedProject {
1366 id: project.id.to_proto(),
1367 collaborators: project
1368 .collaborators
1369 .iter()
1370 .map(|collaborator| collaborator.to_proto())
1371 .collect(),
1372 })
1373 .collect(),
1374 rejoined_projects: rejoined_room
1375 .rejoined_projects
1376 .iter()
1377 .map(|rejoined_project| rejoined_project.to_proto())
1378 .collect(),
1379 })?;
1380 room_updated(&rejoined_room.room, &session.peer);
1381
1382 for project in &rejoined_room.reshared_projects {
1383 for collaborator in &project.collaborators {
1384 session
1385 .peer
1386 .send(
1387 collaborator.connection_id,
1388 proto::UpdateProjectCollaborator {
1389 project_id: project.id.to_proto(),
1390 old_peer_id: Some(project.old_connection_id.into()),
1391 new_peer_id: Some(session.connection_id.into()),
1392 },
1393 )
1394 .trace_err();
1395 }
1396
1397 broadcast(
1398 Some(session.connection_id),
1399 project
1400 .collaborators
1401 .iter()
1402 .map(|collaborator| collaborator.connection_id),
1403 |connection_id| {
1404 session.peer.forward_send(
1405 session.connection_id,
1406 connection_id,
1407 proto::UpdateProject {
1408 project_id: project.id.to_proto(),
1409 worktrees: project.worktrees.clone(),
1410 },
1411 )
1412 },
1413 );
1414 }
1415
1416 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1417
1418 let rejoined_room = rejoined_room.into_inner();
1419
1420 room = rejoined_room.room;
1421 channel = rejoined_room.channel;
1422 }
1423
1424 if let Some(channel) = channel {
1425 channel_updated(
1426 &channel,
1427 &room,
1428 &session.peer,
1429 &*session.connection_pool().await,
1430 );
1431 }
1432
1433 update_user_contacts(session.user_id(), &session).await?;
1434 Ok(())
1435}
1436
1437fn notify_rejoined_projects(
1438 rejoined_projects: &mut Vec<RejoinedProject>,
1439 session: &Session,
1440) -> Result<()> {
1441 for project in rejoined_projects.iter() {
1442 for collaborator in &project.collaborators {
1443 session
1444 .peer
1445 .send(
1446 collaborator.connection_id,
1447 proto::UpdateProjectCollaborator {
1448 project_id: project.id.to_proto(),
1449 old_peer_id: Some(project.old_connection_id.into()),
1450 new_peer_id: Some(session.connection_id.into()),
1451 },
1452 )
1453 .trace_err();
1454 }
1455 }
1456
1457 for project in rejoined_projects {
1458 for worktree in mem::take(&mut project.worktrees) {
1459 // Stream this worktree's entries.
1460 let message = proto::UpdateWorktree {
1461 project_id: project.id.to_proto(),
1462 worktree_id: worktree.id,
1463 abs_path: worktree.abs_path.clone(),
1464 root_name: worktree.root_name,
1465 updated_entries: worktree.updated_entries,
1466 removed_entries: worktree.removed_entries,
1467 scan_id: worktree.scan_id,
1468 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1469 updated_repositories: worktree.updated_repositories,
1470 removed_repositories: worktree.removed_repositories,
1471 };
1472 for update in proto::split_worktree_update(message) {
1473 session.peer.send(session.connection_id, update)?;
1474 }
1475
1476 // Stream this worktree's diagnostics.
1477 for summary in worktree.diagnostic_summaries {
1478 session.peer.send(
1479 session.connection_id,
1480 proto::UpdateDiagnosticSummary {
1481 project_id: project.id.to_proto(),
1482 worktree_id: worktree.id,
1483 summary: Some(summary),
1484 },
1485 )?;
1486 }
1487
1488 for settings_file in worktree.settings_files {
1489 session.peer.send(
1490 session.connection_id,
1491 proto::UpdateWorktreeSettings {
1492 project_id: project.id.to_proto(),
1493 worktree_id: worktree.id,
1494 path: settings_file.path,
1495 content: Some(settings_file.content),
1496 kind: Some(settings_file.kind.to_proto().into()),
1497 },
1498 )?;
1499 }
1500 }
1501
1502 for repository in mem::take(&mut project.updated_repositories) {
1503 for update in split_repository_update(repository) {
1504 session.peer.send(session.connection_id, update)?;
1505 }
1506 }
1507
1508 for id in mem::take(&mut project.removed_repositories) {
1509 session.peer.send(
1510 session.connection_id,
1511 proto::RemoveRepository {
1512 project_id: project.id.to_proto(),
1513 id,
1514 },
1515 )?;
1516 }
1517 }
1518
1519 Ok(())
1520}
1521
1522/// leave room disconnects from the room.
1523async fn leave_room(
1524 _: proto::LeaveRoom,
1525 response: Response<proto::LeaveRoom>,
1526 session: Session,
1527) -> Result<()> {
1528 leave_room_for_session(&session, session.connection_id).await?;
1529 response.send(proto::Ack {})?;
1530 Ok(())
1531}
1532
1533/// Updates the permissions of someone else in the room.
1534async fn set_room_participant_role(
1535 request: proto::SetRoomParticipantRole,
1536 response: Response<proto::SetRoomParticipantRole>,
1537 session: Session,
1538) -> Result<()> {
1539 let user_id = UserId::from_proto(request.user_id);
1540 let role = ChannelRole::from(request.role());
1541
1542 let (livekit_room, can_publish) = {
1543 let room = session
1544 .db()
1545 .await
1546 .set_room_participant_role(
1547 session.user_id(),
1548 RoomId::from_proto(request.room_id),
1549 user_id,
1550 role,
1551 )
1552 .await?;
1553
1554 let livekit_room = room.livekit_room.clone();
1555 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1556 room_updated(&room, &session.peer);
1557 (livekit_room, can_publish)
1558 };
1559
1560 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1561 live_kit
1562 .update_participant(
1563 livekit_room.clone(),
1564 request.user_id.to_string(),
1565 livekit_api::proto::ParticipantPermission {
1566 can_subscribe: true,
1567 can_publish,
1568 can_publish_data: can_publish,
1569 hidden: false,
1570 recorder: false,
1571 },
1572 )
1573 .await
1574 .trace_err();
1575 }
1576
1577 response.send(proto::Ack {})?;
1578 Ok(())
1579}
1580
1581/// Call someone else into the current room
1582async fn call(
1583 request: proto::Call,
1584 response: Response<proto::Call>,
1585 session: Session,
1586) -> Result<()> {
1587 let room_id = RoomId::from_proto(request.room_id);
1588 let calling_user_id = session.user_id();
1589 let calling_connection_id = session.connection_id;
1590 let called_user_id = UserId::from_proto(request.called_user_id);
1591 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1592 if !session
1593 .db()
1594 .await
1595 .has_contact(calling_user_id, called_user_id)
1596 .await?
1597 {
1598 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1599 }
1600
1601 let incoming_call = {
1602 let (room, incoming_call) = &mut *session
1603 .db()
1604 .await
1605 .call(
1606 room_id,
1607 calling_user_id,
1608 calling_connection_id,
1609 called_user_id,
1610 initial_project_id,
1611 )
1612 .await?;
1613 room_updated(room, &session.peer);
1614 mem::take(incoming_call)
1615 };
1616 update_user_contacts(called_user_id, &session).await?;
1617
1618 let mut calls = session
1619 .connection_pool()
1620 .await
1621 .user_connection_ids(called_user_id)
1622 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1623 .collect::<FuturesUnordered<_>>();
1624
1625 while let Some(call_response) = calls.next().await {
1626 match call_response.as_ref() {
1627 Ok(_) => {
1628 response.send(proto::Ack {})?;
1629 return Ok(());
1630 }
1631 Err(_) => {
1632 call_response.trace_err();
1633 }
1634 }
1635 }
1636
1637 {
1638 let room = session
1639 .db()
1640 .await
1641 .call_failed(room_id, called_user_id)
1642 .await?;
1643 room_updated(&room, &session.peer);
1644 }
1645 update_user_contacts(called_user_id, &session).await?;
1646
1647 Err(anyhow!("failed to ring user"))?
1648}
1649
1650/// Cancel an outgoing call.
1651async fn cancel_call(
1652 request: proto::CancelCall,
1653 response: Response<proto::CancelCall>,
1654 session: Session,
1655) -> Result<()> {
1656 let called_user_id = UserId::from_proto(request.called_user_id);
1657 let room_id = RoomId::from_proto(request.room_id);
1658 {
1659 let room = session
1660 .db()
1661 .await
1662 .cancel_call(room_id, session.connection_id, called_user_id)
1663 .await?;
1664 room_updated(&room, &session.peer);
1665 }
1666
1667 for connection_id in session
1668 .connection_pool()
1669 .await
1670 .user_connection_ids(called_user_id)
1671 {
1672 session
1673 .peer
1674 .send(
1675 connection_id,
1676 proto::CallCanceled {
1677 room_id: room_id.to_proto(),
1678 },
1679 )
1680 .trace_err();
1681 }
1682 response.send(proto::Ack {})?;
1683
1684 update_user_contacts(called_user_id, &session).await?;
1685 Ok(())
1686}
1687
1688/// Decline an incoming call.
1689async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1690 let room_id = RoomId::from_proto(message.room_id);
1691 {
1692 let room = session
1693 .db()
1694 .await
1695 .decline_call(Some(room_id), session.user_id())
1696 .await?
1697 .ok_or_else(|| anyhow!("failed to decline call"))?;
1698 room_updated(&room, &session.peer);
1699 }
1700
1701 for connection_id in session
1702 .connection_pool()
1703 .await
1704 .user_connection_ids(session.user_id())
1705 {
1706 session
1707 .peer
1708 .send(
1709 connection_id,
1710 proto::CallCanceled {
1711 room_id: room_id.to_proto(),
1712 },
1713 )
1714 .trace_err();
1715 }
1716 update_user_contacts(session.user_id(), &session).await?;
1717 Ok(())
1718}
1719
1720/// Updates other participants in the room with your current location.
1721async fn update_participant_location(
1722 request: proto::UpdateParticipantLocation,
1723 response: Response<proto::UpdateParticipantLocation>,
1724 session: Session,
1725) -> Result<()> {
1726 let room_id = RoomId::from_proto(request.room_id);
1727 let location = request
1728 .location
1729 .ok_or_else(|| anyhow!("invalid location"))?;
1730
1731 let db = session.db().await;
1732 let room = db
1733 .update_room_participant_location(room_id, session.connection_id, location)
1734 .await?;
1735
1736 room_updated(&room, &session.peer);
1737 response.send(proto::Ack {})?;
1738 Ok(())
1739}
1740
1741/// Share a project into the room.
1742async fn share_project(
1743 request: proto::ShareProject,
1744 response: Response<proto::ShareProject>,
1745 session: Session,
1746) -> Result<()> {
1747 let (project_id, room) = &*session
1748 .db()
1749 .await
1750 .share_project(
1751 RoomId::from_proto(request.room_id),
1752 session.connection_id,
1753 &request.worktrees,
1754 request.is_ssh_project,
1755 )
1756 .await?;
1757 response.send(proto::ShareProjectResponse {
1758 project_id: project_id.to_proto(),
1759 })?;
1760 room_updated(room, &session.peer);
1761
1762 Ok(())
1763}
1764
1765/// Unshare a project from the room.
1766async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1767 let project_id = ProjectId::from_proto(message.project_id);
1768 unshare_project_internal(project_id, session.connection_id, &session).await
1769}
1770
1771async fn unshare_project_internal(
1772 project_id: ProjectId,
1773 connection_id: ConnectionId,
1774 session: &Session,
1775) -> Result<()> {
1776 let delete = {
1777 let room_guard = session
1778 .db()
1779 .await
1780 .unshare_project(project_id, connection_id)
1781 .await?;
1782
1783 let (delete, room, guest_connection_ids) = &*room_guard;
1784
1785 let message = proto::UnshareProject {
1786 project_id: project_id.to_proto(),
1787 };
1788
1789 broadcast(
1790 Some(connection_id),
1791 guest_connection_ids.iter().copied(),
1792 |conn_id| session.peer.send(conn_id, message.clone()),
1793 );
1794 if let Some(room) = room {
1795 room_updated(room, &session.peer);
1796 }
1797
1798 *delete
1799 };
1800
1801 if delete {
1802 let db = session.db().await;
1803 db.delete_project(project_id).await?;
1804 }
1805
1806 Ok(())
1807}
1808
1809/// Join someone elses shared project.
1810async fn join_project(
1811 request: proto::JoinProject,
1812 response: Response<proto::JoinProject>,
1813 session: Session,
1814) -> Result<()> {
1815 let project_id = ProjectId::from_proto(request.project_id);
1816
1817 tracing::info!(%project_id, "join project");
1818
1819 let db = session.db().await;
1820 let (project, replica_id) = &mut *db
1821 .join_project(project_id, session.connection_id, session.user_id())
1822 .await?;
1823 drop(db);
1824 tracing::info!(%project_id, "join remote project");
1825 join_project_internal(response, session, project, replica_id)
1826}
1827
1828trait JoinProjectInternalResponse {
1829 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1830}
1831impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1832 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1833 Response::<proto::JoinProject>::send(self, result)
1834 }
1835}
1836
1837fn join_project_internal(
1838 response: impl JoinProjectInternalResponse,
1839 session: Session,
1840 project: &mut Project,
1841 replica_id: &ReplicaId,
1842) -> Result<()> {
1843 let collaborators = project
1844 .collaborators
1845 .iter()
1846 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1847 .map(|collaborator| collaborator.to_proto())
1848 .collect::<Vec<_>>();
1849 let project_id = project.id;
1850 let guest_user_id = session.user_id();
1851
1852 let worktrees = project
1853 .worktrees
1854 .iter()
1855 .map(|(id, worktree)| proto::WorktreeMetadata {
1856 id: *id,
1857 root_name: worktree.root_name.clone(),
1858 visible: worktree.visible,
1859 abs_path: worktree.abs_path.clone(),
1860 })
1861 .collect::<Vec<_>>();
1862
1863 let add_project_collaborator = proto::AddProjectCollaborator {
1864 project_id: project_id.to_proto(),
1865 collaborator: Some(proto::Collaborator {
1866 peer_id: Some(session.connection_id.into()),
1867 replica_id: replica_id.0 as u32,
1868 user_id: guest_user_id.to_proto(),
1869 is_host: false,
1870 }),
1871 };
1872
1873 for collaborator in &collaborators {
1874 session
1875 .peer
1876 .send(
1877 collaborator.peer_id.unwrap().into(),
1878 add_project_collaborator.clone(),
1879 )
1880 .trace_err();
1881 }
1882
1883 // First, we send the metadata associated with each worktree.
1884 response.send(proto::JoinProjectResponse {
1885 project_id: project.id.0 as u64,
1886 worktrees: worktrees.clone(),
1887 replica_id: replica_id.0 as u32,
1888 collaborators: collaborators.clone(),
1889 language_servers: project.language_servers.clone(),
1890 role: project.role.into(),
1891 })?;
1892
1893 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1894 // Stream this worktree's entries.
1895 let message = proto::UpdateWorktree {
1896 project_id: project_id.to_proto(),
1897 worktree_id,
1898 abs_path: worktree.abs_path.clone(),
1899 root_name: worktree.root_name,
1900 updated_entries: worktree.entries,
1901 removed_entries: Default::default(),
1902 scan_id: worktree.scan_id,
1903 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1904 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1905 removed_repositories: Default::default(),
1906 };
1907 for update in proto::split_worktree_update(message) {
1908 session.peer.send(session.connection_id, update.clone())?;
1909 }
1910
1911 // Stream this worktree's diagnostics.
1912 for summary in worktree.diagnostic_summaries {
1913 session.peer.send(
1914 session.connection_id,
1915 proto::UpdateDiagnosticSummary {
1916 project_id: project_id.to_proto(),
1917 worktree_id: worktree.id,
1918 summary: Some(summary),
1919 },
1920 )?;
1921 }
1922
1923 for settings_file in worktree.settings_files {
1924 session.peer.send(
1925 session.connection_id,
1926 proto::UpdateWorktreeSettings {
1927 project_id: project_id.to_proto(),
1928 worktree_id: worktree.id,
1929 path: settings_file.path,
1930 content: Some(settings_file.content),
1931 kind: Some(settings_file.kind.to_proto() as i32),
1932 },
1933 )?;
1934 }
1935 }
1936
1937 for repository in mem::take(&mut project.repositories) {
1938 for update in split_repository_update(repository) {
1939 session.peer.send(session.connection_id, update)?;
1940 }
1941 }
1942
1943 for language_server in &project.language_servers {
1944 session.peer.send(
1945 session.connection_id,
1946 proto::UpdateLanguageServer {
1947 project_id: project_id.to_proto(),
1948 language_server_id: language_server.id,
1949 variant: Some(
1950 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1951 proto::LspDiskBasedDiagnosticsUpdated {},
1952 ),
1953 ),
1954 },
1955 )?;
1956 }
1957
1958 Ok(())
1959}
1960
1961/// Leave someone elses shared project.
1962async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1963 let sender_id = session.connection_id;
1964 let project_id = ProjectId::from_proto(request.project_id);
1965 let db = session.db().await;
1966
1967 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1968 tracing::info!(
1969 %project_id,
1970 "leave project"
1971 );
1972
1973 project_left(project, &session);
1974 if let Some(room) = room {
1975 room_updated(room, &session.peer);
1976 }
1977
1978 Ok(())
1979}
1980
1981/// Updates other participants with changes to the project
1982async fn update_project(
1983 request: proto::UpdateProject,
1984 response: Response<proto::UpdateProject>,
1985 session: Session,
1986) -> Result<()> {
1987 let project_id = ProjectId::from_proto(request.project_id);
1988 let (room, guest_connection_ids) = &*session
1989 .db()
1990 .await
1991 .update_project(project_id, session.connection_id, &request.worktrees)
1992 .await?;
1993 broadcast(
1994 Some(session.connection_id),
1995 guest_connection_ids.iter().copied(),
1996 |connection_id| {
1997 session
1998 .peer
1999 .forward_send(session.connection_id, connection_id, request.clone())
2000 },
2001 );
2002 if let Some(room) = room {
2003 room_updated(room, &session.peer);
2004 }
2005 response.send(proto::Ack {})?;
2006
2007 Ok(())
2008}
2009
2010/// Updates other participants with changes to the worktree
2011async fn update_worktree(
2012 request: proto::UpdateWorktree,
2013 response: Response<proto::UpdateWorktree>,
2014 session: Session,
2015) -> Result<()> {
2016 let guest_connection_ids = session
2017 .db()
2018 .await
2019 .update_worktree(&request, session.connection_id)
2020 .await?;
2021
2022 broadcast(
2023 Some(session.connection_id),
2024 guest_connection_ids.iter().copied(),
2025 |connection_id| {
2026 session
2027 .peer
2028 .forward_send(session.connection_id, connection_id, request.clone())
2029 },
2030 );
2031 response.send(proto::Ack {})?;
2032 Ok(())
2033}
2034
2035async fn update_repository(
2036 request: proto::UpdateRepository,
2037 response: Response<proto::UpdateRepository>,
2038 session: Session,
2039) -> Result<()> {
2040 let guest_connection_ids = session
2041 .db()
2042 .await
2043 .update_repository(&request, session.connection_id)
2044 .await?;
2045
2046 broadcast(
2047 Some(session.connection_id),
2048 guest_connection_ids.iter().copied(),
2049 |connection_id| {
2050 session
2051 .peer
2052 .forward_send(session.connection_id, connection_id, request.clone())
2053 },
2054 );
2055 response.send(proto::Ack {})?;
2056 Ok(())
2057}
2058
2059async fn remove_repository(
2060 request: proto::RemoveRepository,
2061 response: Response<proto::RemoveRepository>,
2062 session: Session,
2063) -> Result<()> {
2064 let guest_connection_ids = session
2065 .db()
2066 .await
2067 .remove_repository(&request, session.connection_id)
2068 .await?;
2069
2070 broadcast(
2071 Some(session.connection_id),
2072 guest_connection_ids.iter().copied(),
2073 |connection_id| {
2074 session
2075 .peer
2076 .forward_send(session.connection_id, connection_id, request.clone())
2077 },
2078 );
2079 response.send(proto::Ack {})?;
2080 Ok(())
2081}
2082
2083/// Updates other participants with changes to the diagnostics
2084async fn update_diagnostic_summary(
2085 message: proto::UpdateDiagnosticSummary,
2086 session: Session,
2087) -> Result<()> {
2088 let guest_connection_ids = session
2089 .db()
2090 .await
2091 .update_diagnostic_summary(&message, session.connection_id)
2092 .await?;
2093
2094 broadcast(
2095 Some(session.connection_id),
2096 guest_connection_ids.iter().copied(),
2097 |connection_id| {
2098 session
2099 .peer
2100 .forward_send(session.connection_id, connection_id, message.clone())
2101 },
2102 );
2103
2104 Ok(())
2105}
2106
2107/// Updates other participants with changes to the worktree settings
2108async fn update_worktree_settings(
2109 message: proto::UpdateWorktreeSettings,
2110 session: Session,
2111) -> Result<()> {
2112 let guest_connection_ids = session
2113 .db()
2114 .await
2115 .update_worktree_settings(&message, session.connection_id)
2116 .await?;
2117
2118 broadcast(
2119 Some(session.connection_id),
2120 guest_connection_ids.iter().copied(),
2121 |connection_id| {
2122 session
2123 .peer
2124 .forward_send(session.connection_id, connection_id, message.clone())
2125 },
2126 );
2127
2128 Ok(())
2129}
2130
2131/// Notify other participants that a language server has started.
2132async fn start_language_server(
2133 request: proto::StartLanguageServer,
2134 session: Session,
2135) -> Result<()> {
2136 let guest_connection_ids = session
2137 .db()
2138 .await
2139 .start_language_server(&request, session.connection_id)
2140 .await?;
2141
2142 broadcast(
2143 Some(session.connection_id),
2144 guest_connection_ids.iter().copied(),
2145 |connection_id| {
2146 session
2147 .peer
2148 .forward_send(session.connection_id, connection_id, request.clone())
2149 },
2150 );
2151 Ok(())
2152}
2153
2154/// Notify other participants that a language server has changed.
2155async fn update_language_server(
2156 request: proto::UpdateLanguageServer,
2157 session: Session,
2158) -> Result<()> {
2159 let project_id = ProjectId::from_proto(request.project_id);
2160 let project_connection_ids = session
2161 .db()
2162 .await
2163 .project_connection_ids(project_id, session.connection_id, true)
2164 .await?;
2165 broadcast(
2166 Some(session.connection_id),
2167 project_connection_ids.iter().copied(),
2168 |connection_id| {
2169 session
2170 .peer
2171 .forward_send(session.connection_id, connection_id, request.clone())
2172 },
2173 );
2174 Ok(())
2175}
2176
2177/// forward a project request to the host. These requests should be read only
2178/// as guests are allowed to send them.
2179async fn forward_read_only_project_request<T>(
2180 request: T,
2181 response: Response<T>,
2182 session: Session,
2183) -> Result<()>
2184where
2185 T: EntityMessage + RequestMessage,
2186{
2187 let project_id = ProjectId::from_proto(request.remote_entity_id());
2188 let host_connection_id = session
2189 .db()
2190 .await
2191 .host_for_read_only_project_request(project_id, session.connection_id)
2192 .await?;
2193 let payload = session
2194 .peer
2195 .forward_request(session.connection_id, host_connection_id, request)
2196 .await?;
2197 response.send(payload)?;
2198 Ok(())
2199}
2200
2201async fn forward_find_search_candidates_request(
2202 request: proto::FindSearchCandidates,
2203 response: Response<proto::FindSearchCandidates>,
2204 session: Session,
2205) -> Result<()> {
2206 let project_id = ProjectId::from_proto(request.remote_entity_id());
2207 let host_connection_id = session
2208 .db()
2209 .await
2210 .host_for_read_only_project_request(project_id, session.connection_id)
2211 .await?;
2212 let payload = session
2213 .peer
2214 .forward_request(session.connection_id, host_connection_id, request)
2215 .await?;
2216 response.send(payload)?;
2217 Ok(())
2218}
2219
2220/// forward a project request to the host. These requests are disallowed
2221/// for guests.
2222async fn forward_mutating_project_request<T>(
2223 request: T,
2224 response: Response<T>,
2225 session: Session,
2226) -> Result<()>
2227where
2228 T: EntityMessage + RequestMessage,
2229{
2230 let project_id = ProjectId::from_proto(request.remote_entity_id());
2231
2232 let host_connection_id = session
2233 .db()
2234 .await
2235 .host_for_mutating_project_request(project_id, session.connection_id)
2236 .await?;
2237 let payload = session
2238 .peer
2239 .forward_request(session.connection_id, host_connection_id, request)
2240 .await?;
2241 response.send(payload)?;
2242 Ok(())
2243}
2244
2245/// Notify other participants that a new buffer has been created
2246async fn create_buffer_for_peer(
2247 request: proto::CreateBufferForPeer,
2248 session: Session,
2249) -> Result<()> {
2250 session
2251 .db()
2252 .await
2253 .check_user_is_project_host(
2254 ProjectId::from_proto(request.project_id),
2255 session.connection_id,
2256 )
2257 .await?;
2258 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2259 session
2260 .peer
2261 .forward_send(session.connection_id, peer_id.into(), request)?;
2262 Ok(())
2263}
2264
2265/// Notify other participants that a buffer has been updated. This is
2266/// allowed for guests as long as the update is limited to selections.
2267async fn update_buffer(
2268 request: proto::UpdateBuffer,
2269 response: Response<proto::UpdateBuffer>,
2270 session: Session,
2271) -> Result<()> {
2272 let project_id = ProjectId::from_proto(request.project_id);
2273 let mut capability = Capability::ReadOnly;
2274
2275 for op in request.operations.iter() {
2276 match op.variant {
2277 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2278 Some(_) => capability = Capability::ReadWrite,
2279 }
2280 }
2281
2282 let host = {
2283 let guard = session
2284 .db()
2285 .await
2286 .connections_for_buffer_update(project_id, session.connection_id, capability)
2287 .await?;
2288
2289 let (host, guests) = &*guard;
2290
2291 broadcast(
2292 Some(session.connection_id),
2293 guests.clone(),
2294 |connection_id| {
2295 session
2296 .peer
2297 .forward_send(session.connection_id, connection_id, request.clone())
2298 },
2299 );
2300
2301 *host
2302 };
2303
2304 if host != session.connection_id {
2305 session
2306 .peer
2307 .forward_request(session.connection_id, host, request.clone())
2308 .await?;
2309 }
2310
2311 response.send(proto::Ack {})?;
2312 Ok(())
2313}
2314
2315async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2316 let project_id = ProjectId::from_proto(message.project_id);
2317
2318 let operation = message.operation.as_ref().context("invalid operation")?;
2319 let capability = match operation.variant.as_ref() {
2320 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2321 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2322 match buffer_op.variant {
2323 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2324 Capability::ReadOnly
2325 }
2326 _ => Capability::ReadWrite,
2327 }
2328 } else {
2329 Capability::ReadWrite
2330 }
2331 }
2332 Some(_) => Capability::ReadWrite,
2333 None => Capability::ReadOnly,
2334 };
2335
2336 let guard = session
2337 .db()
2338 .await
2339 .connections_for_buffer_update(project_id, session.connection_id, capability)
2340 .await?;
2341
2342 let (host, guests) = &*guard;
2343
2344 broadcast(
2345 Some(session.connection_id),
2346 guests.iter().chain([host]).copied(),
2347 |connection_id| {
2348 session
2349 .peer
2350 .forward_send(session.connection_id, connection_id, message.clone())
2351 },
2352 );
2353
2354 Ok(())
2355}
2356
2357/// Notify other participants that a project has been updated.
2358async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2359 request: T,
2360 session: Session,
2361) -> Result<()> {
2362 let project_id = ProjectId::from_proto(request.remote_entity_id());
2363 let project_connection_ids = session
2364 .db()
2365 .await
2366 .project_connection_ids(project_id, session.connection_id, false)
2367 .await?;
2368
2369 broadcast(
2370 Some(session.connection_id),
2371 project_connection_ids.iter().copied(),
2372 |connection_id| {
2373 session
2374 .peer
2375 .forward_send(session.connection_id, connection_id, request.clone())
2376 },
2377 );
2378 Ok(())
2379}
2380
2381/// Start following another user in a call.
2382async fn follow(
2383 request: proto::Follow,
2384 response: Response<proto::Follow>,
2385 session: Session,
2386) -> Result<()> {
2387 let room_id = RoomId::from_proto(request.room_id);
2388 let project_id = request.project_id.map(ProjectId::from_proto);
2389 let leader_id = request
2390 .leader_id
2391 .ok_or_else(|| anyhow!("invalid leader id"))?
2392 .into();
2393 let follower_id = session.connection_id;
2394
2395 session
2396 .db()
2397 .await
2398 .check_room_participants(room_id, leader_id, session.connection_id)
2399 .await?;
2400
2401 let response_payload = session
2402 .peer
2403 .forward_request(session.connection_id, leader_id, request)
2404 .await?;
2405 response.send(response_payload)?;
2406
2407 if let Some(project_id) = project_id {
2408 let room = session
2409 .db()
2410 .await
2411 .follow(room_id, project_id, leader_id, follower_id)
2412 .await?;
2413 room_updated(&room, &session.peer);
2414 }
2415
2416 Ok(())
2417}
2418
2419/// Stop following another user in a call.
2420async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2421 let room_id = RoomId::from_proto(request.room_id);
2422 let project_id = request.project_id.map(ProjectId::from_proto);
2423 let leader_id = request
2424 .leader_id
2425 .ok_or_else(|| anyhow!("invalid leader id"))?
2426 .into();
2427 let follower_id = session.connection_id;
2428
2429 session
2430 .db()
2431 .await
2432 .check_room_participants(room_id, leader_id, session.connection_id)
2433 .await?;
2434
2435 session
2436 .peer
2437 .forward_send(session.connection_id, leader_id, request)?;
2438
2439 if let Some(project_id) = project_id {
2440 let room = session
2441 .db()
2442 .await
2443 .unfollow(room_id, project_id, leader_id, follower_id)
2444 .await?;
2445 room_updated(&room, &session.peer);
2446 }
2447
2448 Ok(())
2449}
2450
2451/// Notify everyone following you of your current location.
2452async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2453 let room_id = RoomId::from_proto(request.room_id);
2454 let database = session.db.lock().await;
2455
2456 let connection_ids = if let Some(project_id) = request.project_id {
2457 let project_id = ProjectId::from_proto(project_id);
2458 database
2459 .project_connection_ids(project_id, session.connection_id, true)
2460 .await?
2461 } else {
2462 database
2463 .room_connection_ids(room_id, session.connection_id)
2464 .await?
2465 };
2466
2467 // For now, don't send view update messages back to that view's current leader.
2468 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2469 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2470 _ => None,
2471 });
2472
2473 for connection_id in connection_ids.iter().cloned() {
2474 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2475 session
2476 .peer
2477 .forward_send(session.connection_id, connection_id, request.clone())?;
2478 }
2479 }
2480 Ok(())
2481}
2482
2483/// Get public data about users.
2484async fn get_users(
2485 request: proto::GetUsers,
2486 response: Response<proto::GetUsers>,
2487 session: Session,
2488) -> Result<()> {
2489 let user_ids = request
2490 .user_ids
2491 .into_iter()
2492 .map(UserId::from_proto)
2493 .collect();
2494 let users = session
2495 .db()
2496 .await
2497 .get_users_by_ids(user_ids)
2498 .await?
2499 .into_iter()
2500 .map(|user| proto::User {
2501 id: user.id.to_proto(),
2502 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2503 github_login: user.github_login,
2504 email: user.email_address,
2505 name: user.name,
2506 })
2507 .collect();
2508 response.send(proto::UsersResponse { users })?;
2509 Ok(())
2510}
2511
2512/// Search for users (to invite) buy Github login
2513async fn fuzzy_search_users(
2514 request: proto::FuzzySearchUsers,
2515 response: Response<proto::FuzzySearchUsers>,
2516 session: Session,
2517) -> Result<()> {
2518 let query = request.query;
2519 let users = match query.len() {
2520 0 => vec![],
2521 1 | 2 => session
2522 .db()
2523 .await
2524 .get_user_by_github_login(&query)
2525 .await?
2526 .into_iter()
2527 .collect(),
2528 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2529 };
2530 let users = users
2531 .into_iter()
2532 .filter(|user| user.id != session.user_id())
2533 .map(|user| proto::User {
2534 id: user.id.to_proto(),
2535 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2536 github_login: user.github_login,
2537 name: user.name,
2538 email: user.email_address,
2539 })
2540 .collect();
2541 response.send(proto::UsersResponse { users })?;
2542 Ok(())
2543}
2544
2545/// Send a contact request to another user.
2546async fn request_contact(
2547 request: proto::RequestContact,
2548 response: Response<proto::RequestContact>,
2549 session: Session,
2550) -> Result<()> {
2551 let requester_id = session.user_id();
2552 let responder_id = UserId::from_proto(request.responder_id);
2553 if requester_id == responder_id {
2554 return Err(anyhow!("cannot add yourself as a contact"))?;
2555 }
2556
2557 let notifications = session
2558 .db()
2559 .await
2560 .send_contact_request(requester_id, responder_id)
2561 .await?;
2562
2563 // Update outgoing contact requests of requester
2564 let mut update = proto::UpdateContacts::default();
2565 update.outgoing_requests.push(responder_id.to_proto());
2566 for connection_id in session
2567 .connection_pool()
2568 .await
2569 .user_connection_ids(requester_id)
2570 {
2571 session.peer.send(connection_id, update.clone())?;
2572 }
2573
2574 // Update incoming contact requests of responder
2575 let mut update = proto::UpdateContacts::default();
2576 update
2577 .incoming_requests
2578 .push(proto::IncomingContactRequest {
2579 requester_id: requester_id.to_proto(),
2580 });
2581 let connection_pool = session.connection_pool().await;
2582 for connection_id in connection_pool.user_connection_ids(responder_id) {
2583 session.peer.send(connection_id, update.clone())?;
2584 }
2585
2586 send_notifications(&connection_pool, &session.peer, notifications);
2587
2588 response.send(proto::Ack {})?;
2589 Ok(())
2590}
2591
2592/// Accept or decline a contact request
2593async fn respond_to_contact_request(
2594 request: proto::RespondToContactRequest,
2595 response: Response<proto::RespondToContactRequest>,
2596 session: Session,
2597) -> Result<()> {
2598 let responder_id = session.user_id();
2599 let requester_id = UserId::from_proto(request.requester_id);
2600 let db = session.db().await;
2601 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2602 db.dismiss_contact_notification(responder_id, requester_id)
2603 .await?;
2604 } else {
2605 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2606
2607 let notifications = db
2608 .respond_to_contact_request(responder_id, requester_id, accept)
2609 .await?;
2610 let requester_busy = db.is_user_busy(requester_id).await?;
2611 let responder_busy = db.is_user_busy(responder_id).await?;
2612
2613 let pool = session.connection_pool().await;
2614 // Update responder with new contact
2615 let mut update = proto::UpdateContacts::default();
2616 if accept {
2617 update
2618 .contacts
2619 .push(contact_for_user(requester_id, requester_busy, &pool));
2620 }
2621 update
2622 .remove_incoming_requests
2623 .push(requester_id.to_proto());
2624 for connection_id in pool.user_connection_ids(responder_id) {
2625 session.peer.send(connection_id, update.clone())?;
2626 }
2627
2628 // Update requester with new contact
2629 let mut update = proto::UpdateContacts::default();
2630 if accept {
2631 update
2632 .contacts
2633 .push(contact_for_user(responder_id, responder_busy, &pool));
2634 }
2635 update
2636 .remove_outgoing_requests
2637 .push(responder_id.to_proto());
2638
2639 for connection_id in pool.user_connection_ids(requester_id) {
2640 session.peer.send(connection_id, update.clone())?;
2641 }
2642
2643 send_notifications(&pool, &session.peer, notifications);
2644 }
2645
2646 response.send(proto::Ack {})?;
2647 Ok(())
2648}
2649
2650/// Remove a contact.
2651async fn remove_contact(
2652 request: proto::RemoveContact,
2653 response: Response<proto::RemoveContact>,
2654 session: Session,
2655) -> Result<()> {
2656 let requester_id = session.user_id();
2657 let responder_id = UserId::from_proto(request.user_id);
2658 let db = session.db().await;
2659 let (contact_accepted, deleted_notification_id) =
2660 db.remove_contact(requester_id, responder_id).await?;
2661
2662 let pool = session.connection_pool().await;
2663 // Update outgoing contact requests of requester
2664 let mut update = proto::UpdateContacts::default();
2665 if contact_accepted {
2666 update.remove_contacts.push(responder_id.to_proto());
2667 } else {
2668 update
2669 .remove_outgoing_requests
2670 .push(responder_id.to_proto());
2671 }
2672 for connection_id in pool.user_connection_ids(requester_id) {
2673 session.peer.send(connection_id, update.clone())?;
2674 }
2675
2676 // Update incoming contact requests of responder
2677 let mut update = proto::UpdateContacts::default();
2678 if contact_accepted {
2679 update.remove_contacts.push(requester_id.to_proto());
2680 } else {
2681 update
2682 .remove_incoming_requests
2683 .push(requester_id.to_proto());
2684 }
2685 for connection_id in pool.user_connection_ids(responder_id) {
2686 session.peer.send(connection_id, update.clone())?;
2687 if let Some(notification_id) = deleted_notification_id {
2688 session.peer.send(
2689 connection_id,
2690 proto::DeleteNotification {
2691 notification_id: notification_id.to_proto(),
2692 },
2693 )?;
2694 }
2695 }
2696
2697 response.send(proto::Ack {})?;
2698 Ok(())
2699}
2700
2701fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2702 version.0.minor() < 139
2703}
2704
2705async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
2706 let db = session.db().await;
2707
2708 let feature_flags = db.get_user_flags(user_id).await?;
2709 let plan = session.current_plan(&db).await?;
2710 let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
2711 let billing_preferences = db.get_billing_preferences(user_id).await?;
2712
2713 let usage = if let Some(llm_db) = session.app_state.llm_db.clone() {
2714 let subscription = db.get_active_billing_subscription(user_id).await?;
2715
2716 let subscription_period = maybe!({
2717 let subscription = subscription?;
2718 let period_start_at = subscription.current_period_start_at()?;
2719 let period_end_at = subscription.current_period_end_at()?;
2720
2721 Some((period_start_at, period_end_at))
2722 });
2723
2724 if let Some((period_start_at, period_end_at)) = subscription_period {
2725 llm_db
2726 .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
2727 .await?
2728 } else {
2729 None
2730 }
2731 } else {
2732 None
2733 };
2734
2735 session
2736 .peer
2737 .send(
2738 session.connection_id,
2739 proto::UpdateUserPlan {
2740 plan: plan.into(),
2741 trial_started_at: billing_customer
2742 .and_then(|billing_customer| billing_customer.trial_started_at)
2743 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2744 is_usage_based_billing_enabled: billing_preferences
2745 .map(|preferences| preferences.model_request_overages_enabled),
2746 usage: usage.map(|usage| {
2747 let plan = match plan {
2748 proto::Plan::Free => zed_llm_client::Plan::Free,
2749 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2750 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2751 };
2752
2753 let model_requests_limit = match plan.model_requests_limit() {
2754 zed_llm_client::UsageLimit::Limited(limit) => {
2755 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2756 && feature_flags
2757 .iter()
2758 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2759 {
2760 1_000
2761 } else {
2762 limit
2763 };
2764
2765 zed_llm_client::UsageLimit::Limited(limit)
2766 }
2767 zed_llm_client::UsageLimit::Unlimited => {
2768 zed_llm_client::UsageLimit::Unlimited
2769 }
2770 };
2771
2772 proto::SubscriptionUsage {
2773 model_requests_usage_amount: usage.model_requests as u32,
2774 model_requests_usage_limit: Some(proto::UsageLimit {
2775 variant: Some(match model_requests_limit {
2776 zed_llm_client::UsageLimit::Limited(limit) => {
2777 proto::usage_limit::Variant::Limited(
2778 proto::usage_limit::Limited {
2779 limit: limit as u32,
2780 },
2781 )
2782 }
2783 zed_llm_client::UsageLimit::Unlimited => {
2784 proto::usage_limit::Variant::Unlimited(
2785 proto::usage_limit::Unlimited {},
2786 )
2787 }
2788 }),
2789 }),
2790 edit_predictions_usage_amount: usage.edit_predictions as u32,
2791 edit_predictions_usage_limit: Some(proto::UsageLimit {
2792 variant: Some(match plan.edit_predictions_limit() {
2793 zed_llm_client::UsageLimit::Limited(limit) => {
2794 proto::usage_limit::Variant::Limited(
2795 proto::usage_limit::Limited {
2796 limit: limit as u32,
2797 },
2798 )
2799 }
2800 zed_llm_client::UsageLimit::Unlimited => {
2801 proto::usage_limit::Variant::Unlimited(
2802 proto::usage_limit::Unlimited {},
2803 )
2804 }
2805 }),
2806 }),
2807 }
2808 }),
2809 },
2810 )
2811 .trace_err();
2812
2813 Ok(())
2814}
2815
2816async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2817 subscribe_user_to_channels(session.user_id(), &session).await?;
2818 Ok(())
2819}
2820
2821async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2822 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2823 let mut pool = session.connection_pool().await;
2824 for membership in &channels_for_user.channel_memberships {
2825 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2826 }
2827 session.peer.send(
2828 session.connection_id,
2829 build_update_user_channels(&channels_for_user),
2830 )?;
2831 session.peer.send(
2832 session.connection_id,
2833 build_channels_update(channels_for_user),
2834 )?;
2835 Ok(())
2836}
2837
2838/// Creates a new channel.
2839async fn create_channel(
2840 request: proto::CreateChannel,
2841 response: Response<proto::CreateChannel>,
2842 session: Session,
2843) -> Result<()> {
2844 let db = session.db().await;
2845
2846 let parent_id = request.parent_id.map(ChannelId::from_proto);
2847 let (channel, membership) = db
2848 .create_channel(&request.name, parent_id, session.user_id())
2849 .await?;
2850
2851 let root_id = channel.root_id();
2852 let channel = Channel::from_model(channel);
2853
2854 response.send(proto::CreateChannelResponse {
2855 channel: Some(channel.to_proto()),
2856 parent_id: request.parent_id,
2857 })?;
2858
2859 let mut connection_pool = session.connection_pool().await;
2860 if let Some(membership) = membership {
2861 connection_pool.subscribe_to_channel(
2862 membership.user_id,
2863 membership.channel_id,
2864 membership.role,
2865 );
2866 let update = proto::UpdateUserChannels {
2867 channel_memberships: vec![proto::ChannelMembership {
2868 channel_id: membership.channel_id.to_proto(),
2869 role: membership.role.into(),
2870 }],
2871 ..Default::default()
2872 };
2873 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2874 session.peer.send(connection_id, update.clone())?;
2875 }
2876 }
2877
2878 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2879 if !role.can_see_channel(channel.visibility) {
2880 continue;
2881 }
2882
2883 let update = proto::UpdateChannels {
2884 channels: vec![channel.to_proto()],
2885 ..Default::default()
2886 };
2887 session.peer.send(connection_id, update.clone())?;
2888 }
2889
2890 Ok(())
2891}
2892
2893/// Delete a channel
2894async fn delete_channel(
2895 request: proto::DeleteChannel,
2896 response: Response<proto::DeleteChannel>,
2897 session: Session,
2898) -> Result<()> {
2899 let db = session.db().await;
2900
2901 let channel_id = request.channel_id;
2902 let (root_channel, removed_channels) = db
2903 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2904 .await?;
2905 response.send(proto::Ack {})?;
2906
2907 // Notify members of removed channels
2908 let mut update = proto::UpdateChannels::default();
2909 update
2910 .delete_channels
2911 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2912
2913 let connection_pool = session.connection_pool().await;
2914 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2915 session.peer.send(connection_id, update.clone())?;
2916 }
2917
2918 Ok(())
2919}
2920
2921/// Invite someone to join a channel.
2922async fn invite_channel_member(
2923 request: proto::InviteChannelMember,
2924 response: Response<proto::InviteChannelMember>,
2925 session: Session,
2926) -> Result<()> {
2927 let db = session.db().await;
2928 let channel_id = ChannelId::from_proto(request.channel_id);
2929 let invitee_id = UserId::from_proto(request.user_id);
2930 let InviteMemberResult {
2931 channel,
2932 notifications,
2933 } = db
2934 .invite_channel_member(
2935 channel_id,
2936 invitee_id,
2937 session.user_id(),
2938 request.role().into(),
2939 )
2940 .await?;
2941
2942 let update = proto::UpdateChannels {
2943 channel_invitations: vec![channel.to_proto()],
2944 ..Default::default()
2945 };
2946
2947 let connection_pool = session.connection_pool().await;
2948 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2949 session.peer.send(connection_id, update.clone())?;
2950 }
2951
2952 send_notifications(&connection_pool, &session.peer, notifications);
2953
2954 response.send(proto::Ack {})?;
2955 Ok(())
2956}
2957
2958/// remove someone from a channel
2959async fn remove_channel_member(
2960 request: proto::RemoveChannelMember,
2961 response: Response<proto::RemoveChannelMember>,
2962 session: Session,
2963) -> Result<()> {
2964 let db = session.db().await;
2965 let channel_id = ChannelId::from_proto(request.channel_id);
2966 let member_id = UserId::from_proto(request.user_id);
2967
2968 let RemoveChannelMemberResult {
2969 membership_update,
2970 notification_id,
2971 } = db
2972 .remove_channel_member(channel_id, member_id, session.user_id())
2973 .await?;
2974
2975 let mut connection_pool = session.connection_pool().await;
2976 notify_membership_updated(
2977 &mut connection_pool,
2978 membership_update,
2979 member_id,
2980 &session.peer,
2981 );
2982 for connection_id in connection_pool.user_connection_ids(member_id) {
2983 if let Some(notification_id) = notification_id {
2984 session
2985 .peer
2986 .send(
2987 connection_id,
2988 proto::DeleteNotification {
2989 notification_id: notification_id.to_proto(),
2990 },
2991 )
2992 .trace_err();
2993 }
2994 }
2995
2996 response.send(proto::Ack {})?;
2997 Ok(())
2998}
2999
3000/// Toggle the channel between public and private.
3001/// Care is taken to maintain the invariant that public channels only descend from public channels,
3002/// (though members-only channels can appear at any point in the hierarchy).
3003async fn set_channel_visibility(
3004 request: proto::SetChannelVisibility,
3005 response: Response<proto::SetChannelVisibility>,
3006 session: Session,
3007) -> Result<()> {
3008 let db = session.db().await;
3009 let channel_id = ChannelId::from_proto(request.channel_id);
3010 let visibility = request.visibility().into();
3011
3012 let channel_model = db
3013 .set_channel_visibility(channel_id, visibility, session.user_id())
3014 .await?;
3015 let root_id = channel_model.root_id();
3016 let channel = Channel::from_model(channel_model);
3017
3018 let mut connection_pool = session.connection_pool().await;
3019 for (user_id, role) in connection_pool
3020 .channel_user_ids(root_id)
3021 .collect::<Vec<_>>()
3022 .into_iter()
3023 {
3024 let update = if role.can_see_channel(channel.visibility) {
3025 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3026 proto::UpdateChannels {
3027 channels: vec![channel.to_proto()],
3028 ..Default::default()
3029 }
3030 } else {
3031 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3032 proto::UpdateChannels {
3033 delete_channels: vec![channel.id.to_proto()],
3034 ..Default::default()
3035 }
3036 };
3037
3038 for connection_id in connection_pool.user_connection_ids(user_id) {
3039 session.peer.send(connection_id, update.clone())?;
3040 }
3041 }
3042
3043 response.send(proto::Ack {})?;
3044 Ok(())
3045}
3046
3047/// Alter the role for a user in the channel.
3048async fn set_channel_member_role(
3049 request: proto::SetChannelMemberRole,
3050 response: Response<proto::SetChannelMemberRole>,
3051 session: Session,
3052) -> Result<()> {
3053 let db = session.db().await;
3054 let channel_id = ChannelId::from_proto(request.channel_id);
3055 let member_id = UserId::from_proto(request.user_id);
3056 let result = db
3057 .set_channel_member_role(
3058 channel_id,
3059 session.user_id(),
3060 member_id,
3061 request.role().into(),
3062 )
3063 .await?;
3064
3065 match result {
3066 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3067 let mut connection_pool = session.connection_pool().await;
3068 notify_membership_updated(
3069 &mut connection_pool,
3070 membership_update,
3071 member_id,
3072 &session.peer,
3073 )
3074 }
3075 db::SetMemberRoleResult::InviteUpdated(channel) => {
3076 let update = proto::UpdateChannels {
3077 channel_invitations: vec![channel.to_proto()],
3078 ..Default::default()
3079 };
3080
3081 for connection_id in session
3082 .connection_pool()
3083 .await
3084 .user_connection_ids(member_id)
3085 {
3086 session.peer.send(connection_id, update.clone())?;
3087 }
3088 }
3089 }
3090
3091 response.send(proto::Ack {})?;
3092 Ok(())
3093}
3094
3095/// Change the name of a channel
3096async fn rename_channel(
3097 request: proto::RenameChannel,
3098 response: Response<proto::RenameChannel>,
3099 session: Session,
3100) -> Result<()> {
3101 let db = session.db().await;
3102 let channel_id = ChannelId::from_proto(request.channel_id);
3103 let channel_model = db
3104 .rename_channel(channel_id, session.user_id(), &request.name)
3105 .await?;
3106 let root_id = channel_model.root_id();
3107 let channel = Channel::from_model(channel_model);
3108
3109 response.send(proto::RenameChannelResponse {
3110 channel: Some(channel.to_proto()),
3111 })?;
3112
3113 let connection_pool = session.connection_pool().await;
3114 let update = proto::UpdateChannels {
3115 channels: vec![channel.to_proto()],
3116 ..Default::default()
3117 };
3118 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3119 if role.can_see_channel(channel.visibility) {
3120 session.peer.send(connection_id, update.clone())?;
3121 }
3122 }
3123
3124 Ok(())
3125}
3126
3127/// Move a channel to a new parent.
3128async fn move_channel(
3129 request: proto::MoveChannel,
3130 response: Response<proto::MoveChannel>,
3131 session: Session,
3132) -> Result<()> {
3133 let channel_id = ChannelId::from_proto(request.channel_id);
3134 let to = ChannelId::from_proto(request.to);
3135
3136 let (root_id, channels) = session
3137 .db()
3138 .await
3139 .move_channel(channel_id, to, session.user_id())
3140 .await?;
3141
3142 let connection_pool = session.connection_pool().await;
3143 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3144 let channels = channels
3145 .iter()
3146 .filter_map(|channel| {
3147 if role.can_see_channel(channel.visibility) {
3148 Some(channel.to_proto())
3149 } else {
3150 None
3151 }
3152 })
3153 .collect::<Vec<_>>();
3154 if channels.is_empty() {
3155 continue;
3156 }
3157
3158 let update = proto::UpdateChannels {
3159 channels,
3160 ..Default::default()
3161 };
3162
3163 session.peer.send(connection_id, update.clone())?;
3164 }
3165
3166 response.send(Ack {})?;
3167 Ok(())
3168}
3169
3170/// Get the list of channel members
3171async fn get_channel_members(
3172 request: proto::GetChannelMembers,
3173 response: Response<proto::GetChannelMembers>,
3174 session: Session,
3175) -> Result<()> {
3176 let db = session.db().await;
3177 let channel_id = ChannelId::from_proto(request.channel_id);
3178 let limit = if request.limit == 0 {
3179 u16::MAX as u64
3180 } else {
3181 request.limit
3182 };
3183 let (members, users) = db
3184 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3185 .await?;
3186 response.send(proto::GetChannelMembersResponse { members, users })?;
3187 Ok(())
3188}
3189
3190/// Accept or decline a channel invitation.
3191async fn respond_to_channel_invite(
3192 request: proto::RespondToChannelInvite,
3193 response: Response<proto::RespondToChannelInvite>,
3194 session: Session,
3195) -> Result<()> {
3196 let db = session.db().await;
3197 let channel_id = ChannelId::from_proto(request.channel_id);
3198 let RespondToChannelInvite {
3199 membership_update,
3200 notifications,
3201 } = db
3202 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3203 .await?;
3204
3205 let mut connection_pool = session.connection_pool().await;
3206 if let Some(membership_update) = membership_update {
3207 notify_membership_updated(
3208 &mut connection_pool,
3209 membership_update,
3210 session.user_id(),
3211 &session.peer,
3212 );
3213 } else {
3214 let update = proto::UpdateChannels {
3215 remove_channel_invitations: vec![channel_id.to_proto()],
3216 ..Default::default()
3217 };
3218
3219 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3220 session.peer.send(connection_id, update.clone())?;
3221 }
3222 };
3223
3224 send_notifications(&connection_pool, &session.peer, notifications);
3225
3226 response.send(proto::Ack {})?;
3227
3228 Ok(())
3229}
3230
3231/// Join the channels' room
3232async fn join_channel(
3233 request: proto::JoinChannel,
3234 response: Response<proto::JoinChannel>,
3235 session: Session,
3236) -> Result<()> {
3237 let channel_id = ChannelId::from_proto(request.channel_id);
3238 join_channel_internal(channel_id, Box::new(response), session).await
3239}
3240
3241trait JoinChannelInternalResponse {
3242 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3243}
3244impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3245 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3246 Response::<proto::JoinChannel>::send(self, result)
3247 }
3248}
3249impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3250 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3251 Response::<proto::JoinRoom>::send(self, result)
3252 }
3253}
3254
3255async fn join_channel_internal(
3256 channel_id: ChannelId,
3257 response: Box<impl JoinChannelInternalResponse>,
3258 session: Session,
3259) -> Result<()> {
3260 let joined_room = {
3261 let mut db = session.db().await;
3262 // If zed quits without leaving the room, and the user re-opens zed before the
3263 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3264 // room they were in.
3265 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3266 tracing::info!(
3267 stale_connection_id = %connection,
3268 "cleaning up stale connection",
3269 );
3270 drop(db);
3271 leave_room_for_session(&session, connection).await?;
3272 db = session.db().await;
3273 }
3274
3275 let (joined_room, membership_updated, role) = db
3276 .join_channel(channel_id, session.user_id(), session.connection_id)
3277 .await?;
3278
3279 let live_kit_connection_info =
3280 session
3281 .app_state
3282 .livekit_client
3283 .as_ref()
3284 .and_then(|live_kit| {
3285 let (can_publish, token) = if role == ChannelRole::Guest {
3286 (
3287 false,
3288 live_kit
3289 .guest_token(
3290 &joined_room.room.livekit_room,
3291 &session.user_id().to_string(),
3292 )
3293 .trace_err()?,
3294 )
3295 } else {
3296 (
3297 true,
3298 live_kit
3299 .room_token(
3300 &joined_room.room.livekit_room,
3301 &session.user_id().to_string(),
3302 )
3303 .trace_err()?,
3304 )
3305 };
3306
3307 Some(LiveKitConnectionInfo {
3308 server_url: live_kit.url().into(),
3309 token,
3310 can_publish,
3311 })
3312 });
3313
3314 response.send(proto::JoinRoomResponse {
3315 room: Some(joined_room.room.clone()),
3316 channel_id: joined_room
3317 .channel
3318 .as_ref()
3319 .map(|channel| channel.id.to_proto()),
3320 live_kit_connection_info,
3321 })?;
3322
3323 let mut connection_pool = session.connection_pool().await;
3324 if let Some(membership_updated) = membership_updated {
3325 notify_membership_updated(
3326 &mut connection_pool,
3327 membership_updated,
3328 session.user_id(),
3329 &session.peer,
3330 );
3331 }
3332
3333 room_updated(&joined_room.room, &session.peer);
3334
3335 joined_room
3336 };
3337
3338 channel_updated(
3339 &joined_room
3340 .channel
3341 .ok_or_else(|| anyhow!("channel not returned"))?,
3342 &joined_room.room,
3343 &session.peer,
3344 &*session.connection_pool().await,
3345 );
3346
3347 update_user_contacts(session.user_id(), &session).await?;
3348 Ok(())
3349}
3350
3351/// Start editing the channel notes
3352async fn join_channel_buffer(
3353 request: proto::JoinChannelBuffer,
3354 response: Response<proto::JoinChannelBuffer>,
3355 session: Session,
3356) -> Result<()> {
3357 let db = session.db().await;
3358 let channel_id = ChannelId::from_proto(request.channel_id);
3359
3360 let open_response = db
3361 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3362 .await?;
3363
3364 let collaborators = open_response.collaborators.clone();
3365 response.send(open_response)?;
3366
3367 let update = UpdateChannelBufferCollaborators {
3368 channel_id: channel_id.to_proto(),
3369 collaborators: collaborators.clone(),
3370 };
3371 channel_buffer_updated(
3372 session.connection_id,
3373 collaborators
3374 .iter()
3375 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3376 &update,
3377 &session.peer,
3378 );
3379
3380 Ok(())
3381}
3382
3383/// Edit the channel notes
3384async fn update_channel_buffer(
3385 request: proto::UpdateChannelBuffer,
3386 session: Session,
3387) -> Result<()> {
3388 let db = session.db().await;
3389 let channel_id = ChannelId::from_proto(request.channel_id);
3390
3391 let (collaborators, epoch, version) = db
3392 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3393 .await?;
3394
3395 channel_buffer_updated(
3396 session.connection_id,
3397 collaborators.clone(),
3398 &proto::UpdateChannelBuffer {
3399 channel_id: channel_id.to_proto(),
3400 operations: request.operations,
3401 },
3402 &session.peer,
3403 );
3404
3405 let pool = &*session.connection_pool().await;
3406
3407 let non_collaborators =
3408 pool.channel_connection_ids(channel_id)
3409 .filter_map(|(connection_id, _)| {
3410 if collaborators.contains(&connection_id) {
3411 None
3412 } else {
3413 Some(connection_id)
3414 }
3415 });
3416
3417 broadcast(None, non_collaborators, |peer_id| {
3418 session.peer.send(
3419 peer_id,
3420 proto::UpdateChannels {
3421 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3422 channel_id: channel_id.to_proto(),
3423 epoch: epoch as u64,
3424 version: version.clone(),
3425 }],
3426 ..Default::default()
3427 },
3428 )
3429 });
3430
3431 Ok(())
3432}
3433
3434/// Rejoin the channel notes after a connection blip
3435async fn rejoin_channel_buffers(
3436 request: proto::RejoinChannelBuffers,
3437 response: Response<proto::RejoinChannelBuffers>,
3438 session: Session,
3439) -> Result<()> {
3440 let db = session.db().await;
3441 let buffers = db
3442 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3443 .await?;
3444
3445 for rejoined_buffer in &buffers {
3446 let collaborators_to_notify = rejoined_buffer
3447 .buffer
3448 .collaborators
3449 .iter()
3450 .filter_map(|c| Some(c.peer_id?.into()));
3451 channel_buffer_updated(
3452 session.connection_id,
3453 collaborators_to_notify,
3454 &proto::UpdateChannelBufferCollaborators {
3455 channel_id: rejoined_buffer.buffer.channel_id,
3456 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3457 },
3458 &session.peer,
3459 );
3460 }
3461
3462 response.send(proto::RejoinChannelBuffersResponse {
3463 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3464 })?;
3465
3466 Ok(())
3467}
3468
3469/// Stop editing the channel notes
3470async fn leave_channel_buffer(
3471 request: proto::LeaveChannelBuffer,
3472 response: Response<proto::LeaveChannelBuffer>,
3473 session: Session,
3474) -> Result<()> {
3475 let db = session.db().await;
3476 let channel_id = ChannelId::from_proto(request.channel_id);
3477
3478 let left_buffer = db
3479 .leave_channel_buffer(channel_id, session.connection_id)
3480 .await?;
3481
3482 response.send(Ack {})?;
3483
3484 channel_buffer_updated(
3485 session.connection_id,
3486 left_buffer.connections,
3487 &proto::UpdateChannelBufferCollaborators {
3488 channel_id: channel_id.to_proto(),
3489 collaborators: left_buffer.collaborators,
3490 },
3491 &session.peer,
3492 );
3493
3494 Ok(())
3495}
3496
3497fn channel_buffer_updated<T: EnvelopedMessage>(
3498 sender_id: ConnectionId,
3499 collaborators: impl IntoIterator<Item = ConnectionId>,
3500 message: &T,
3501 peer: &Peer,
3502) {
3503 broadcast(Some(sender_id), collaborators, |peer_id| {
3504 peer.send(peer_id, message.clone())
3505 });
3506}
3507
3508fn send_notifications(
3509 connection_pool: &ConnectionPool,
3510 peer: &Peer,
3511 notifications: db::NotificationBatch,
3512) {
3513 for (user_id, notification) in notifications {
3514 for connection_id in connection_pool.user_connection_ids(user_id) {
3515 if let Err(error) = peer.send(
3516 connection_id,
3517 proto::AddNotification {
3518 notification: Some(notification.clone()),
3519 },
3520 ) {
3521 tracing::error!(
3522 "failed to send notification to {:?} {}",
3523 connection_id,
3524 error
3525 );
3526 }
3527 }
3528 }
3529}
3530
3531/// Send a message to the channel
3532async fn send_channel_message(
3533 request: proto::SendChannelMessage,
3534 response: Response<proto::SendChannelMessage>,
3535 session: Session,
3536) -> Result<()> {
3537 // Validate the message body.
3538 let body = request.body.trim().to_string();
3539 if body.len() > MAX_MESSAGE_LEN {
3540 return Err(anyhow!("message is too long"))?;
3541 }
3542 if body.is_empty() {
3543 return Err(anyhow!("message can't be blank"))?;
3544 }
3545
3546 // TODO: adjust mentions if body is trimmed
3547
3548 let timestamp = OffsetDateTime::now_utc();
3549 let nonce = request
3550 .nonce
3551 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3552
3553 let channel_id = ChannelId::from_proto(request.channel_id);
3554 let CreatedChannelMessage {
3555 message_id,
3556 participant_connection_ids,
3557 notifications,
3558 } = session
3559 .db()
3560 .await
3561 .create_channel_message(
3562 channel_id,
3563 session.user_id(),
3564 &body,
3565 &request.mentions,
3566 timestamp,
3567 nonce.clone().into(),
3568 request.reply_to_message_id.map(MessageId::from_proto),
3569 )
3570 .await?;
3571
3572 let message = proto::ChannelMessage {
3573 sender_id: session.user_id().to_proto(),
3574 id: message_id.to_proto(),
3575 body,
3576 mentions: request.mentions,
3577 timestamp: timestamp.unix_timestamp() as u64,
3578 nonce: Some(nonce),
3579 reply_to_message_id: request.reply_to_message_id,
3580 edited_at: None,
3581 };
3582 broadcast(
3583 Some(session.connection_id),
3584 participant_connection_ids.clone(),
3585 |connection| {
3586 session.peer.send(
3587 connection,
3588 proto::ChannelMessageSent {
3589 channel_id: channel_id.to_proto(),
3590 message: Some(message.clone()),
3591 },
3592 )
3593 },
3594 );
3595 response.send(proto::SendChannelMessageResponse {
3596 message: Some(message),
3597 })?;
3598
3599 let pool = &*session.connection_pool().await;
3600 let non_participants =
3601 pool.channel_connection_ids(channel_id)
3602 .filter_map(|(connection_id, _)| {
3603 if participant_connection_ids.contains(&connection_id) {
3604 None
3605 } else {
3606 Some(connection_id)
3607 }
3608 });
3609 broadcast(None, non_participants, |peer_id| {
3610 session.peer.send(
3611 peer_id,
3612 proto::UpdateChannels {
3613 latest_channel_message_ids: vec![proto::ChannelMessageId {
3614 channel_id: channel_id.to_proto(),
3615 message_id: message_id.to_proto(),
3616 }],
3617 ..Default::default()
3618 },
3619 )
3620 });
3621 send_notifications(pool, &session.peer, notifications);
3622
3623 Ok(())
3624}
3625
3626/// Delete a channel message
3627async fn remove_channel_message(
3628 request: proto::RemoveChannelMessage,
3629 response: Response<proto::RemoveChannelMessage>,
3630 session: Session,
3631) -> Result<()> {
3632 let channel_id = ChannelId::from_proto(request.channel_id);
3633 let message_id = MessageId::from_proto(request.message_id);
3634 let (connection_ids, existing_notification_ids) = session
3635 .db()
3636 .await
3637 .remove_channel_message(channel_id, message_id, session.user_id())
3638 .await?;
3639
3640 broadcast(
3641 Some(session.connection_id),
3642 connection_ids,
3643 move |connection| {
3644 session.peer.send(connection, request.clone())?;
3645
3646 for notification_id in &existing_notification_ids {
3647 session.peer.send(
3648 connection,
3649 proto::DeleteNotification {
3650 notification_id: (*notification_id).to_proto(),
3651 },
3652 )?;
3653 }
3654
3655 Ok(())
3656 },
3657 );
3658 response.send(proto::Ack {})?;
3659 Ok(())
3660}
3661
3662async fn update_channel_message(
3663 request: proto::UpdateChannelMessage,
3664 response: Response<proto::UpdateChannelMessage>,
3665 session: Session,
3666) -> Result<()> {
3667 let channel_id = ChannelId::from_proto(request.channel_id);
3668 let message_id = MessageId::from_proto(request.message_id);
3669 let updated_at = OffsetDateTime::now_utc();
3670 let UpdatedChannelMessage {
3671 message_id,
3672 participant_connection_ids,
3673 notifications,
3674 reply_to_message_id,
3675 timestamp,
3676 deleted_mention_notification_ids,
3677 updated_mention_notifications,
3678 } = session
3679 .db()
3680 .await
3681 .update_channel_message(
3682 channel_id,
3683 message_id,
3684 session.user_id(),
3685 request.body.as_str(),
3686 &request.mentions,
3687 updated_at,
3688 )
3689 .await?;
3690
3691 let nonce = request
3692 .nonce
3693 .clone()
3694 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3695
3696 let message = proto::ChannelMessage {
3697 sender_id: session.user_id().to_proto(),
3698 id: message_id.to_proto(),
3699 body: request.body.clone(),
3700 mentions: request.mentions.clone(),
3701 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3702 nonce: Some(nonce),
3703 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3704 edited_at: Some(updated_at.unix_timestamp() as u64),
3705 };
3706
3707 response.send(proto::Ack {})?;
3708
3709 let pool = &*session.connection_pool().await;
3710 broadcast(
3711 Some(session.connection_id),
3712 participant_connection_ids,
3713 |connection| {
3714 session.peer.send(
3715 connection,
3716 proto::ChannelMessageUpdate {
3717 channel_id: channel_id.to_proto(),
3718 message: Some(message.clone()),
3719 },
3720 )?;
3721
3722 for notification_id in &deleted_mention_notification_ids {
3723 session.peer.send(
3724 connection,
3725 proto::DeleteNotification {
3726 notification_id: (*notification_id).to_proto(),
3727 },
3728 )?;
3729 }
3730
3731 for notification in &updated_mention_notifications {
3732 session.peer.send(
3733 connection,
3734 proto::UpdateNotification {
3735 notification: Some(notification.clone()),
3736 },
3737 )?;
3738 }
3739
3740 Ok(())
3741 },
3742 );
3743
3744 send_notifications(pool, &session.peer, notifications);
3745
3746 Ok(())
3747}
3748
3749/// Mark a channel message as read
3750async fn acknowledge_channel_message(
3751 request: proto::AckChannelMessage,
3752 session: Session,
3753) -> Result<()> {
3754 let channel_id = ChannelId::from_proto(request.channel_id);
3755 let message_id = MessageId::from_proto(request.message_id);
3756 let notifications = session
3757 .db()
3758 .await
3759 .observe_channel_message(channel_id, session.user_id(), message_id)
3760 .await?;
3761 send_notifications(
3762 &*session.connection_pool().await,
3763 &session.peer,
3764 notifications,
3765 );
3766 Ok(())
3767}
3768
3769/// Mark a buffer version as synced
3770async fn acknowledge_buffer_version(
3771 request: proto::AckBufferOperation,
3772 session: Session,
3773) -> Result<()> {
3774 let buffer_id = BufferId::from_proto(request.buffer_id);
3775 session
3776 .db()
3777 .await
3778 .observe_buffer_version(
3779 buffer_id,
3780 session.user_id(),
3781 request.epoch as i32,
3782 &request.version,
3783 )
3784 .await?;
3785 Ok(())
3786}
3787
3788/// Get a Supermaven API key for the user
3789async fn get_supermaven_api_key(
3790 _request: proto::GetSupermavenApiKey,
3791 response: Response<proto::GetSupermavenApiKey>,
3792 session: Session,
3793) -> Result<()> {
3794 let user_id: String = session.user_id().to_string();
3795 if !session.is_staff() {
3796 return Err(anyhow!("supermaven not enabled for this account"))?;
3797 }
3798
3799 let email = session
3800 .email()
3801 .ok_or_else(|| anyhow!("user must have an email"))?;
3802
3803 let supermaven_admin_api = session
3804 .supermaven_client
3805 .as_ref()
3806 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3807
3808 let result = supermaven_admin_api
3809 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3810 .await?;
3811
3812 response.send(proto::GetSupermavenApiKeyResponse {
3813 api_key: result.api_key,
3814 })?;
3815
3816 Ok(())
3817}
3818
3819/// Start receiving chat updates for a channel
3820async fn join_channel_chat(
3821 request: proto::JoinChannelChat,
3822 response: Response<proto::JoinChannelChat>,
3823 session: Session,
3824) -> Result<()> {
3825 let channel_id = ChannelId::from_proto(request.channel_id);
3826
3827 let db = session.db().await;
3828 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3829 .await?;
3830 let messages = db
3831 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3832 .await?;
3833 response.send(proto::JoinChannelChatResponse {
3834 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3835 messages,
3836 })?;
3837 Ok(())
3838}
3839
3840/// Stop receiving chat updates for a channel
3841async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3842 let channel_id = ChannelId::from_proto(request.channel_id);
3843 session
3844 .db()
3845 .await
3846 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3847 .await?;
3848 Ok(())
3849}
3850
3851/// Retrieve the chat history for a channel
3852async fn get_channel_messages(
3853 request: proto::GetChannelMessages,
3854 response: Response<proto::GetChannelMessages>,
3855 session: Session,
3856) -> Result<()> {
3857 let channel_id = ChannelId::from_proto(request.channel_id);
3858 let messages = session
3859 .db()
3860 .await
3861 .get_channel_messages(
3862 channel_id,
3863 session.user_id(),
3864 MESSAGE_COUNT_PER_PAGE,
3865 Some(MessageId::from_proto(request.before_message_id)),
3866 )
3867 .await?;
3868 response.send(proto::GetChannelMessagesResponse {
3869 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3870 messages,
3871 })?;
3872 Ok(())
3873}
3874
3875/// Retrieve specific chat messages
3876async fn get_channel_messages_by_id(
3877 request: proto::GetChannelMessagesById,
3878 response: Response<proto::GetChannelMessagesById>,
3879 session: Session,
3880) -> Result<()> {
3881 let message_ids = request
3882 .message_ids
3883 .iter()
3884 .map(|id| MessageId::from_proto(*id))
3885 .collect::<Vec<_>>();
3886 let messages = session
3887 .db()
3888 .await
3889 .get_channel_messages_by_id(session.user_id(), &message_ids)
3890 .await?;
3891 response.send(proto::GetChannelMessagesResponse {
3892 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3893 messages,
3894 })?;
3895 Ok(())
3896}
3897
3898/// Retrieve the current users notifications
3899async fn get_notifications(
3900 request: proto::GetNotifications,
3901 response: Response<proto::GetNotifications>,
3902 session: Session,
3903) -> Result<()> {
3904 let notifications = session
3905 .db()
3906 .await
3907 .get_notifications(
3908 session.user_id(),
3909 NOTIFICATION_COUNT_PER_PAGE,
3910 request.before_id.map(db::NotificationId::from_proto),
3911 )
3912 .await?;
3913 response.send(proto::GetNotificationsResponse {
3914 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3915 notifications,
3916 })?;
3917 Ok(())
3918}
3919
3920/// Mark notifications as read
3921async fn mark_notification_as_read(
3922 request: proto::MarkNotificationRead,
3923 response: Response<proto::MarkNotificationRead>,
3924 session: Session,
3925) -> Result<()> {
3926 let database = &session.db().await;
3927 let notifications = database
3928 .mark_notification_as_read_by_id(
3929 session.user_id(),
3930 NotificationId::from_proto(request.notification_id),
3931 )
3932 .await?;
3933 send_notifications(
3934 &*session.connection_pool().await,
3935 &session.peer,
3936 notifications,
3937 );
3938 response.send(proto::Ack {})?;
3939 Ok(())
3940}
3941
3942/// Get the current users information
3943async fn get_private_user_info(
3944 _request: proto::GetPrivateUserInfo,
3945 response: Response<proto::GetPrivateUserInfo>,
3946 session: Session,
3947) -> Result<()> {
3948 let db = session.db().await;
3949
3950 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3951 let user = db
3952 .get_user_by_id(session.user_id())
3953 .await?
3954 .ok_or_else(|| anyhow!("user not found"))?;
3955 let flags = db.get_user_flags(session.user_id()).await?;
3956
3957 response.send(proto::GetPrivateUserInfoResponse {
3958 metrics_id,
3959 staff: user.admin,
3960 flags,
3961 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3962 })?;
3963 Ok(())
3964}
3965
3966/// Accept the terms of service (tos) on behalf of the current user
3967async fn accept_terms_of_service(
3968 _request: proto::AcceptTermsOfService,
3969 response: Response<proto::AcceptTermsOfService>,
3970 session: Session,
3971) -> Result<()> {
3972 let db = session.db().await;
3973
3974 let accepted_tos_at = Utc::now();
3975 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3976 .await?;
3977
3978 response.send(proto::AcceptTermsOfServiceResponse {
3979 accepted_tos_at: accepted_tos_at.timestamp() as u64,
3980 })?;
3981 Ok(())
3982}
3983
3984/// The minimum account age an account must have in order to use the LLM service.
3985pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3986
3987async fn get_llm_api_token(
3988 _request: proto::GetLlmToken,
3989 response: Response<proto::GetLlmToken>,
3990 session: Session,
3991) -> Result<()> {
3992 let db = session.db().await;
3993
3994 let flags = db.get_user_flags(session.user_id()).await?;
3995 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
3996
3997 if !session.is_staff() && !has_language_models_feature_flag {
3998 Err(anyhow!("permission denied"))?
3999 }
4000
4001 let user_id = session.user_id();
4002 let user = db
4003 .get_user_by_id(user_id)
4004 .await?
4005 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4006
4007 if user.accepted_tos_at.is_none() {
4008 Err(anyhow!("terms of service not accepted"))?
4009 }
4010
4011 let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
4012 let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4013 let billing_preferences = db.get_billing_preferences(user.id).await?;
4014
4015 let token = LlmTokenClaims::create(
4016 &user,
4017 session.is_staff(),
4018 billing_preferences,
4019 &flags,
4020 has_legacy_llm_subscription,
4021 billing_subscription,
4022 session.system_id.clone(),
4023 &session.app_state.config,
4024 )?;
4025 response.send(proto::GetLlmTokenResponse { token })?;
4026 Ok(())
4027}
4028
4029fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4030 let message = match message {
4031 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4032 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4033 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4034 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4035 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4036 code: frame.code.into(),
4037 reason: frame.reason.as_str().to_owned().into(),
4038 })),
4039 // We should never receive a frame while reading the message, according
4040 // to the `tungstenite` maintainers:
4041 //
4042 // > It cannot occur when you read messages from the WebSocket, but it
4043 // > can be used when you want to send the raw frames (e.g. you want to
4044 // > send the frames to the WebSocket without composing the full message first).
4045 // >
4046 // > — https://github.com/snapview/tungstenite-rs/issues/268
4047 TungsteniteMessage::Frame(_) => {
4048 bail!("received an unexpected frame while reading the message")
4049 }
4050 };
4051
4052 Ok(message)
4053}
4054
4055fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4056 match message {
4057 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4058 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4059 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4060 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4061 AxumMessage::Close(frame) => {
4062 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4063 code: frame.code.into(),
4064 reason: frame.reason.as_ref().into(),
4065 }))
4066 }
4067 }
4068}
4069
4070fn notify_membership_updated(
4071 connection_pool: &mut ConnectionPool,
4072 result: MembershipUpdated,
4073 user_id: UserId,
4074 peer: &Peer,
4075) {
4076 for membership in &result.new_channels.channel_memberships {
4077 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4078 }
4079 for channel_id in &result.removed_channels {
4080 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4081 }
4082
4083 let user_channels_update = proto::UpdateUserChannels {
4084 channel_memberships: result
4085 .new_channels
4086 .channel_memberships
4087 .iter()
4088 .map(|cm| proto::ChannelMembership {
4089 channel_id: cm.channel_id.to_proto(),
4090 role: cm.role.into(),
4091 })
4092 .collect(),
4093 ..Default::default()
4094 };
4095
4096 let mut update = build_channels_update(result.new_channels);
4097 update.delete_channels = result
4098 .removed_channels
4099 .into_iter()
4100 .map(|id| id.to_proto())
4101 .collect();
4102 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4103
4104 for connection_id in connection_pool.user_connection_ids(user_id) {
4105 peer.send(connection_id, user_channels_update.clone())
4106 .trace_err();
4107 peer.send(connection_id, update.clone()).trace_err();
4108 }
4109}
4110
4111fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4112 proto::UpdateUserChannels {
4113 channel_memberships: channels
4114 .channel_memberships
4115 .iter()
4116 .map(|m| proto::ChannelMembership {
4117 channel_id: m.channel_id.to_proto(),
4118 role: m.role.into(),
4119 })
4120 .collect(),
4121 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4122 observed_channel_message_id: channels.observed_channel_messages.clone(),
4123 }
4124}
4125
4126fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4127 let mut update = proto::UpdateChannels::default();
4128
4129 for channel in channels.channels {
4130 update.channels.push(channel.to_proto());
4131 }
4132
4133 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4134 update.latest_channel_message_ids = channels.latest_channel_messages;
4135
4136 for (channel_id, participants) in channels.channel_participants {
4137 update
4138 .channel_participants
4139 .push(proto::ChannelParticipants {
4140 channel_id: channel_id.to_proto(),
4141 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4142 });
4143 }
4144
4145 for channel in channels.invited_channels {
4146 update.channel_invitations.push(channel.to_proto());
4147 }
4148
4149 update
4150}
4151
4152fn build_initial_contacts_update(
4153 contacts: Vec<db::Contact>,
4154 pool: &ConnectionPool,
4155) -> proto::UpdateContacts {
4156 let mut update = proto::UpdateContacts::default();
4157
4158 for contact in contacts {
4159 match contact {
4160 db::Contact::Accepted { user_id, busy } => {
4161 update.contacts.push(contact_for_user(user_id, busy, pool));
4162 }
4163 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4164 db::Contact::Incoming { user_id } => {
4165 update
4166 .incoming_requests
4167 .push(proto::IncomingContactRequest {
4168 requester_id: user_id.to_proto(),
4169 })
4170 }
4171 }
4172 }
4173
4174 update
4175}
4176
4177fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4178 proto::Contact {
4179 user_id: user_id.to_proto(),
4180 online: pool.is_user_online(user_id),
4181 busy,
4182 }
4183}
4184
4185fn room_updated(room: &proto::Room, peer: &Peer) {
4186 broadcast(
4187 None,
4188 room.participants
4189 .iter()
4190 .filter_map(|participant| Some(participant.peer_id?.into())),
4191 |peer_id| {
4192 peer.send(
4193 peer_id,
4194 proto::RoomUpdated {
4195 room: Some(room.clone()),
4196 },
4197 )
4198 },
4199 );
4200}
4201
4202fn channel_updated(
4203 channel: &db::channel::Model,
4204 room: &proto::Room,
4205 peer: &Peer,
4206 pool: &ConnectionPool,
4207) {
4208 let participants = room
4209 .participants
4210 .iter()
4211 .map(|p| p.user_id)
4212 .collect::<Vec<_>>();
4213
4214 broadcast(
4215 None,
4216 pool.channel_connection_ids(channel.root_id())
4217 .filter_map(|(channel_id, role)| {
4218 role.can_see_channel(channel.visibility)
4219 .then_some(channel_id)
4220 }),
4221 |peer_id| {
4222 peer.send(
4223 peer_id,
4224 proto::UpdateChannels {
4225 channel_participants: vec![proto::ChannelParticipants {
4226 channel_id: channel.id.to_proto(),
4227 participant_user_ids: participants.clone(),
4228 }],
4229 ..Default::default()
4230 },
4231 )
4232 },
4233 );
4234}
4235
4236async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4237 let db = session.db().await;
4238
4239 let contacts = db.get_contacts(user_id).await?;
4240 let busy = db.is_user_busy(user_id).await?;
4241
4242 let pool = session.connection_pool().await;
4243 let updated_contact = contact_for_user(user_id, busy, &pool);
4244 for contact in contacts {
4245 if let db::Contact::Accepted {
4246 user_id: contact_user_id,
4247 ..
4248 } = contact
4249 {
4250 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4251 session
4252 .peer
4253 .send(
4254 contact_conn_id,
4255 proto::UpdateContacts {
4256 contacts: vec![updated_contact.clone()],
4257 remove_contacts: Default::default(),
4258 incoming_requests: Default::default(),
4259 remove_incoming_requests: Default::default(),
4260 outgoing_requests: Default::default(),
4261 remove_outgoing_requests: Default::default(),
4262 },
4263 )
4264 .trace_err();
4265 }
4266 }
4267 }
4268 Ok(())
4269}
4270
4271async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4272 let mut contacts_to_update = HashSet::default();
4273
4274 let room_id;
4275 let canceled_calls_to_user_ids;
4276 let livekit_room;
4277 let delete_livekit_room;
4278 let room;
4279 let channel;
4280
4281 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4282 contacts_to_update.insert(session.user_id());
4283
4284 for project in left_room.left_projects.values() {
4285 project_left(project, session);
4286 }
4287
4288 room_id = RoomId::from_proto(left_room.room.id);
4289 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4290 livekit_room = mem::take(&mut left_room.room.livekit_room);
4291 delete_livekit_room = left_room.deleted;
4292 room = mem::take(&mut left_room.room);
4293 channel = mem::take(&mut left_room.channel);
4294
4295 room_updated(&room, &session.peer);
4296 } else {
4297 return Ok(());
4298 }
4299
4300 if let Some(channel) = channel {
4301 channel_updated(
4302 &channel,
4303 &room,
4304 &session.peer,
4305 &*session.connection_pool().await,
4306 );
4307 }
4308
4309 {
4310 let pool = session.connection_pool().await;
4311 for canceled_user_id in canceled_calls_to_user_ids {
4312 for connection_id in pool.user_connection_ids(canceled_user_id) {
4313 session
4314 .peer
4315 .send(
4316 connection_id,
4317 proto::CallCanceled {
4318 room_id: room_id.to_proto(),
4319 },
4320 )
4321 .trace_err();
4322 }
4323 contacts_to_update.insert(canceled_user_id);
4324 }
4325 }
4326
4327 for contact_user_id in contacts_to_update {
4328 update_user_contacts(contact_user_id, session).await?;
4329 }
4330
4331 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4332 live_kit
4333 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4334 .await
4335 .trace_err();
4336
4337 if delete_livekit_room {
4338 live_kit.delete_room(livekit_room).await.trace_err();
4339 }
4340 }
4341
4342 Ok(())
4343}
4344
4345async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4346 let left_channel_buffers = session
4347 .db()
4348 .await
4349 .leave_channel_buffers(session.connection_id)
4350 .await?;
4351
4352 for left_buffer in left_channel_buffers {
4353 channel_buffer_updated(
4354 session.connection_id,
4355 left_buffer.connections,
4356 &proto::UpdateChannelBufferCollaborators {
4357 channel_id: left_buffer.channel_id.to_proto(),
4358 collaborators: left_buffer.collaborators,
4359 },
4360 &session.peer,
4361 );
4362 }
4363
4364 Ok(())
4365}
4366
4367fn project_left(project: &db::LeftProject, session: &Session) {
4368 for connection_id in &project.connection_ids {
4369 if project.should_unshare {
4370 session
4371 .peer
4372 .send(
4373 *connection_id,
4374 proto::UnshareProject {
4375 project_id: project.id.to_proto(),
4376 },
4377 )
4378 .trace_err();
4379 } else {
4380 session
4381 .peer
4382 .send(
4383 *connection_id,
4384 proto::RemoveProjectCollaborator {
4385 project_id: project.id.to_proto(),
4386 peer_id: Some(session.connection_id.into()),
4387 },
4388 )
4389 .trace_err();
4390 }
4391 }
4392}
4393
4394pub trait ResultExt {
4395 type Ok;
4396
4397 fn trace_err(self) -> Option<Self::Ok>;
4398}
4399
4400impl<T, E> ResultExt for Result<T, E>
4401where
4402 E: std::fmt::Debug,
4403{
4404 type Ok = T;
4405
4406 #[track_caller]
4407 fn trace_err(self) -> Option<T> {
4408 match self {
4409 Ok(value) => Some(value),
4410 Err(error) => {
4411 tracing::error!("{:?}", error);
4412 None
4413 }
4414 }
4415 }
4416}