1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::db::billing_subscription::SubscriptionKind;
5use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
6use crate::{
7 AppState, Error, Result, auth,
8 db::{
9 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
10 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
11 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
12 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15};
16use anyhow::{Context as _, anyhow, bail};
17use async_tungstenite::tungstenite::{
18 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
19};
20use axum::{
21 Extension, Router, TypedHeader,
22 body::Body,
23 extract::{
24 ConnectInfo, WebSocketUpgrade,
25 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use reqwest_client::ReqwestClient;
38use rpc::proto::split_repository_update;
39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use time::OffsetDateTime;
70use tokio::sync::{MutexGuard, Semaphore, watch};
71use tower::ServiceBuilder;
72use tracing::{
73 Instrument,
74 field::{self},
75 info_span, instrument,
76};
77
78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
79
80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
82
83const MESSAGE_COUNT_PER_PAGE: usize = 100;
84const MAX_MESSAGE_LEN: usize = 1024;
85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
86
87type MessageHandler =
88 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
89
90struct Response<R> {
91 peer: Arc<Peer>,
92 receipt: Receipt<R>,
93 responded: Arc<AtomicBool>,
94}
95
96impl<R: RequestMessage> Response<R> {
97 fn send(self, payload: R::Response) -> Result<()> {
98 self.responded.store(true, SeqCst);
99 self.peer.respond(self.receipt, payload)?;
100 Ok(())
101 }
102}
103
104#[derive(Clone, Debug)]
105pub enum Principal {
106 User(User),
107 Impersonated { user: User, admin: User },
108}
109
110impl Principal {
111 fn update_span(&self, span: &tracing::Span) {
112 match &self {
113 Principal::User(user) => {
114 span.record("user_id", user.id.0);
115 span.record("login", &user.github_login);
116 }
117 Principal::Impersonated { user, admin } => {
118 span.record("user_id", user.id.0);
119 span.record("login", &user.github_login);
120 span.record("impersonator", &admin.github_login);
121 }
122 }
123 }
124}
125
126#[derive(Clone)]
127struct Session {
128 principal: Principal,
129 connection_id: ConnectionId,
130 db: Arc<tokio::sync::Mutex<DbHandle>>,
131 peer: Arc<Peer>,
132 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
133 app_state: Arc<AppState>,
134 supermaven_client: Option<Arc<SupermavenAdminApi>>,
135 /// The GeoIP country code for the user.
136 #[allow(unused)]
137 geoip_country_code: Option<String>,
138 system_id: Option<String>,
139 _executor: Executor,
140}
141
142impl Session {
143 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
144 #[cfg(test)]
145 tokio::task::yield_now().await;
146 let guard = self.db.lock().await;
147 #[cfg(test)]
148 tokio::task::yield_now().await;
149 guard
150 }
151
152 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
153 #[cfg(test)]
154 tokio::task::yield_now().await;
155 let guard = self.connection_pool.lock();
156 ConnectionPoolGuard {
157 guard,
158 _not_send: PhantomData,
159 }
160 }
161
162 fn is_staff(&self) -> bool {
163 match &self.principal {
164 Principal::User(user) => user.admin,
165 Principal::Impersonated { .. } => true,
166 }
167 }
168
169 pub async fn has_llm_subscription(
170 &self,
171 db: &MutexGuard<'_, DbHandle>,
172 ) -> anyhow::Result<bool> {
173 if self.is_staff() {
174 return Ok(true);
175 }
176
177 let user_id = self.user_id();
178
179 Ok(db.has_active_billing_subscription(user_id).await?)
180 }
181
182 pub async fn current_plan(&self, db: &MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
183 if self.is_staff() {
184 return Ok(proto::Plan::ZedPro);
185 }
186
187 let user_id = self.user_id();
188
189 let subscription = db.get_active_billing_subscription(user_id).await?;
190 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
191
192 let plan = if let Some(subscription_kind) = subscription_kind {
193 match subscription_kind {
194 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
195 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
196 SubscriptionKind::ZedFree => proto::Plan::Free,
197 }
198 } else {
199 proto::Plan::Free
200 };
201
202 Ok(plan)
203 }
204
205 fn user_id(&self) -> UserId {
206 match &self.principal {
207 Principal::User(user) => user.id,
208 Principal::Impersonated { user, .. } => user.id,
209 }
210 }
211
212 pub fn email(&self) -> Option<String> {
213 match &self.principal {
214 Principal::User(user) => user.email_address.clone(),
215 Principal::Impersonated { user, .. } => user.email_address.clone(),
216 }
217 }
218}
219
220impl Debug for Session {
221 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
222 let mut result = f.debug_struct("Session");
223 match &self.principal {
224 Principal::User(user) => {
225 result.field("user", &user.github_login);
226 }
227 Principal::Impersonated { user, admin } => {
228 result.field("user", &user.github_login);
229 result.field("impersonator", &admin.github_login);
230 }
231 }
232 result.field("connection_id", &self.connection_id).finish()
233 }
234}
235
236struct DbHandle(Arc<Database>);
237
238impl Deref for DbHandle {
239 type Target = Database;
240
241 fn deref(&self) -> &Self::Target {
242 self.0.as_ref()
243 }
244}
245
246pub struct Server {
247 id: parking_lot::Mutex<ServerId>,
248 peer: Arc<Peer>,
249 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
250 app_state: Arc<AppState>,
251 handlers: HashMap<TypeId, MessageHandler>,
252 teardown: watch::Sender<bool>,
253}
254
255pub(crate) struct ConnectionPoolGuard<'a> {
256 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
257 _not_send: PhantomData<Rc<()>>,
258}
259
260#[derive(Serialize)]
261pub struct ServerSnapshot<'a> {
262 peer: &'a Peer,
263 #[serde(serialize_with = "serialize_deref")]
264 connection_pool: ConnectionPoolGuard<'a>,
265}
266
267pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
268where
269 S: Serializer,
270 T: Deref<Target = U>,
271 U: Serialize,
272{
273 Serialize::serialize(value.deref(), serializer)
274}
275
276impl Server {
277 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
278 let mut server = Self {
279 id: parking_lot::Mutex::new(id),
280 peer: Peer::new(id.0 as u32),
281 app_state: app_state.clone(),
282 connection_pool: Default::default(),
283 handlers: Default::default(),
284 teardown: watch::channel(false).0,
285 };
286
287 server
288 .add_request_handler(ping)
289 .add_request_handler(create_room)
290 .add_request_handler(join_room)
291 .add_request_handler(rejoin_room)
292 .add_request_handler(leave_room)
293 .add_request_handler(set_room_participant_role)
294 .add_request_handler(call)
295 .add_request_handler(cancel_call)
296 .add_message_handler(decline_call)
297 .add_request_handler(update_participant_location)
298 .add_request_handler(share_project)
299 .add_message_handler(unshare_project)
300 .add_request_handler(join_project)
301 .add_message_handler(leave_project)
302 .add_request_handler(update_project)
303 .add_request_handler(update_worktree)
304 .add_request_handler(update_repository)
305 .add_request_handler(remove_repository)
306 .add_message_handler(start_language_server)
307 .add_message_handler(update_language_server)
308 .add_message_handler(update_diagnostic_summary)
309 .add_message_handler(update_worktree_settings)
310 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
311 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
312 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
313 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
314 .add_request_handler(forward_find_search_candidates_request)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
316 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
317 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
318 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
319 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
320 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
321 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
322 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
323 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
324 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
325 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
326 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
327 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
328 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
329 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
330 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
331 .add_request_handler(
332 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
333 )
334 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
335 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
336 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
337 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
338 .add_request_handler(
339 forward_read_only_project_request::<proto::LanguageServerIdForName>,
340 )
341 .add_request_handler(
342 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
343 )
344 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
345 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
346 .add_request_handler(
347 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
348 )
349 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
350 .add_request_handler(
351 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
352 )
353 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
354 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
355 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
356 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
357 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
358 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
359 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
360 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
361 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
362 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
363 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
364 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
365 .add_request_handler(
366 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
367 )
368 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
369 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
370 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
371 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
372 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
373 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
374 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
375 .add_message_handler(create_buffer_for_peer)
376 .add_request_handler(update_buffer)
377 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
378 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
379 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
380 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
381 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
382 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
383 .add_request_handler(get_users)
384 .add_request_handler(fuzzy_search_users)
385 .add_request_handler(request_contact)
386 .add_request_handler(remove_contact)
387 .add_request_handler(respond_to_contact_request)
388 .add_message_handler(subscribe_to_channels)
389 .add_request_handler(create_channel)
390 .add_request_handler(delete_channel)
391 .add_request_handler(invite_channel_member)
392 .add_request_handler(remove_channel_member)
393 .add_request_handler(set_channel_member_role)
394 .add_request_handler(set_channel_visibility)
395 .add_request_handler(rename_channel)
396 .add_request_handler(join_channel_buffer)
397 .add_request_handler(leave_channel_buffer)
398 .add_message_handler(update_channel_buffer)
399 .add_request_handler(rejoin_channel_buffers)
400 .add_request_handler(get_channel_members)
401 .add_request_handler(respond_to_channel_invite)
402 .add_request_handler(join_channel)
403 .add_request_handler(join_channel_chat)
404 .add_message_handler(leave_channel_chat)
405 .add_request_handler(send_channel_message)
406 .add_request_handler(remove_channel_message)
407 .add_request_handler(update_channel_message)
408 .add_request_handler(get_channel_messages)
409 .add_request_handler(get_channel_messages_by_id)
410 .add_request_handler(get_notifications)
411 .add_request_handler(mark_notification_as_read)
412 .add_request_handler(move_channel)
413 .add_request_handler(follow)
414 .add_message_handler(unfollow)
415 .add_message_handler(update_followers)
416 .add_request_handler(get_private_user_info)
417 .add_request_handler(get_llm_api_token)
418 .add_request_handler(accept_terms_of_service)
419 .add_message_handler(acknowledge_channel_message)
420 .add_message_handler(acknowledge_buffer_version)
421 .add_request_handler(get_supermaven_api_key)
422 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
423 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
424 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
425 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
426 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
427 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
428 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
429 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
430 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
431 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
432 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
433 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
434 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
435 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
436 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
437 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
438 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
439 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
440 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
441 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
442 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
443 .add_message_handler(update_context);
444
445 Arc::new(server)
446 }
447
448 pub async fn start(&self) -> Result<()> {
449 let server_id = *self.id.lock();
450 let app_state = self.app_state.clone();
451 let peer = self.peer.clone();
452 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
453 let pool = self.connection_pool.clone();
454 let livekit_client = self.app_state.livekit_client.clone();
455
456 let span = info_span!("start server");
457 self.app_state.executor.spawn_detached(
458 async move {
459 tracing::info!("waiting for cleanup timeout");
460 timeout.await;
461 tracing::info!("cleanup timeout expired, retrieving stale rooms");
462 if let Some((room_ids, channel_ids)) = app_state
463 .db
464 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
465 .await
466 .trace_err()
467 {
468 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
469 tracing::info!(
470 stale_channel_buffer_count = channel_ids.len(),
471 "retrieved stale channel buffers"
472 );
473
474 for channel_id in channel_ids {
475 if let Some(refreshed_channel_buffer) = app_state
476 .db
477 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
478 .await
479 .trace_err()
480 {
481 for connection_id in refreshed_channel_buffer.connection_ids {
482 peer.send(
483 connection_id,
484 proto::UpdateChannelBufferCollaborators {
485 channel_id: channel_id.to_proto(),
486 collaborators: refreshed_channel_buffer
487 .collaborators
488 .clone(),
489 },
490 )
491 .trace_err();
492 }
493 }
494 }
495
496 for room_id in room_ids {
497 let mut contacts_to_update = HashSet::default();
498 let mut canceled_calls_to_user_ids = Vec::new();
499 let mut livekit_room = String::new();
500 let mut delete_livekit_room = false;
501
502 if let Some(mut refreshed_room) = app_state
503 .db
504 .clear_stale_room_participants(room_id, server_id)
505 .await
506 .trace_err()
507 {
508 tracing::info!(
509 room_id = room_id.0,
510 new_participant_count = refreshed_room.room.participants.len(),
511 "refreshed room"
512 );
513 room_updated(&refreshed_room.room, &peer);
514 if let Some(channel) = refreshed_room.channel.as_ref() {
515 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
516 }
517 contacts_to_update
518 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
519 contacts_to_update
520 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
521 canceled_calls_to_user_ids =
522 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
523 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
524 delete_livekit_room = refreshed_room.room.participants.is_empty();
525 }
526
527 {
528 let pool = pool.lock();
529 for canceled_user_id in canceled_calls_to_user_ids {
530 for connection_id in pool.user_connection_ids(canceled_user_id) {
531 peer.send(
532 connection_id,
533 proto::CallCanceled {
534 room_id: room_id.to_proto(),
535 },
536 )
537 .trace_err();
538 }
539 }
540 }
541
542 for user_id in contacts_to_update {
543 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
544 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
545 if let Some((busy, contacts)) = busy.zip(contacts) {
546 let pool = pool.lock();
547 let updated_contact = contact_for_user(user_id, busy, &pool);
548 for contact in contacts {
549 if let db::Contact::Accepted {
550 user_id: contact_user_id,
551 ..
552 } = contact
553 {
554 for contact_conn_id in
555 pool.user_connection_ids(contact_user_id)
556 {
557 peer.send(
558 contact_conn_id,
559 proto::UpdateContacts {
560 contacts: vec![updated_contact.clone()],
561 remove_contacts: Default::default(),
562 incoming_requests: Default::default(),
563 remove_incoming_requests: Default::default(),
564 outgoing_requests: Default::default(),
565 remove_outgoing_requests: Default::default(),
566 },
567 )
568 .trace_err();
569 }
570 }
571 }
572 }
573 }
574
575 if let Some(live_kit) = livekit_client.as_ref() {
576 if delete_livekit_room {
577 live_kit.delete_room(livekit_room).await.trace_err();
578 }
579 }
580 }
581 }
582
583 app_state
584 .db
585 .delete_stale_servers(&app_state.config.zed_environment, server_id)
586 .await
587 .trace_err();
588 }
589 .instrument(span),
590 );
591 Ok(())
592 }
593
594 pub fn teardown(&self) {
595 self.peer.teardown();
596 self.connection_pool.lock().reset();
597 let _ = self.teardown.send(true);
598 }
599
600 #[cfg(test)]
601 pub fn reset(&self, id: ServerId) {
602 self.teardown();
603 *self.id.lock() = id;
604 self.peer.reset(id.0 as u32);
605 let _ = self.teardown.send(false);
606 }
607
608 #[cfg(test)]
609 pub fn id(&self) -> ServerId {
610 *self.id.lock()
611 }
612
613 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
614 where
615 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
616 Fut: 'static + Send + Future<Output = Result<()>>,
617 M: EnvelopedMessage,
618 {
619 let prev_handler = self.handlers.insert(
620 TypeId::of::<M>(),
621 Box::new(move |envelope, session| {
622 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
623 let received_at = envelope.received_at;
624 tracing::info!("message received");
625 let start_time = Instant::now();
626 let future = (handler)(*envelope, session);
627 async move {
628 let result = future.await;
629 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
630 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
631 let queue_duration_ms = total_duration_ms - processing_duration_ms;
632 let payload_type = M::NAME;
633
634 match result {
635 Err(error) => {
636 tracing::error!(
637 ?error,
638 total_duration_ms,
639 processing_duration_ms,
640 queue_duration_ms,
641 payload_type,
642 "error handling message"
643 )
644 }
645 Ok(()) => tracing::info!(
646 total_duration_ms,
647 processing_duration_ms,
648 queue_duration_ms,
649 "finished handling message"
650 ),
651 }
652 }
653 .boxed()
654 }),
655 );
656 if prev_handler.is_some() {
657 panic!("registered a handler for the same message twice");
658 }
659 self
660 }
661
662 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
663 where
664 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
665 Fut: 'static + Send + Future<Output = Result<()>>,
666 M: EnvelopedMessage,
667 {
668 self.add_handler(move |envelope, session| handler(envelope.payload, session));
669 self
670 }
671
672 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
673 where
674 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
675 Fut: Send + Future<Output = Result<()>>,
676 M: RequestMessage,
677 {
678 let handler = Arc::new(handler);
679 self.add_handler(move |envelope, session| {
680 let receipt = envelope.receipt();
681 let handler = handler.clone();
682 async move {
683 let peer = session.peer.clone();
684 let responded = Arc::new(AtomicBool::default());
685 let response = Response {
686 peer: peer.clone(),
687 responded: responded.clone(),
688 receipt,
689 };
690 match (handler)(envelope.payload, response, session).await {
691 Ok(()) => {
692 if responded.load(std::sync::atomic::Ordering::SeqCst) {
693 Ok(())
694 } else {
695 Err(anyhow!("handler did not send a response"))?
696 }
697 }
698 Err(error) => {
699 let proto_err = match &error {
700 Error::Internal(err) => err.to_proto(),
701 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
702 };
703 peer.respond_with_error(receipt, proto_err)?;
704 Err(error)
705 }
706 }
707 }
708 })
709 }
710
711 pub fn handle_connection(
712 self: &Arc<Self>,
713 connection: Connection,
714 address: String,
715 principal: Principal,
716 zed_version: ZedVersion,
717 geoip_country_code: Option<String>,
718 system_id: Option<String>,
719 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
720 executor: Executor,
721 ) -> impl Future<Output = ()> + use<> {
722 let this = self.clone();
723 let span = info_span!("handle connection", %address,
724 connection_id=field::Empty,
725 user_id=field::Empty,
726 login=field::Empty,
727 impersonator=field::Empty,
728 geoip_country_code=field::Empty
729 );
730 principal.update_span(&span);
731 if let Some(country_code) = geoip_country_code.as_ref() {
732 span.record("geoip_country_code", country_code);
733 }
734
735 let mut teardown = self.teardown.subscribe();
736 async move {
737 if *teardown.borrow() {
738 tracing::error!("server is tearing down");
739 return
740 }
741 let (connection_id, handle_io, mut incoming_rx) = this
742 .peer
743 .add_connection(connection, {
744 let executor = executor.clone();
745 move |duration| executor.sleep(duration)
746 });
747 tracing::Span::current().record("connection_id", format!("{}", connection_id));
748
749 tracing::info!("connection opened");
750
751 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
752 let http_client = match ReqwestClient::user_agent(&user_agent) {
753 Ok(http_client) => Arc::new(http_client),
754 Err(error) => {
755 tracing::error!(?error, "failed to create HTTP client");
756 return;
757 }
758 };
759
760 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
761 supermaven_admin_api_key.to_string(),
762 http_client.clone(),
763 )));
764
765 let session = Session {
766 principal: principal.clone(),
767 connection_id,
768 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
769 peer: this.peer.clone(),
770 connection_pool: this.connection_pool.clone(),
771 app_state: this.app_state.clone(),
772 geoip_country_code,
773 system_id,
774 _executor: executor.clone(),
775 supermaven_client,
776 };
777
778 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
779 tracing::error!(?error, "failed to send initial client update");
780 return;
781 }
782
783 let handle_io = handle_io.fuse();
784 futures::pin_mut!(handle_io);
785
786 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
787 // This prevents deadlocks when e.g., client A performs a request to client B and
788 // client B performs a request to client A. If both clients stop processing further
789 // messages until their respective request completes, they won't have a chance to
790 // respond to the other client's request and cause a deadlock.
791 //
792 // This arrangement ensures we will attempt to process earlier messages first, but fall
793 // back to processing messages arrived later in the spirit of making progress.
794 let mut foreground_message_handlers = FuturesUnordered::new();
795 let concurrent_handlers = Arc::new(Semaphore::new(256));
796 loop {
797 let next_message = async {
798 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
799 let message = incoming_rx.next().await;
800 (permit, message)
801 }.fuse();
802 futures::pin_mut!(next_message);
803 futures::select_biased! {
804 _ = teardown.changed().fuse() => return,
805 result = handle_io => {
806 if let Err(error) = result {
807 tracing::error!(?error, "error handling I/O");
808 }
809 break;
810 }
811 _ = foreground_message_handlers.next() => {}
812 next_message = next_message => {
813 let (permit, message) = next_message;
814 if let Some(message) = message {
815 let type_name = message.payload_type_name();
816 // note: we copy all the fields from the parent span so we can query them in the logs.
817 // (https://github.com/tokio-rs/tracing/issues/2670).
818 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
819 user_id=field::Empty,
820 login=field::Empty,
821 impersonator=field::Empty,
822 );
823 principal.update_span(&span);
824 let span_enter = span.enter();
825 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
826 let is_background = message.is_background();
827 let handle_message = (handler)(message, session.clone());
828 drop(span_enter);
829
830 let handle_message = async move {
831 handle_message.await;
832 drop(permit);
833 }.instrument(span);
834 if is_background {
835 executor.spawn_detached(handle_message);
836 } else {
837 foreground_message_handlers.push(handle_message);
838 }
839 } else {
840 tracing::error!("no message handler");
841 }
842 } else {
843 tracing::info!("connection closed");
844 break;
845 }
846 }
847 }
848 }
849
850 drop(foreground_message_handlers);
851 tracing::info!("signing out");
852 if let Err(error) = connection_lost(session, teardown, executor).await {
853 tracing::error!(?error, "error signing out");
854 }
855
856 }.instrument(span)
857 }
858
859 async fn send_initial_client_update(
860 &self,
861 connection_id: ConnectionId,
862 principal: &Principal,
863 zed_version: ZedVersion,
864 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
865 session: &Session,
866 ) -> Result<()> {
867 self.peer.send(
868 connection_id,
869 proto::Hello {
870 peer_id: Some(connection_id.into()),
871 },
872 )?;
873 tracing::info!("sent hello message");
874 if let Some(send_connection_id) = send_connection_id.take() {
875 let _ = send_connection_id.send(connection_id);
876 }
877
878 match principal {
879 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
880 if !user.connected_once {
881 self.peer.send(connection_id, proto::ShowContacts {})?;
882 self.app_state
883 .db
884 .set_user_connected_once(user.id, true)
885 .await?;
886 }
887
888 update_user_plan(user.id, session).await?;
889
890 let contacts = self.app_state.db.get_contacts(user.id).await?;
891
892 {
893 let mut pool = self.connection_pool.lock();
894 pool.add_connection(connection_id, user.id, user.admin, zed_version);
895 self.peer.send(
896 connection_id,
897 build_initial_contacts_update(contacts, &pool),
898 )?;
899 }
900
901 if should_auto_subscribe_to_channels(zed_version) {
902 subscribe_user_to_channels(user.id, session).await?;
903 }
904
905 if let Some(incoming_call) =
906 self.app_state.db.incoming_call_for_user(user.id).await?
907 {
908 self.peer.send(connection_id, incoming_call)?;
909 }
910
911 update_user_contacts(user.id, session).await?;
912 }
913 }
914
915 Ok(())
916 }
917
918 pub async fn invite_code_redeemed(
919 self: &Arc<Self>,
920 inviter_id: UserId,
921 invitee_id: UserId,
922 ) -> Result<()> {
923 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
924 if let Some(code) = &user.invite_code {
925 let pool = self.connection_pool.lock();
926 let invitee_contact = contact_for_user(invitee_id, false, &pool);
927 for connection_id in pool.user_connection_ids(inviter_id) {
928 self.peer.send(
929 connection_id,
930 proto::UpdateContacts {
931 contacts: vec![invitee_contact.clone()],
932 ..Default::default()
933 },
934 )?;
935 self.peer.send(
936 connection_id,
937 proto::UpdateInviteInfo {
938 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
939 count: user.invite_count as u32,
940 },
941 )?;
942 }
943 }
944 }
945 Ok(())
946 }
947
948 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
949 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
950 if let Some(invite_code) = &user.invite_code {
951 let pool = self.connection_pool.lock();
952 for connection_id in pool.user_connection_ids(user_id) {
953 self.peer.send(
954 connection_id,
955 proto::UpdateInviteInfo {
956 url: format!(
957 "{}{}",
958 self.app_state.config.invite_link_prefix, invite_code
959 ),
960 count: user.invite_count as u32,
961 },
962 )?;
963 }
964 }
965 }
966 Ok(())
967 }
968
969 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
970 let pool = self.connection_pool.lock();
971 for connection_id in pool.user_connection_ids(user_id) {
972 self.peer
973 .send(connection_id, proto::RefreshLlmToken {})
974 .trace_err();
975 }
976 }
977
978 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
979 ServerSnapshot {
980 connection_pool: ConnectionPoolGuard {
981 guard: self.connection_pool.lock(),
982 _not_send: PhantomData,
983 },
984 peer: &self.peer,
985 }
986 }
987}
988
989impl Deref for ConnectionPoolGuard<'_> {
990 type Target = ConnectionPool;
991
992 fn deref(&self) -> &Self::Target {
993 &self.guard
994 }
995}
996
997impl DerefMut for ConnectionPoolGuard<'_> {
998 fn deref_mut(&mut self) -> &mut Self::Target {
999 &mut self.guard
1000 }
1001}
1002
1003impl Drop for ConnectionPoolGuard<'_> {
1004 fn drop(&mut self) {
1005 #[cfg(test)]
1006 self.check_invariants();
1007 }
1008}
1009
1010fn broadcast<F>(
1011 sender_id: Option<ConnectionId>,
1012 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1013 mut f: F,
1014) where
1015 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1016{
1017 for receiver_id in receiver_ids {
1018 if Some(receiver_id) != sender_id {
1019 if let Err(error) = f(receiver_id) {
1020 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1021 }
1022 }
1023 }
1024}
1025
1026pub struct ProtocolVersion(u32);
1027
1028impl Header for ProtocolVersion {
1029 fn name() -> &'static HeaderName {
1030 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1031 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1032 }
1033
1034 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1035 where
1036 Self: Sized,
1037 I: Iterator<Item = &'i axum::http::HeaderValue>,
1038 {
1039 let version = values
1040 .next()
1041 .ok_or_else(axum::headers::Error::invalid)?
1042 .to_str()
1043 .map_err(|_| axum::headers::Error::invalid())?
1044 .parse()
1045 .map_err(|_| axum::headers::Error::invalid())?;
1046 Ok(Self(version))
1047 }
1048
1049 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1050 values.extend([self.0.to_string().parse().unwrap()]);
1051 }
1052}
1053
1054pub struct AppVersionHeader(SemanticVersion);
1055impl Header for AppVersionHeader {
1056 fn name() -> &'static HeaderName {
1057 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1058 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1059 }
1060
1061 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1062 where
1063 Self: Sized,
1064 I: Iterator<Item = &'i axum::http::HeaderValue>,
1065 {
1066 let version = values
1067 .next()
1068 .ok_or_else(axum::headers::Error::invalid)?
1069 .to_str()
1070 .map_err(|_| axum::headers::Error::invalid())?
1071 .parse()
1072 .map_err(|_| axum::headers::Error::invalid())?;
1073 Ok(Self(version))
1074 }
1075
1076 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1077 values.extend([self.0.to_string().parse().unwrap()]);
1078 }
1079}
1080
1081pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1082 Router::new()
1083 .route("/rpc", get(handle_websocket_request))
1084 .layer(
1085 ServiceBuilder::new()
1086 .layer(Extension(server.app_state.clone()))
1087 .layer(middleware::from_fn(auth::validate_header)),
1088 )
1089 .route("/metrics", get(handle_metrics))
1090 .layer(Extension(server))
1091}
1092
1093pub async fn handle_websocket_request(
1094 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1095 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1096 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1097 Extension(server): Extension<Arc<Server>>,
1098 Extension(principal): Extension<Principal>,
1099 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1100 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1101 ws: WebSocketUpgrade,
1102) -> axum::response::Response {
1103 if protocol_version != rpc::PROTOCOL_VERSION {
1104 return (
1105 StatusCode::UPGRADE_REQUIRED,
1106 "client must be upgraded".to_string(),
1107 )
1108 .into_response();
1109 }
1110
1111 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1112 return (
1113 StatusCode::UPGRADE_REQUIRED,
1114 "no version header found".to_string(),
1115 )
1116 .into_response();
1117 };
1118
1119 if !version.can_collaborate() {
1120 return (
1121 StatusCode::UPGRADE_REQUIRED,
1122 "client must be upgraded".to_string(),
1123 )
1124 .into_response();
1125 }
1126
1127 let socket_address = socket_address.to_string();
1128 ws.on_upgrade(move |socket| {
1129 let socket = socket
1130 .map_ok(to_tungstenite_message)
1131 .err_into()
1132 .with(|message| async move { to_axum_message(message) });
1133 let connection = Connection::new(Box::pin(socket));
1134 async move {
1135 server
1136 .handle_connection(
1137 connection,
1138 socket_address,
1139 principal,
1140 version,
1141 country_code_header.map(|header| header.to_string()),
1142 system_id_header.map(|header| header.to_string()),
1143 None,
1144 Executor::Production,
1145 )
1146 .await;
1147 }
1148 })
1149}
1150
1151pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1152 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1153 let connections_metric = CONNECTIONS_METRIC
1154 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1155
1156 let connections = server
1157 .connection_pool
1158 .lock()
1159 .connections()
1160 .filter(|connection| !connection.admin)
1161 .count();
1162 connections_metric.set(connections as _);
1163
1164 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1165 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1166 register_int_gauge!(
1167 "shared_projects",
1168 "number of open projects with one or more guests"
1169 )
1170 .unwrap()
1171 });
1172
1173 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1174 shared_projects_metric.set(shared_projects as _);
1175
1176 let encoder = prometheus::TextEncoder::new();
1177 let metric_families = prometheus::gather();
1178 let encoded_metrics = encoder
1179 .encode_to_string(&metric_families)
1180 .map_err(|err| anyhow!("{}", err))?;
1181 Ok(encoded_metrics)
1182}
1183
1184#[instrument(err, skip(executor))]
1185async fn connection_lost(
1186 session: Session,
1187 mut teardown: watch::Receiver<bool>,
1188 executor: Executor,
1189) -> Result<()> {
1190 session.peer.disconnect(session.connection_id);
1191 session
1192 .connection_pool()
1193 .await
1194 .remove_connection(session.connection_id)?;
1195
1196 session
1197 .db()
1198 .await
1199 .connection_lost(session.connection_id)
1200 .await
1201 .trace_err();
1202
1203 futures::select_biased! {
1204 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1205
1206 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1207 leave_room_for_session(&session, session.connection_id).await.trace_err();
1208 leave_channel_buffers_for_session(&session)
1209 .await
1210 .trace_err();
1211
1212 if !session
1213 .connection_pool()
1214 .await
1215 .is_user_online(session.user_id())
1216 {
1217 let db = session.db().await;
1218 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1219 room_updated(&room, &session.peer);
1220 }
1221 }
1222
1223 update_user_contacts(session.user_id(), &session).await?;
1224 },
1225 _ = teardown.changed().fuse() => {}
1226 }
1227
1228 Ok(())
1229}
1230
1231/// Acknowledges a ping from a client, used to keep the connection alive.
1232async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1233 response.send(proto::Ack {})?;
1234 Ok(())
1235}
1236
1237/// Creates a new room for calling (outside of channels)
1238async fn create_room(
1239 _request: proto::CreateRoom,
1240 response: Response<proto::CreateRoom>,
1241 session: Session,
1242) -> Result<()> {
1243 let livekit_room = nanoid::nanoid!(30);
1244
1245 let live_kit_connection_info = util::maybe!(async {
1246 let live_kit = session.app_state.livekit_client.as_ref();
1247 let live_kit = live_kit?;
1248 let user_id = session.user_id().to_string();
1249
1250 let token = live_kit
1251 .room_token(&livekit_room, &user_id.to_string())
1252 .trace_err()?;
1253
1254 Some(proto::LiveKitConnectionInfo {
1255 server_url: live_kit.url().into(),
1256 token,
1257 can_publish: true,
1258 })
1259 })
1260 .await;
1261
1262 let room = session
1263 .db()
1264 .await
1265 .create_room(session.user_id(), session.connection_id, &livekit_room)
1266 .await?;
1267
1268 response.send(proto::CreateRoomResponse {
1269 room: Some(room.clone()),
1270 live_kit_connection_info,
1271 })?;
1272
1273 update_user_contacts(session.user_id(), &session).await?;
1274 Ok(())
1275}
1276
1277/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1278async fn join_room(
1279 request: proto::JoinRoom,
1280 response: Response<proto::JoinRoom>,
1281 session: Session,
1282) -> Result<()> {
1283 let room_id = RoomId::from_proto(request.id);
1284
1285 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1286
1287 if let Some(channel_id) = channel_id {
1288 return join_channel_internal(channel_id, Box::new(response), session).await;
1289 }
1290
1291 let joined_room = {
1292 let room = session
1293 .db()
1294 .await
1295 .join_room(room_id, session.user_id(), session.connection_id)
1296 .await?;
1297 room_updated(&room.room, &session.peer);
1298 room.into_inner()
1299 };
1300
1301 for connection_id in session
1302 .connection_pool()
1303 .await
1304 .user_connection_ids(session.user_id())
1305 {
1306 session
1307 .peer
1308 .send(
1309 connection_id,
1310 proto::CallCanceled {
1311 room_id: room_id.to_proto(),
1312 },
1313 )
1314 .trace_err();
1315 }
1316
1317 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1318 {
1319 live_kit
1320 .room_token(
1321 &joined_room.room.livekit_room,
1322 &session.user_id().to_string(),
1323 )
1324 .trace_err()
1325 .map(|token| proto::LiveKitConnectionInfo {
1326 server_url: live_kit.url().into(),
1327 token,
1328 can_publish: true,
1329 })
1330 } else {
1331 None
1332 };
1333
1334 response.send(proto::JoinRoomResponse {
1335 room: Some(joined_room.room),
1336 channel_id: None,
1337 live_kit_connection_info,
1338 })?;
1339
1340 update_user_contacts(session.user_id(), &session).await?;
1341 Ok(())
1342}
1343
1344/// Rejoin room is used to reconnect to a room after connection errors.
1345async fn rejoin_room(
1346 request: proto::RejoinRoom,
1347 response: Response<proto::RejoinRoom>,
1348 session: Session,
1349) -> Result<()> {
1350 let room;
1351 let channel;
1352 {
1353 let mut rejoined_room = session
1354 .db()
1355 .await
1356 .rejoin_room(request, session.user_id(), session.connection_id)
1357 .await?;
1358
1359 response.send(proto::RejoinRoomResponse {
1360 room: Some(rejoined_room.room.clone()),
1361 reshared_projects: rejoined_room
1362 .reshared_projects
1363 .iter()
1364 .map(|project| proto::ResharedProject {
1365 id: project.id.to_proto(),
1366 collaborators: project
1367 .collaborators
1368 .iter()
1369 .map(|collaborator| collaborator.to_proto())
1370 .collect(),
1371 })
1372 .collect(),
1373 rejoined_projects: rejoined_room
1374 .rejoined_projects
1375 .iter()
1376 .map(|rejoined_project| rejoined_project.to_proto())
1377 .collect(),
1378 })?;
1379 room_updated(&rejoined_room.room, &session.peer);
1380
1381 for project in &rejoined_room.reshared_projects {
1382 for collaborator in &project.collaborators {
1383 session
1384 .peer
1385 .send(
1386 collaborator.connection_id,
1387 proto::UpdateProjectCollaborator {
1388 project_id: project.id.to_proto(),
1389 old_peer_id: Some(project.old_connection_id.into()),
1390 new_peer_id: Some(session.connection_id.into()),
1391 },
1392 )
1393 .trace_err();
1394 }
1395
1396 broadcast(
1397 Some(session.connection_id),
1398 project
1399 .collaborators
1400 .iter()
1401 .map(|collaborator| collaborator.connection_id),
1402 |connection_id| {
1403 session.peer.forward_send(
1404 session.connection_id,
1405 connection_id,
1406 proto::UpdateProject {
1407 project_id: project.id.to_proto(),
1408 worktrees: project.worktrees.clone(),
1409 },
1410 )
1411 },
1412 );
1413 }
1414
1415 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1416
1417 let rejoined_room = rejoined_room.into_inner();
1418
1419 room = rejoined_room.room;
1420 channel = rejoined_room.channel;
1421 }
1422
1423 if let Some(channel) = channel {
1424 channel_updated(
1425 &channel,
1426 &room,
1427 &session.peer,
1428 &*session.connection_pool().await,
1429 );
1430 }
1431
1432 update_user_contacts(session.user_id(), &session).await?;
1433 Ok(())
1434}
1435
1436fn notify_rejoined_projects(
1437 rejoined_projects: &mut Vec<RejoinedProject>,
1438 session: &Session,
1439) -> Result<()> {
1440 for project in rejoined_projects.iter() {
1441 for collaborator in &project.collaborators {
1442 session
1443 .peer
1444 .send(
1445 collaborator.connection_id,
1446 proto::UpdateProjectCollaborator {
1447 project_id: project.id.to_proto(),
1448 old_peer_id: Some(project.old_connection_id.into()),
1449 new_peer_id: Some(session.connection_id.into()),
1450 },
1451 )
1452 .trace_err();
1453 }
1454 }
1455
1456 for project in rejoined_projects {
1457 for worktree in mem::take(&mut project.worktrees) {
1458 // Stream this worktree's entries.
1459 let message = proto::UpdateWorktree {
1460 project_id: project.id.to_proto(),
1461 worktree_id: worktree.id,
1462 abs_path: worktree.abs_path.clone(),
1463 root_name: worktree.root_name,
1464 updated_entries: worktree.updated_entries,
1465 removed_entries: worktree.removed_entries,
1466 scan_id: worktree.scan_id,
1467 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1468 updated_repositories: worktree.updated_repositories,
1469 removed_repositories: worktree.removed_repositories,
1470 };
1471 for update in proto::split_worktree_update(message) {
1472 session.peer.send(session.connection_id, update)?;
1473 }
1474
1475 // Stream this worktree's diagnostics.
1476 for summary in worktree.diagnostic_summaries {
1477 session.peer.send(
1478 session.connection_id,
1479 proto::UpdateDiagnosticSummary {
1480 project_id: project.id.to_proto(),
1481 worktree_id: worktree.id,
1482 summary: Some(summary),
1483 },
1484 )?;
1485 }
1486
1487 for settings_file in worktree.settings_files {
1488 session.peer.send(
1489 session.connection_id,
1490 proto::UpdateWorktreeSettings {
1491 project_id: project.id.to_proto(),
1492 worktree_id: worktree.id,
1493 path: settings_file.path,
1494 content: Some(settings_file.content),
1495 kind: Some(settings_file.kind.to_proto().into()),
1496 },
1497 )?;
1498 }
1499 }
1500
1501 for repository in mem::take(&mut project.updated_repositories) {
1502 for update in split_repository_update(repository) {
1503 session.peer.send(session.connection_id, update)?;
1504 }
1505 }
1506
1507 for id in mem::take(&mut project.removed_repositories) {
1508 session.peer.send(
1509 session.connection_id,
1510 proto::RemoveRepository {
1511 project_id: project.id.to_proto(),
1512 id,
1513 },
1514 )?;
1515 }
1516 }
1517
1518 Ok(())
1519}
1520
1521/// leave room disconnects from the room.
1522async fn leave_room(
1523 _: proto::LeaveRoom,
1524 response: Response<proto::LeaveRoom>,
1525 session: Session,
1526) -> Result<()> {
1527 leave_room_for_session(&session, session.connection_id).await?;
1528 response.send(proto::Ack {})?;
1529 Ok(())
1530}
1531
1532/// Updates the permissions of someone else in the room.
1533async fn set_room_participant_role(
1534 request: proto::SetRoomParticipantRole,
1535 response: Response<proto::SetRoomParticipantRole>,
1536 session: Session,
1537) -> Result<()> {
1538 let user_id = UserId::from_proto(request.user_id);
1539 let role = ChannelRole::from(request.role());
1540
1541 let (livekit_room, can_publish) = {
1542 let room = session
1543 .db()
1544 .await
1545 .set_room_participant_role(
1546 session.user_id(),
1547 RoomId::from_proto(request.room_id),
1548 user_id,
1549 role,
1550 )
1551 .await?;
1552
1553 let livekit_room = room.livekit_room.clone();
1554 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1555 room_updated(&room, &session.peer);
1556 (livekit_room, can_publish)
1557 };
1558
1559 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1560 live_kit
1561 .update_participant(
1562 livekit_room.clone(),
1563 request.user_id.to_string(),
1564 livekit_api::proto::ParticipantPermission {
1565 can_subscribe: true,
1566 can_publish,
1567 can_publish_data: can_publish,
1568 hidden: false,
1569 recorder: false,
1570 },
1571 )
1572 .await
1573 .trace_err();
1574 }
1575
1576 response.send(proto::Ack {})?;
1577 Ok(())
1578}
1579
1580/// Call someone else into the current room
1581async fn call(
1582 request: proto::Call,
1583 response: Response<proto::Call>,
1584 session: Session,
1585) -> Result<()> {
1586 let room_id = RoomId::from_proto(request.room_id);
1587 let calling_user_id = session.user_id();
1588 let calling_connection_id = session.connection_id;
1589 let called_user_id = UserId::from_proto(request.called_user_id);
1590 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1591 if !session
1592 .db()
1593 .await
1594 .has_contact(calling_user_id, called_user_id)
1595 .await?
1596 {
1597 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1598 }
1599
1600 let incoming_call = {
1601 let (room, incoming_call) = &mut *session
1602 .db()
1603 .await
1604 .call(
1605 room_id,
1606 calling_user_id,
1607 calling_connection_id,
1608 called_user_id,
1609 initial_project_id,
1610 )
1611 .await?;
1612 room_updated(room, &session.peer);
1613 mem::take(incoming_call)
1614 };
1615 update_user_contacts(called_user_id, &session).await?;
1616
1617 let mut calls = session
1618 .connection_pool()
1619 .await
1620 .user_connection_ids(called_user_id)
1621 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1622 .collect::<FuturesUnordered<_>>();
1623
1624 while let Some(call_response) = calls.next().await {
1625 match call_response.as_ref() {
1626 Ok(_) => {
1627 response.send(proto::Ack {})?;
1628 return Ok(());
1629 }
1630 Err(_) => {
1631 call_response.trace_err();
1632 }
1633 }
1634 }
1635
1636 {
1637 let room = session
1638 .db()
1639 .await
1640 .call_failed(room_id, called_user_id)
1641 .await?;
1642 room_updated(&room, &session.peer);
1643 }
1644 update_user_contacts(called_user_id, &session).await?;
1645
1646 Err(anyhow!("failed to ring user"))?
1647}
1648
1649/// Cancel an outgoing call.
1650async fn cancel_call(
1651 request: proto::CancelCall,
1652 response: Response<proto::CancelCall>,
1653 session: Session,
1654) -> Result<()> {
1655 let called_user_id = UserId::from_proto(request.called_user_id);
1656 let room_id = RoomId::from_proto(request.room_id);
1657 {
1658 let room = session
1659 .db()
1660 .await
1661 .cancel_call(room_id, session.connection_id, called_user_id)
1662 .await?;
1663 room_updated(&room, &session.peer);
1664 }
1665
1666 for connection_id in session
1667 .connection_pool()
1668 .await
1669 .user_connection_ids(called_user_id)
1670 {
1671 session
1672 .peer
1673 .send(
1674 connection_id,
1675 proto::CallCanceled {
1676 room_id: room_id.to_proto(),
1677 },
1678 )
1679 .trace_err();
1680 }
1681 response.send(proto::Ack {})?;
1682
1683 update_user_contacts(called_user_id, &session).await?;
1684 Ok(())
1685}
1686
1687/// Decline an incoming call.
1688async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1689 let room_id = RoomId::from_proto(message.room_id);
1690 {
1691 let room = session
1692 .db()
1693 .await
1694 .decline_call(Some(room_id), session.user_id())
1695 .await?
1696 .ok_or_else(|| anyhow!("failed to decline call"))?;
1697 room_updated(&room, &session.peer);
1698 }
1699
1700 for connection_id in session
1701 .connection_pool()
1702 .await
1703 .user_connection_ids(session.user_id())
1704 {
1705 session
1706 .peer
1707 .send(
1708 connection_id,
1709 proto::CallCanceled {
1710 room_id: room_id.to_proto(),
1711 },
1712 )
1713 .trace_err();
1714 }
1715 update_user_contacts(session.user_id(), &session).await?;
1716 Ok(())
1717}
1718
1719/// Updates other participants in the room with your current location.
1720async fn update_participant_location(
1721 request: proto::UpdateParticipantLocation,
1722 response: Response<proto::UpdateParticipantLocation>,
1723 session: Session,
1724) -> Result<()> {
1725 let room_id = RoomId::from_proto(request.room_id);
1726 let location = request
1727 .location
1728 .ok_or_else(|| anyhow!("invalid location"))?;
1729
1730 let db = session.db().await;
1731 let room = db
1732 .update_room_participant_location(room_id, session.connection_id, location)
1733 .await?;
1734
1735 room_updated(&room, &session.peer);
1736 response.send(proto::Ack {})?;
1737 Ok(())
1738}
1739
1740/// Share a project into the room.
1741async fn share_project(
1742 request: proto::ShareProject,
1743 response: Response<proto::ShareProject>,
1744 session: Session,
1745) -> Result<()> {
1746 let (project_id, room) = &*session
1747 .db()
1748 .await
1749 .share_project(
1750 RoomId::from_proto(request.room_id),
1751 session.connection_id,
1752 &request.worktrees,
1753 request.is_ssh_project,
1754 )
1755 .await?;
1756 response.send(proto::ShareProjectResponse {
1757 project_id: project_id.to_proto(),
1758 })?;
1759 room_updated(room, &session.peer);
1760
1761 Ok(())
1762}
1763
1764/// Unshare a project from the room.
1765async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1766 let project_id = ProjectId::from_proto(message.project_id);
1767 unshare_project_internal(project_id, session.connection_id, &session).await
1768}
1769
1770async fn unshare_project_internal(
1771 project_id: ProjectId,
1772 connection_id: ConnectionId,
1773 session: &Session,
1774) -> Result<()> {
1775 let delete = {
1776 let room_guard = session
1777 .db()
1778 .await
1779 .unshare_project(project_id, connection_id)
1780 .await?;
1781
1782 let (delete, room, guest_connection_ids) = &*room_guard;
1783
1784 let message = proto::UnshareProject {
1785 project_id: project_id.to_proto(),
1786 };
1787
1788 broadcast(
1789 Some(connection_id),
1790 guest_connection_ids.iter().copied(),
1791 |conn_id| session.peer.send(conn_id, message.clone()),
1792 );
1793 if let Some(room) = room {
1794 room_updated(room, &session.peer);
1795 }
1796
1797 *delete
1798 };
1799
1800 if delete {
1801 let db = session.db().await;
1802 db.delete_project(project_id).await?;
1803 }
1804
1805 Ok(())
1806}
1807
1808/// Join someone elses shared project.
1809async fn join_project(
1810 request: proto::JoinProject,
1811 response: Response<proto::JoinProject>,
1812 session: Session,
1813) -> Result<()> {
1814 let project_id = ProjectId::from_proto(request.project_id);
1815
1816 tracing::info!(%project_id, "join project");
1817
1818 let db = session.db().await;
1819 let (project, replica_id) = &mut *db
1820 .join_project(project_id, session.connection_id, session.user_id())
1821 .await?;
1822 drop(db);
1823 tracing::info!(%project_id, "join remote project");
1824 join_project_internal(response, session, project, replica_id)
1825}
1826
1827trait JoinProjectInternalResponse {
1828 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1829}
1830impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1831 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1832 Response::<proto::JoinProject>::send(self, result)
1833 }
1834}
1835
1836fn join_project_internal(
1837 response: impl JoinProjectInternalResponse,
1838 session: Session,
1839 project: &mut Project,
1840 replica_id: &ReplicaId,
1841) -> Result<()> {
1842 let collaborators = project
1843 .collaborators
1844 .iter()
1845 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1846 .map(|collaborator| collaborator.to_proto())
1847 .collect::<Vec<_>>();
1848 let project_id = project.id;
1849 let guest_user_id = session.user_id();
1850
1851 let worktrees = project
1852 .worktrees
1853 .iter()
1854 .map(|(id, worktree)| proto::WorktreeMetadata {
1855 id: *id,
1856 root_name: worktree.root_name.clone(),
1857 visible: worktree.visible,
1858 abs_path: worktree.abs_path.clone(),
1859 })
1860 .collect::<Vec<_>>();
1861
1862 let add_project_collaborator = proto::AddProjectCollaborator {
1863 project_id: project_id.to_proto(),
1864 collaborator: Some(proto::Collaborator {
1865 peer_id: Some(session.connection_id.into()),
1866 replica_id: replica_id.0 as u32,
1867 user_id: guest_user_id.to_proto(),
1868 is_host: false,
1869 }),
1870 };
1871
1872 for collaborator in &collaborators {
1873 session
1874 .peer
1875 .send(
1876 collaborator.peer_id.unwrap().into(),
1877 add_project_collaborator.clone(),
1878 )
1879 .trace_err();
1880 }
1881
1882 // First, we send the metadata associated with each worktree.
1883 response.send(proto::JoinProjectResponse {
1884 project_id: project.id.0 as u64,
1885 worktrees: worktrees.clone(),
1886 replica_id: replica_id.0 as u32,
1887 collaborators: collaborators.clone(),
1888 language_servers: project.language_servers.clone(),
1889 role: project.role.into(),
1890 })?;
1891
1892 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1893 // Stream this worktree's entries.
1894 let message = proto::UpdateWorktree {
1895 project_id: project_id.to_proto(),
1896 worktree_id,
1897 abs_path: worktree.abs_path.clone(),
1898 root_name: worktree.root_name,
1899 updated_entries: worktree.entries,
1900 removed_entries: Default::default(),
1901 scan_id: worktree.scan_id,
1902 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1903 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1904 removed_repositories: Default::default(),
1905 };
1906 for update in proto::split_worktree_update(message) {
1907 session.peer.send(session.connection_id, update.clone())?;
1908 }
1909
1910 // Stream this worktree's diagnostics.
1911 for summary in worktree.diagnostic_summaries {
1912 session.peer.send(
1913 session.connection_id,
1914 proto::UpdateDiagnosticSummary {
1915 project_id: project_id.to_proto(),
1916 worktree_id: worktree.id,
1917 summary: Some(summary),
1918 },
1919 )?;
1920 }
1921
1922 for settings_file in worktree.settings_files {
1923 session.peer.send(
1924 session.connection_id,
1925 proto::UpdateWorktreeSettings {
1926 project_id: project_id.to_proto(),
1927 worktree_id: worktree.id,
1928 path: settings_file.path,
1929 content: Some(settings_file.content),
1930 kind: Some(settings_file.kind.to_proto() as i32),
1931 },
1932 )?;
1933 }
1934 }
1935
1936 for repository in mem::take(&mut project.repositories) {
1937 for update in split_repository_update(repository) {
1938 session.peer.send(session.connection_id, update)?;
1939 }
1940 }
1941
1942 for language_server in &project.language_servers {
1943 session.peer.send(
1944 session.connection_id,
1945 proto::UpdateLanguageServer {
1946 project_id: project_id.to_proto(),
1947 language_server_id: language_server.id,
1948 variant: Some(
1949 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1950 proto::LspDiskBasedDiagnosticsUpdated {},
1951 ),
1952 ),
1953 },
1954 )?;
1955 }
1956
1957 Ok(())
1958}
1959
1960/// Leave someone elses shared project.
1961async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1962 let sender_id = session.connection_id;
1963 let project_id = ProjectId::from_proto(request.project_id);
1964 let db = session.db().await;
1965
1966 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1967 tracing::info!(
1968 %project_id,
1969 "leave project"
1970 );
1971
1972 project_left(project, &session);
1973 if let Some(room) = room {
1974 room_updated(room, &session.peer);
1975 }
1976
1977 Ok(())
1978}
1979
1980/// Updates other participants with changes to the project
1981async fn update_project(
1982 request: proto::UpdateProject,
1983 response: Response<proto::UpdateProject>,
1984 session: Session,
1985) -> Result<()> {
1986 let project_id = ProjectId::from_proto(request.project_id);
1987 let (room, guest_connection_ids) = &*session
1988 .db()
1989 .await
1990 .update_project(project_id, session.connection_id, &request.worktrees)
1991 .await?;
1992 broadcast(
1993 Some(session.connection_id),
1994 guest_connection_ids.iter().copied(),
1995 |connection_id| {
1996 session
1997 .peer
1998 .forward_send(session.connection_id, connection_id, request.clone())
1999 },
2000 );
2001 if let Some(room) = room {
2002 room_updated(room, &session.peer);
2003 }
2004 response.send(proto::Ack {})?;
2005
2006 Ok(())
2007}
2008
2009/// Updates other participants with changes to the worktree
2010async fn update_worktree(
2011 request: proto::UpdateWorktree,
2012 response: Response<proto::UpdateWorktree>,
2013 session: Session,
2014) -> Result<()> {
2015 let guest_connection_ids = session
2016 .db()
2017 .await
2018 .update_worktree(&request, session.connection_id)
2019 .await?;
2020
2021 broadcast(
2022 Some(session.connection_id),
2023 guest_connection_ids.iter().copied(),
2024 |connection_id| {
2025 session
2026 .peer
2027 .forward_send(session.connection_id, connection_id, request.clone())
2028 },
2029 );
2030 response.send(proto::Ack {})?;
2031 Ok(())
2032}
2033
2034async fn update_repository(
2035 request: proto::UpdateRepository,
2036 response: Response<proto::UpdateRepository>,
2037 session: Session,
2038) -> Result<()> {
2039 let guest_connection_ids = session
2040 .db()
2041 .await
2042 .update_repository(&request, session.connection_id)
2043 .await?;
2044
2045 broadcast(
2046 Some(session.connection_id),
2047 guest_connection_ids.iter().copied(),
2048 |connection_id| {
2049 session
2050 .peer
2051 .forward_send(session.connection_id, connection_id, request.clone())
2052 },
2053 );
2054 response.send(proto::Ack {})?;
2055 Ok(())
2056}
2057
2058async fn remove_repository(
2059 request: proto::RemoveRepository,
2060 response: Response<proto::RemoveRepository>,
2061 session: Session,
2062) -> Result<()> {
2063 let guest_connection_ids = session
2064 .db()
2065 .await
2066 .remove_repository(&request, session.connection_id)
2067 .await?;
2068
2069 broadcast(
2070 Some(session.connection_id),
2071 guest_connection_ids.iter().copied(),
2072 |connection_id| {
2073 session
2074 .peer
2075 .forward_send(session.connection_id, connection_id, request.clone())
2076 },
2077 );
2078 response.send(proto::Ack {})?;
2079 Ok(())
2080}
2081
2082/// Updates other participants with changes to the diagnostics
2083async fn update_diagnostic_summary(
2084 message: proto::UpdateDiagnosticSummary,
2085 session: Session,
2086) -> Result<()> {
2087 let guest_connection_ids = session
2088 .db()
2089 .await
2090 .update_diagnostic_summary(&message, session.connection_id)
2091 .await?;
2092
2093 broadcast(
2094 Some(session.connection_id),
2095 guest_connection_ids.iter().copied(),
2096 |connection_id| {
2097 session
2098 .peer
2099 .forward_send(session.connection_id, connection_id, message.clone())
2100 },
2101 );
2102
2103 Ok(())
2104}
2105
2106/// Updates other participants with changes to the worktree settings
2107async fn update_worktree_settings(
2108 message: proto::UpdateWorktreeSettings,
2109 session: Session,
2110) -> Result<()> {
2111 let guest_connection_ids = session
2112 .db()
2113 .await
2114 .update_worktree_settings(&message, session.connection_id)
2115 .await?;
2116
2117 broadcast(
2118 Some(session.connection_id),
2119 guest_connection_ids.iter().copied(),
2120 |connection_id| {
2121 session
2122 .peer
2123 .forward_send(session.connection_id, connection_id, message.clone())
2124 },
2125 );
2126
2127 Ok(())
2128}
2129
2130/// Notify other participants that a language server has started.
2131async fn start_language_server(
2132 request: proto::StartLanguageServer,
2133 session: Session,
2134) -> Result<()> {
2135 let guest_connection_ids = session
2136 .db()
2137 .await
2138 .start_language_server(&request, session.connection_id)
2139 .await?;
2140
2141 broadcast(
2142 Some(session.connection_id),
2143 guest_connection_ids.iter().copied(),
2144 |connection_id| {
2145 session
2146 .peer
2147 .forward_send(session.connection_id, connection_id, request.clone())
2148 },
2149 );
2150 Ok(())
2151}
2152
2153/// Notify other participants that a language server has changed.
2154async fn update_language_server(
2155 request: proto::UpdateLanguageServer,
2156 session: Session,
2157) -> Result<()> {
2158 let project_id = ProjectId::from_proto(request.project_id);
2159 let project_connection_ids = session
2160 .db()
2161 .await
2162 .project_connection_ids(project_id, session.connection_id, true)
2163 .await?;
2164 broadcast(
2165 Some(session.connection_id),
2166 project_connection_ids.iter().copied(),
2167 |connection_id| {
2168 session
2169 .peer
2170 .forward_send(session.connection_id, connection_id, request.clone())
2171 },
2172 );
2173 Ok(())
2174}
2175
2176/// forward a project request to the host. These requests should be read only
2177/// as guests are allowed to send them.
2178async fn forward_read_only_project_request<T>(
2179 request: T,
2180 response: Response<T>,
2181 session: Session,
2182) -> Result<()>
2183where
2184 T: EntityMessage + RequestMessage,
2185{
2186 let project_id = ProjectId::from_proto(request.remote_entity_id());
2187 let host_connection_id = session
2188 .db()
2189 .await
2190 .host_for_read_only_project_request(project_id, session.connection_id)
2191 .await?;
2192 let payload = session
2193 .peer
2194 .forward_request(session.connection_id, host_connection_id, request)
2195 .await?;
2196 response.send(payload)?;
2197 Ok(())
2198}
2199
2200async fn forward_find_search_candidates_request(
2201 request: proto::FindSearchCandidates,
2202 response: Response<proto::FindSearchCandidates>,
2203 session: Session,
2204) -> Result<()> {
2205 let project_id = ProjectId::from_proto(request.remote_entity_id());
2206 let host_connection_id = session
2207 .db()
2208 .await
2209 .host_for_read_only_project_request(project_id, session.connection_id)
2210 .await?;
2211 let payload = session
2212 .peer
2213 .forward_request(session.connection_id, host_connection_id, request)
2214 .await?;
2215 response.send(payload)?;
2216 Ok(())
2217}
2218
2219/// forward a project request to the host. These requests are disallowed
2220/// for guests.
2221async fn forward_mutating_project_request<T>(
2222 request: T,
2223 response: Response<T>,
2224 session: Session,
2225) -> Result<()>
2226where
2227 T: EntityMessage + RequestMessage,
2228{
2229 let project_id = ProjectId::from_proto(request.remote_entity_id());
2230
2231 let host_connection_id = session
2232 .db()
2233 .await
2234 .host_for_mutating_project_request(project_id, session.connection_id)
2235 .await?;
2236 let payload = session
2237 .peer
2238 .forward_request(session.connection_id, host_connection_id, request)
2239 .await?;
2240 response.send(payload)?;
2241 Ok(())
2242}
2243
2244/// Notify other participants that a new buffer has been created
2245async fn create_buffer_for_peer(
2246 request: proto::CreateBufferForPeer,
2247 session: Session,
2248) -> Result<()> {
2249 session
2250 .db()
2251 .await
2252 .check_user_is_project_host(
2253 ProjectId::from_proto(request.project_id),
2254 session.connection_id,
2255 )
2256 .await?;
2257 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2258 session
2259 .peer
2260 .forward_send(session.connection_id, peer_id.into(), request)?;
2261 Ok(())
2262}
2263
2264/// Notify other participants that a buffer has been updated. This is
2265/// allowed for guests as long as the update is limited to selections.
2266async fn update_buffer(
2267 request: proto::UpdateBuffer,
2268 response: Response<proto::UpdateBuffer>,
2269 session: Session,
2270) -> Result<()> {
2271 let project_id = ProjectId::from_proto(request.project_id);
2272 let mut capability = Capability::ReadOnly;
2273
2274 for op in request.operations.iter() {
2275 match op.variant {
2276 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2277 Some(_) => capability = Capability::ReadWrite,
2278 }
2279 }
2280
2281 let host = {
2282 let guard = session
2283 .db()
2284 .await
2285 .connections_for_buffer_update(project_id, session.connection_id, capability)
2286 .await?;
2287
2288 let (host, guests) = &*guard;
2289
2290 broadcast(
2291 Some(session.connection_id),
2292 guests.clone(),
2293 |connection_id| {
2294 session
2295 .peer
2296 .forward_send(session.connection_id, connection_id, request.clone())
2297 },
2298 );
2299
2300 *host
2301 };
2302
2303 if host != session.connection_id {
2304 session
2305 .peer
2306 .forward_request(session.connection_id, host, request.clone())
2307 .await?;
2308 }
2309
2310 response.send(proto::Ack {})?;
2311 Ok(())
2312}
2313
2314async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2315 let project_id = ProjectId::from_proto(message.project_id);
2316
2317 let operation = message.operation.as_ref().context("invalid operation")?;
2318 let capability = match operation.variant.as_ref() {
2319 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2320 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2321 match buffer_op.variant {
2322 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2323 Capability::ReadOnly
2324 }
2325 _ => Capability::ReadWrite,
2326 }
2327 } else {
2328 Capability::ReadWrite
2329 }
2330 }
2331 Some(_) => Capability::ReadWrite,
2332 None => Capability::ReadOnly,
2333 };
2334
2335 let guard = session
2336 .db()
2337 .await
2338 .connections_for_buffer_update(project_id, session.connection_id, capability)
2339 .await?;
2340
2341 let (host, guests) = &*guard;
2342
2343 broadcast(
2344 Some(session.connection_id),
2345 guests.iter().chain([host]).copied(),
2346 |connection_id| {
2347 session
2348 .peer
2349 .forward_send(session.connection_id, connection_id, message.clone())
2350 },
2351 );
2352
2353 Ok(())
2354}
2355
2356/// Notify other participants that a project has been updated.
2357async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2358 request: T,
2359 session: Session,
2360) -> Result<()> {
2361 let project_id = ProjectId::from_proto(request.remote_entity_id());
2362 let project_connection_ids = session
2363 .db()
2364 .await
2365 .project_connection_ids(project_id, session.connection_id, false)
2366 .await?;
2367
2368 broadcast(
2369 Some(session.connection_id),
2370 project_connection_ids.iter().copied(),
2371 |connection_id| {
2372 session
2373 .peer
2374 .forward_send(session.connection_id, connection_id, request.clone())
2375 },
2376 );
2377 Ok(())
2378}
2379
2380/// Start following another user in a call.
2381async fn follow(
2382 request: proto::Follow,
2383 response: Response<proto::Follow>,
2384 session: Session,
2385) -> Result<()> {
2386 let room_id = RoomId::from_proto(request.room_id);
2387 let project_id = request.project_id.map(ProjectId::from_proto);
2388 let leader_id = request
2389 .leader_id
2390 .ok_or_else(|| anyhow!("invalid leader id"))?
2391 .into();
2392 let follower_id = session.connection_id;
2393
2394 session
2395 .db()
2396 .await
2397 .check_room_participants(room_id, leader_id, session.connection_id)
2398 .await?;
2399
2400 let response_payload = session
2401 .peer
2402 .forward_request(session.connection_id, leader_id, request)
2403 .await?;
2404 response.send(response_payload)?;
2405
2406 if let Some(project_id) = project_id {
2407 let room = session
2408 .db()
2409 .await
2410 .follow(room_id, project_id, leader_id, follower_id)
2411 .await?;
2412 room_updated(&room, &session.peer);
2413 }
2414
2415 Ok(())
2416}
2417
2418/// Stop following another user in a call.
2419async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2420 let room_id = RoomId::from_proto(request.room_id);
2421 let project_id = request.project_id.map(ProjectId::from_proto);
2422 let leader_id = request
2423 .leader_id
2424 .ok_or_else(|| anyhow!("invalid leader id"))?
2425 .into();
2426 let follower_id = session.connection_id;
2427
2428 session
2429 .db()
2430 .await
2431 .check_room_participants(room_id, leader_id, session.connection_id)
2432 .await?;
2433
2434 session
2435 .peer
2436 .forward_send(session.connection_id, leader_id, request)?;
2437
2438 if let Some(project_id) = project_id {
2439 let room = session
2440 .db()
2441 .await
2442 .unfollow(room_id, project_id, leader_id, follower_id)
2443 .await?;
2444 room_updated(&room, &session.peer);
2445 }
2446
2447 Ok(())
2448}
2449
2450/// Notify everyone following you of your current location.
2451async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2452 let room_id = RoomId::from_proto(request.room_id);
2453 let database = session.db.lock().await;
2454
2455 let connection_ids = if let Some(project_id) = request.project_id {
2456 let project_id = ProjectId::from_proto(project_id);
2457 database
2458 .project_connection_ids(project_id, session.connection_id, true)
2459 .await?
2460 } else {
2461 database
2462 .room_connection_ids(room_id, session.connection_id)
2463 .await?
2464 };
2465
2466 // For now, don't send view update messages back to that view's current leader.
2467 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2468 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2469 _ => None,
2470 });
2471
2472 for connection_id in connection_ids.iter().cloned() {
2473 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2474 session
2475 .peer
2476 .forward_send(session.connection_id, connection_id, request.clone())?;
2477 }
2478 }
2479 Ok(())
2480}
2481
2482/// Get public data about users.
2483async fn get_users(
2484 request: proto::GetUsers,
2485 response: Response<proto::GetUsers>,
2486 session: Session,
2487) -> Result<()> {
2488 let user_ids = request
2489 .user_ids
2490 .into_iter()
2491 .map(UserId::from_proto)
2492 .collect();
2493 let users = session
2494 .db()
2495 .await
2496 .get_users_by_ids(user_ids)
2497 .await?
2498 .into_iter()
2499 .map(|user| proto::User {
2500 id: user.id.to_proto(),
2501 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2502 github_login: user.github_login,
2503 email: user.email_address,
2504 name: user.name,
2505 })
2506 .collect();
2507 response.send(proto::UsersResponse { users })?;
2508 Ok(())
2509}
2510
2511/// Search for users (to invite) buy Github login
2512async fn fuzzy_search_users(
2513 request: proto::FuzzySearchUsers,
2514 response: Response<proto::FuzzySearchUsers>,
2515 session: Session,
2516) -> Result<()> {
2517 let query = request.query;
2518 let users = match query.len() {
2519 0 => vec![],
2520 1 | 2 => session
2521 .db()
2522 .await
2523 .get_user_by_github_login(&query)
2524 .await?
2525 .into_iter()
2526 .collect(),
2527 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2528 };
2529 let users = users
2530 .into_iter()
2531 .filter(|user| user.id != session.user_id())
2532 .map(|user| proto::User {
2533 id: user.id.to_proto(),
2534 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2535 github_login: user.github_login,
2536 name: user.name,
2537 email: user.email_address,
2538 })
2539 .collect();
2540 response.send(proto::UsersResponse { users })?;
2541 Ok(())
2542}
2543
2544/// Send a contact request to another user.
2545async fn request_contact(
2546 request: proto::RequestContact,
2547 response: Response<proto::RequestContact>,
2548 session: Session,
2549) -> Result<()> {
2550 let requester_id = session.user_id();
2551 let responder_id = UserId::from_proto(request.responder_id);
2552 if requester_id == responder_id {
2553 return Err(anyhow!("cannot add yourself as a contact"))?;
2554 }
2555
2556 let notifications = session
2557 .db()
2558 .await
2559 .send_contact_request(requester_id, responder_id)
2560 .await?;
2561
2562 // Update outgoing contact requests of requester
2563 let mut update = proto::UpdateContacts::default();
2564 update.outgoing_requests.push(responder_id.to_proto());
2565 for connection_id in session
2566 .connection_pool()
2567 .await
2568 .user_connection_ids(requester_id)
2569 {
2570 session.peer.send(connection_id, update.clone())?;
2571 }
2572
2573 // Update incoming contact requests of responder
2574 let mut update = proto::UpdateContacts::default();
2575 update
2576 .incoming_requests
2577 .push(proto::IncomingContactRequest {
2578 requester_id: requester_id.to_proto(),
2579 });
2580 let connection_pool = session.connection_pool().await;
2581 for connection_id in connection_pool.user_connection_ids(responder_id) {
2582 session.peer.send(connection_id, update.clone())?;
2583 }
2584
2585 send_notifications(&connection_pool, &session.peer, notifications);
2586
2587 response.send(proto::Ack {})?;
2588 Ok(())
2589}
2590
2591/// Accept or decline a contact request
2592async fn respond_to_contact_request(
2593 request: proto::RespondToContactRequest,
2594 response: Response<proto::RespondToContactRequest>,
2595 session: Session,
2596) -> Result<()> {
2597 let responder_id = session.user_id();
2598 let requester_id = UserId::from_proto(request.requester_id);
2599 let db = session.db().await;
2600 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2601 db.dismiss_contact_notification(responder_id, requester_id)
2602 .await?;
2603 } else {
2604 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2605
2606 let notifications = db
2607 .respond_to_contact_request(responder_id, requester_id, accept)
2608 .await?;
2609 let requester_busy = db.is_user_busy(requester_id).await?;
2610 let responder_busy = db.is_user_busy(responder_id).await?;
2611
2612 let pool = session.connection_pool().await;
2613 // Update responder with new contact
2614 let mut update = proto::UpdateContacts::default();
2615 if accept {
2616 update
2617 .contacts
2618 .push(contact_for_user(requester_id, requester_busy, &pool));
2619 }
2620 update
2621 .remove_incoming_requests
2622 .push(requester_id.to_proto());
2623 for connection_id in pool.user_connection_ids(responder_id) {
2624 session.peer.send(connection_id, update.clone())?;
2625 }
2626
2627 // Update requester with new contact
2628 let mut update = proto::UpdateContacts::default();
2629 if accept {
2630 update
2631 .contacts
2632 .push(contact_for_user(responder_id, responder_busy, &pool));
2633 }
2634 update
2635 .remove_outgoing_requests
2636 .push(responder_id.to_proto());
2637
2638 for connection_id in pool.user_connection_ids(requester_id) {
2639 session.peer.send(connection_id, update.clone())?;
2640 }
2641
2642 send_notifications(&pool, &session.peer, notifications);
2643 }
2644
2645 response.send(proto::Ack {})?;
2646 Ok(())
2647}
2648
2649/// Remove a contact.
2650async fn remove_contact(
2651 request: proto::RemoveContact,
2652 response: Response<proto::RemoveContact>,
2653 session: Session,
2654) -> Result<()> {
2655 let requester_id = session.user_id();
2656 let responder_id = UserId::from_proto(request.user_id);
2657 let db = session.db().await;
2658 let (contact_accepted, deleted_notification_id) =
2659 db.remove_contact(requester_id, responder_id).await?;
2660
2661 let pool = session.connection_pool().await;
2662 // Update outgoing contact requests of requester
2663 let mut update = proto::UpdateContacts::default();
2664 if contact_accepted {
2665 update.remove_contacts.push(responder_id.to_proto());
2666 } else {
2667 update
2668 .remove_outgoing_requests
2669 .push(responder_id.to_proto());
2670 }
2671 for connection_id in pool.user_connection_ids(requester_id) {
2672 session.peer.send(connection_id, update.clone())?;
2673 }
2674
2675 // Update incoming contact requests of responder
2676 let mut update = proto::UpdateContacts::default();
2677 if contact_accepted {
2678 update.remove_contacts.push(requester_id.to_proto());
2679 } else {
2680 update
2681 .remove_incoming_requests
2682 .push(requester_id.to_proto());
2683 }
2684 for connection_id in pool.user_connection_ids(responder_id) {
2685 session.peer.send(connection_id, update.clone())?;
2686 if let Some(notification_id) = deleted_notification_id {
2687 session.peer.send(
2688 connection_id,
2689 proto::DeleteNotification {
2690 notification_id: notification_id.to_proto(),
2691 },
2692 )?;
2693 }
2694 }
2695
2696 response.send(proto::Ack {})?;
2697 Ok(())
2698}
2699
2700fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2701 version.0.minor() < 139
2702}
2703
2704async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
2705 let db = session.db().await;
2706
2707 let feature_flags = db.get_user_flags(user_id).await?;
2708 let plan = session.current_plan(&db).await?;
2709 let billing_customer = db.get_billing_customer_by_user_id(user_id).await?;
2710 let billing_preferences = db.get_billing_preferences(user_id).await?;
2711
2712 let (subscription_period, usage) = if let Some(llm_db) = session.app_state.llm_db.clone() {
2713 let subscription = db.get_active_billing_subscription(user_id).await?;
2714
2715 let subscription_period = crate::db::billing_subscription::Model::current_period(
2716 subscription,
2717 session.is_staff(),
2718 );
2719
2720 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2721 llm_db
2722 .get_subscription_usage_for_period(user_id, period_start_at, period_end_at)
2723 .await?
2724 } else {
2725 None
2726 };
2727
2728 (subscription_period, usage)
2729 } else {
2730 (None, None)
2731 };
2732
2733 session
2734 .peer
2735 .send(
2736 session.connection_id,
2737 proto::UpdateUserPlan {
2738 plan: plan.into(),
2739 trial_started_at: billing_customer
2740 .and_then(|billing_customer| billing_customer.trial_started_at)
2741 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2742 is_usage_based_billing_enabled: if session.is_staff() {
2743 Some(true)
2744 } else {
2745 billing_preferences
2746 .map(|preferences| preferences.model_request_overages_enabled)
2747 },
2748 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2749 proto::SubscriptionPeriod {
2750 started_at: started_at.timestamp() as u64,
2751 ended_at: ended_at.timestamp() as u64,
2752 }
2753 }),
2754 usage: usage.map(|usage| {
2755 let plan = match plan {
2756 proto::Plan::Free => zed_llm_client::Plan::Free,
2757 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2758 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2759 };
2760
2761 let model_requests_limit = match plan.model_requests_limit() {
2762 zed_llm_client::UsageLimit::Limited(limit) => {
2763 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2764 && feature_flags
2765 .iter()
2766 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2767 {
2768 1_000
2769 } else {
2770 limit
2771 };
2772
2773 zed_llm_client::UsageLimit::Limited(limit)
2774 }
2775 zed_llm_client::UsageLimit::Unlimited => {
2776 zed_llm_client::UsageLimit::Unlimited
2777 }
2778 };
2779
2780 proto::SubscriptionUsage {
2781 model_requests_usage_amount: usage.model_requests as u32,
2782 model_requests_usage_limit: Some(proto::UsageLimit {
2783 variant: Some(match model_requests_limit {
2784 zed_llm_client::UsageLimit::Limited(limit) => {
2785 proto::usage_limit::Variant::Limited(
2786 proto::usage_limit::Limited {
2787 limit: limit as u32,
2788 },
2789 )
2790 }
2791 zed_llm_client::UsageLimit::Unlimited => {
2792 proto::usage_limit::Variant::Unlimited(
2793 proto::usage_limit::Unlimited {},
2794 )
2795 }
2796 }),
2797 }),
2798 edit_predictions_usage_amount: usage.edit_predictions as u32,
2799 edit_predictions_usage_limit: Some(proto::UsageLimit {
2800 variant: Some(match plan.edit_predictions_limit() {
2801 zed_llm_client::UsageLimit::Limited(limit) => {
2802 proto::usage_limit::Variant::Limited(
2803 proto::usage_limit::Limited {
2804 limit: limit as u32,
2805 },
2806 )
2807 }
2808 zed_llm_client::UsageLimit::Unlimited => {
2809 proto::usage_limit::Variant::Unlimited(
2810 proto::usage_limit::Unlimited {},
2811 )
2812 }
2813 }),
2814 }),
2815 }
2816 }),
2817 },
2818 )
2819 .trace_err();
2820
2821 Ok(())
2822}
2823
2824async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2825 subscribe_user_to_channels(session.user_id(), &session).await?;
2826 Ok(())
2827}
2828
2829async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2830 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2831 let mut pool = session.connection_pool().await;
2832 for membership in &channels_for_user.channel_memberships {
2833 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2834 }
2835 session.peer.send(
2836 session.connection_id,
2837 build_update_user_channels(&channels_for_user),
2838 )?;
2839 session.peer.send(
2840 session.connection_id,
2841 build_channels_update(channels_for_user),
2842 )?;
2843 Ok(())
2844}
2845
2846/// Creates a new channel.
2847async fn create_channel(
2848 request: proto::CreateChannel,
2849 response: Response<proto::CreateChannel>,
2850 session: Session,
2851) -> Result<()> {
2852 let db = session.db().await;
2853
2854 let parent_id = request.parent_id.map(ChannelId::from_proto);
2855 let (channel, membership) = db
2856 .create_channel(&request.name, parent_id, session.user_id())
2857 .await?;
2858
2859 let root_id = channel.root_id();
2860 let channel = Channel::from_model(channel);
2861
2862 response.send(proto::CreateChannelResponse {
2863 channel: Some(channel.to_proto()),
2864 parent_id: request.parent_id,
2865 })?;
2866
2867 let mut connection_pool = session.connection_pool().await;
2868 if let Some(membership) = membership {
2869 connection_pool.subscribe_to_channel(
2870 membership.user_id,
2871 membership.channel_id,
2872 membership.role,
2873 );
2874 let update = proto::UpdateUserChannels {
2875 channel_memberships: vec![proto::ChannelMembership {
2876 channel_id: membership.channel_id.to_proto(),
2877 role: membership.role.into(),
2878 }],
2879 ..Default::default()
2880 };
2881 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2882 session.peer.send(connection_id, update.clone())?;
2883 }
2884 }
2885
2886 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2887 if !role.can_see_channel(channel.visibility) {
2888 continue;
2889 }
2890
2891 let update = proto::UpdateChannels {
2892 channels: vec![channel.to_proto()],
2893 ..Default::default()
2894 };
2895 session.peer.send(connection_id, update.clone())?;
2896 }
2897
2898 Ok(())
2899}
2900
2901/// Delete a channel
2902async fn delete_channel(
2903 request: proto::DeleteChannel,
2904 response: Response<proto::DeleteChannel>,
2905 session: Session,
2906) -> Result<()> {
2907 let db = session.db().await;
2908
2909 let channel_id = request.channel_id;
2910 let (root_channel, removed_channels) = db
2911 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2912 .await?;
2913 response.send(proto::Ack {})?;
2914
2915 // Notify members of removed channels
2916 let mut update = proto::UpdateChannels::default();
2917 update
2918 .delete_channels
2919 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2920
2921 let connection_pool = session.connection_pool().await;
2922 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2923 session.peer.send(connection_id, update.clone())?;
2924 }
2925
2926 Ok(())
2927}
2928
2929/// Invite someone to join a channel.
2930async fn invite_channel_member(
2931 request: proto::InviteChannelMember,
2932 response: Response<proto::InviteChannelMember>,
2933 session: Session,
2934) -> Result<()> {
2935 let db = session.db().await;
2936 let channel_id = ChannelId::from_proto(request.channel_id);
2937 let invitee_id = UserId::from_proto(request.user_id);
2938 let InviteMemberResult {
2939 channel,
2940 notifications,
2941 } = db
2942 .invite_channel_member(
2943 channel_id,
2944 invitee_id,
2945 session.user_id(),
2946 request.role().into(),
2947 )
2948 .await?;
2949
2950 let update = proto::UpdateChannels {
2951 channel_invitations: vec![channel.to_proto()],
2952 ..Default::default()
2953 };
2954
2955 let connection_pool = session.connection_pool().await;
2956 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2957 session.peer.send(connection_id, update.clone())?;
2958 }
2959
2960 send_notifications(&connection_pool, &session.peer, notifications);
2961
2962 response.send(proto::Ack {})?;
2963 Ok(())
2964}
2965
2966/// remove someone from a channel
2967async fn remove_channel_member(
2968 request: proto::RemoveChannelMember,
2969 response: Response<proto::RemoveChannelMember>,
2970 session: Session,
2971) -> Result<()> {
2972 let db = session.db().await;
2973 let channel_id = ChannelId::from_proto(request.channel_id);
2974 let member_id = UserId::from_proto(request.user_id);
2975
2976 let RemoveChannelMemberResult {
2977 membership_update,
2978 notification_id,
2979 } = db
2980 .remove_channel_member(channel_id, member_id, session.user_id())
2981 .await?;
2982
2983 let mut connection_pool = session.connection_pool().await;
2984 notify_membership_updated(
2985 &mut connection_pool,
2986 membership_update,
2987 member_id,
2988 &session.peer,
2989 );
2990 for connection_id in connection_pool.user_connection_ids(member_id) {
2991 if let Some(notification_id) = notification_id {
2992 session
2993 .peer
2994 .send(
2995 connection_id,
2996 proto::DeleteNotification {
2997 notification_id: notification_id.to_proto(),
2998 },
2999 )
3000 .trace_err();
3001 }
3002 }
3003
3004 response.send(proto::Ack {})?;
3005 Ok(())
3006}
3007
3008/// Toggle the channel between public and private.
3009/// Care is taken to maintain the invariant that public channels only descend from public channels,
3010/// (though members-only channels can appear at any point in the hierarchy).
3011async fn set_channel_visibility(
3012 request: proto::SetChannelVisibility,
3013 response: Response<proto::SetChannelVisibility>,
3014 session: Session,
3015) -> Result<()> {
3016 let db = session.db().await;
3017 let channel_id = ChannelId::from_proto(request.channel_id);
3018 let visibility = request.visibility().into();
3019
3020 let channel_model = db
3021 .set_channel_visibility(channel_id, visibility, session.user_id())
3022 .await?;
3023 let root_id = channel_model.root_id();
3024 let channel = Channel::from_model(channel_model);
3025
3026 let mut connection_pool = session.connection_pool().await;
3027 for (user_id, role) in connection_pool
3028 .channel_user_ids(root_id)
3029 .collect::<Vec<_>>()
3030 .into_iter()
3031 {
3032 let update = if role.can_see_channel(channel.visibility) {
3033 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3034 proto::UpdateChannels {
3035 channels: vec![channel.to_proto()],
3036 ..Default::default()
3037 }
3038 } else {
3039 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3040 proto::UpdateChannels {
3041 delete_channels: vec![channel.id.to_proto()],
3042 ..Default::default()
3043 }
3044 };
3045
3046 for connection_id in connection_pool.user_connection_ids(user_id) {
3047 session.peer.send(connection_id, update.clone())?;
3048 }
3049 }
3050
3051 response.send(proto::Ack {})?;
3052 Ok(())
3053}
3054
3055/// Alter the role for a user in the channel.
3056async fn set_channel_member_role(
3057 request: proto::SetChannelMemberRole,
3058 response: Response<proto::SetChannelMemberRole>,
3059 session: Session,
3060) -> Result<()> {
3061 let db = session.db().await;
3062 let channel_id = ChannelId::from_proto(request.channel_id);
3063 let member_id = UserId::from_proto(request.user_id);
3064 let result = db
3065 .set_channel_member_role(
3066 channel_id,
3067 session.user_id(),
3068 member_id,
3069 request.role().into(),
3070 )
3071 .await?;
3072
3073 match result {
3074 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3075 let mut connection_pool = session.connection_pool().await;
3076 notify_membership_updated(
3077 &mut connection_pool,
3078 membership_update,
3079 member_id,
3080 &session.peer,
3081 )
3082 }
3083 db::SetMemberRoleResult::InviteUpdated(channel) => {
3084 let update = proto::UpdateChannels {
3085 channel_invitations: vec![channel.to_proto()],
3086 ..Default::default()
3087 };
3088
3089 for connection_id in session
3090 .connection_pool()
3091 .await
3092 .user_connection_ids(member_id)
3093 {
3094 session.peer.send(connection_id, update.clone())?;
3095 }
3096 }
3097 }
3098
3099 response.send(proto::Ack {})?;
3100 Ok(())
3101}
3102
3103/// Change the name of a channel
3104async fn rename_channel(
3105 request: proto::RenameChannel,
3106 response: Response<proto::RenameChannel>,
3107 session: Session,
3108) -> Result<()> {
3109 let db = session.db().await;
3110 let channel_id = ChannelId::from_proto(request.channel_id);
3111 let channel_model = db
3112 .rename_channel(channel_id, session.user_id(), &request.name)
3113 .await?;
3114 let root_id = channel_model.root_id();
3115 let channel = Channel::from_model(channel_model);
3116
3117 response.send(proto::RenameChannelResponse {
3118 channel: Some(channel.to_proto()),
3119 })?;
3120
3121 let connection_pool = session.connection_pool().await;
3122 let update = proto::UpdateChannels {
3123 channels: vec![channel.to_proto()],
3124 ..Default::default()
3125 };
3126 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3127 if role.can_see_channel(channel.visibility) {
3128 session.peer.send(connection_id, update.clone())?;
3129 }
3130 }
3131
3132 Ok(())
3133}
3134
3135/// Move a channel to a new parent.
3136async fn move_channel(
3137 request: proto::MoveChannel,
3138 response: Response<proto::MoveChannel>,
3139 session: Session,
3140) -> Result<()> {
3141 let channel_id = ChannelId::from_proto(request.channel_id);
3142 let to = ChannelId::from_proto(request.to);
3143
3144 let (root_id, channels) = session
3145 .db()
3146 .await
3147 .move_channel(channel_id, to, session.user_id())
3148 .await?;
3149
3150 let connection_pool = session.connection_pool().await;
3151 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3152 let channels = channels
3153 .iter()
3154 .filter_map(|channel| {
3155 if role.can_see_channel(channel.visibility) {
3156 Some(channel.to_proto())
3157 } else {
3158 None
3159 }
3160 })
3161 .collect::<Vec<_>>();
3162 if channels.is_empty() {
3163 continue;
3164 }
3165
3166 let update = proto::UpdateChannels {
3167 channels,
3168 ..Default::default()
3169 };
3170
3171 session.peer.send(connection_id, update.clone())?;
3172 }
3173
3174 response.send(Ack {})?;
3175 Ok(())
3176}
3177
3178/// Get the list of channel members
3179async fn get_channel_members(
3180 request: proto::GetChannelMembers,
3181 response: Response<proto::GetChannelMembers>,
3182 session: Session,
3183) -> Result<()> {
3184 let db = session.db().await;
3185 let channel_id = ChannelId::from_proto(request.channel_id);
3186 let limit = if request.limit == 0 {
3187 u16::MAX as u64
3188 } else {
3189 request.limit
3190 };
3191 let (members, users) = db
3192 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3193 .await?;
3194 response.send(proto::GetChannelMembersResponse { members, users })?;
3195 Ok(())
3196}
3197
3198/// Accept or decline a channel invitation.
3199async fn respond_to_channel_invite(
3200 request: proto::RespondToChannelInvite,
3201 response: Response<proto::RespondToChannelInvite>,
3202 session: Session,
3203) -> Result<()> {
3204 let db = session.db().await;
3205 let channel_id = ChannelId::from_proto(request.channel_id);
3206 let RespondToChannelInvite {
3207 membership_update,
3208 notifications,
3209 } = db
3210 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3211 .await?;
3212
3213 let mut connection_pool = session.connection_pool().await;
3214 if let Some(membership_update) = membership_update {
3215 notify_membership_updated(
3216 &mut connection_pool,
3217 membership_update,
3218 session.user_id(),
3219 &session.peer,
3220 );
3221 } else {
3222 let update = proto::UpdateChannels {
3223 remove_channel_invitations: vec![channel_id.to_proto()],
3224 ..Default::default()
3225 };
3226
3227 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3228 session.peer.send(connection_id, update.clone())?;
3229 }
3230 };
3231
3232 send_notifications(&connection_pool, &session.peer, notifications);
3233
3234 response.send(proto::Ack {})?;
3235
3236 Ok(())
3237}
3238
3239/// Join the channels' room
3240async fn join_channel(
3241 request: proto::JoinChannel,
3242 response: Response<proto::JoinChannel>,
3243 session: Session,
3244) -> Result<()> {
3245 let channel_id = ChannelId::from_proto(request.channel_id);
3246 join_channel_internal(channel_id, Box::new(response), session).await
3247}
3248
3249trait JoinChannelInternalResponse {
3250 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3251}
3252impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3253 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3254 Response::<proto::JoinChannel>::send(self, result)
3255 }
3256}
3257impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3258 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3259 Response::<proto::JoinRoom>::send(self, result)
3260 }
3261}
3262
3263async fn join_channel_internal(
3264 channel_id: ChannelId,
3265 response: Box<impl JoinChannelInternalResponse>,
3266 session: Session,
3267) -> Result<()> {
3268 let joined_room = {
3269 let mut db = session.db().await;
3270 // If zed quits without leaving the room, and the user re-opens zed before the
3271 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3272 // room they were in.
3273 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3274 tracing::info!(
3275 stale_connection_id = %connection,
3276 "cleaning up stale connection",
3277 );
3278 drop(db);
3279 leave_room_for_session(&session, connection).await?;
3280 db = session.db().await;
3281 }
3282
3283 let (joined_room, membership_updated, role) = db
3284 .join_channel(channel_id, session.user_id(), session.connection_id)
3285 .await?;
3286
3287 let live_kit_connection_info =
3288 session
3289 .app_state
3290 .livekit_client
3291 .as_ref()
3292 .and_then(|live_kit| {
3293 let (can_publish, token) = if role == ChannelRole::Guest {
3294 (
3295 false,
3296 live_kit
3297 .guest_token(
3298 &joined_room.room.livekit_room,
3299 &session.user_id().to_string(),
3300 )
3301 .trace_err()?,
3302 )
3303 } else {
3304 (
3305 true,
3306 live_kit
3307 .room_token(
3308 &joined_room.room.livekit_room,
3309 &session.user_id().to_string(),
3310 )
3311 .trace_err()?,
3312 )
3313 };
3314
3315 Some(LiveKitConnectionInfo {
3316 server_url: live_kit.url().into(),
3317 token,
3318 can_publish,
3319 })
3320 });
3321
3322 response.send(proto::JoinRoomResponse {
3323 room: Some(joined_room.room.clone()),
3324 channel_id: joined_room
3325 .channel
3326 .as_ref()
3327 .map(|channel| channel.id.to_proto()),
3328 live_kit_connection_info,
3329 })?;
3330
3331 let mut connection_pool = session.connection_pool().await;
3332 if let Some(membership_updated) = membership_updated {
3333 notify_membership_updated(
3334 &mut connection_pool,
3335 membership_updated,
3336 session.user_id(),
3337 &session.peer,
3338 );
3339 }
3340
3341 room_updated(&joined_room.room, &session.peer);
3342
3343 joined_room
3344 };
3345
3346 channel_updated(
3347 &joined_room
3348 .channel
3349 .ok_or_else(|| anyhow!("channel not returned"))?,
3350 &joined_room.room,
3351 &session.peer,
3352 &*session.connection_pool().await,
3353 );
3354
3355 update_user_contacts(session.user_id(), &session).await?;
3356 Ok(())
3357}
3358
3359/// Start editing the channel notes
3360async fn join_channel_buffer(
3361 request: proto::JoinChannelBuffer,
3362 response: Response<proto::JoinChannelBuffer>,
3363 session: Session,
3364) -> Result<()> {
3365 let db = session.db().await;
3366 let channel_id = ChannelId::from_proto(request.channel_id);
3367
3368 let open_response = db
3369 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3370 .await?;
3371
3372 let collaborators = open_response.collaborators.clone();
3373 response.send(open_response)?;
3374
3375 let update = UpdateChannelBufferCollaborators {
3376 channel_id: channel_id.to_proto(),
3377 collaborators: collaborators.clone(),
3378 };
3379 channel_buffer_updated(
3380 session.connection_id,
3381 collaborators
3382 .iter()
3383 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3384 &update,
3385 &session.peer,
3386 );
3387
3388 Ok(())
3389}
3390
3391/// Edit the channel notes
3392async fn update_channel_buffer(
3393 request: proto::UpdateChannelBuffer,
3394 session: Session,
3395) -> Result<()> {
3396 let db = session.db().await;
3397 let channel_id = ChannelId::from_proto(request.channel_id);
3398
3399 let (collaborators, epoch, version) = db
3400 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3401 .await?;
3402
3403 channel_buffer_updated(
3404 session.connection_id,
3405 collaborators.clone(),
3406 &proto::UpdateChannelBuffer {
3407 channel_id: channel_id.to_proto(),
3408 operations: request.operations,
3409 },
3410 &session.peer,
3411 );
3412
3413 let pool = &*session.connection_pool().await;
3414
3415 let non_collaborators =
3416 pool.channel_connection_ids(channel_id)
3417 .filter_map(|(connection_id, _)| {
3418 if collaborators.contains(&connection_id) {
3419 None
3420 } else {
3421 Some(connection_id)
3422 }
3423 });
3424
3425 broadcast(None, non_collaborators, |peer_id| {
3426 session.peer.send(
3427 peer_id,
3428 proto::UpdateChannels {
3429 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3430 channel_id: channel_id.to_proto(),
3431 epoch: epoch as u64,
3432 version: version.clone(),
3433 }],
3434 ..Default::default()
3435 },
3436 )
3437 });
3438
3439 Ok(())
3440}
3441
3442/// Rejoin the channel notes after a connection blip
3443async fn rejoin_channel_buffers(
3444 request: proto::RejoinChannelBuffers,
3445 response: Response<proto::RejoinChannelBuffers>,
3446 session: Session,
3447) -> Result<()> {
3448 let db = session.db().await;
3449 let buffers = db
3450 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3451 .await?;
3452
3453 for rejoined_buffer in &buffers {
3454 let collaborators_to_notify = rejoined_buffer
3455 .buffer
3456 .collaborators
3457 .iter()
3458 .filter_map(|c| Some(c.peer_id?.into()));
3459 channel_buffer_updated(
3460 session.connection_id,
3461 collaborators_to_notify,
3462 &proto::UpdateChannelBufferCollaborators {
3463 channel_id: rejoined_buffer.buffer.channel_id,
3464 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3465 },
3466 &session.peer,
3467 );
3468 }
3469
3470 response.send(proto::RejoinChannelBuffersResponse {
3471 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3472 })?;
3473
3474 Ok(())
3475}
3476
3477/// Stop editing the channel notes
3478async fn leave_channel_buffer(
3479 request: proto::LeaveChannelBuffer,
3480 response: Response<proto::LeaveChannelBuffer>,
3481 session: Session,
3482) -> Result<()> {
3483 let db = session.db().await;
3484 let channel_id = ChannelId::from_proto(request.channel_id);
3485
3486 let left_buffer = db
3487 .leave_channel_buffer(channel_id, session.connection_id)
3488 .await?;
3489
3490 response.send(Ack {})?;
3491
3492 channel_buffer_updated(
3493 session.connection_id,
3494 left_buffer.connections,
3495 &proto::UpdateChannelBufferCollaborators {
3496 channel_id: channel_id.to_proto(),
3497 collaborators: left_buffer.collaborators,
3498 },
3499 &session.peer,
3500 );
3501
3502 Ok(())
3503}
3504
3505fn channel_buffer_updated<T: EnvelopedMessage>(
3506 sender_id: ConnectionId,
3507 collaborators: impl IntoIterator<Item = ConnectionId>,
3508 message: &T,
3509 peer: &Peer,
3510) {
3511 broadcast(Some(sender_id), collaborators, |peer_id| {
3512 peer.send(peer_id, message.clone())
3513 });
3514}
3515
3516fn send_notifications(
3517 connection_pool: &ConnectionPool,
3518 peer: &Peer,
3519 notifications: db::NotificationBatch,
3520) {
3521 for (user_id, notification) in notifications {
3522 for connection_id in connection_pool.user_connection_ids(user_id) {
3523 if let Err(error) = peer.send(
3524 connection_id,
3525 proto::AddNotification {
3526 notification: Some(notification.clone()),
3527 },
3528 ) {
3529 tracing::error!(
3530 "failed to send notification to {:?} {}",
3531 connection_id,
3532 error
3533 );
3534 }
3535 }
3536 }
3537}
3538
3539/// Send a message to the channel
3540async fn send_channel_message(
3541 request: proto::SendChannelMessage,
3542 response: Response<proto::SendChannelMessage>,
3543 session: Session,
3544) -> Result<()> {
3545 // Validate the message body.
3546 let body = request.body.trim().to_string();
3547 if body.len() > MAX_MESSAGE_LEN {
3548 return Err(anyhow!("message is too long"))?;
3549 }
3550 if body.is_empty() {
3551 return Err(anyhow!("message can't be blank"))?;
3552 }
3553
3554 // TODO: adjust mentions if body is trimmed
3555
3556 let timestamp = OffsetDateTime::now_utc();
3557 let nonce = request
3558 .nonce
3559 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3560
3561 let channel_id = ChannelId::from_proto(request.channel_id);
3562 let CreatedChannelMessage {
3563 message_id,
3564 participant_connection_ids,
3565 notifications,
3566 } = session
3567 .db()
3568 .await
3569 .create_channel_message(
3570 channel_id,
3571 session.user_id(),
3572 &body,
3573 &request.mentions,
3574 timestamp,
3575 nonce.clone().into(),
3576 request.reply_to_message_id.map(MessageId::from_proto),
3577 )
3578 .await?;
3579
3580 let message = proto::ChannelMessage {
3581 sender_id: session.user_id().to_proto(),
3582 id: message_id.to_proto(),
3583 body,
3584 mentions: request.mentions,
3585 timestamp: timestamp.unix_timestamp() as u64,
3586 nonce: Some(nonce),
3587 reply_to_message_id: request.reply_to_message_id,
3588 edited_at: None,
3589 };
3590 broadcast(
3591 Some(session.connection_id),
3592 participant_connection_ids.clone(),
3593 |connection| {
3594 session.peer.send(
3595 connection,
3596 proto::ChannelMessageSent {
3597 channel_id: channel_id.to_proto(),
3598 message: Some(message.clone()),
3599 },
3600 )
3601 },
3602 );
3603 response.send(proto::SendChannelMessageResponse {
3604 message: Some(message),
3605 })?;
3606
3607 let pool = &*session.connection_pool().await;
3608 let non_participants =
3609 pool.channel_connection_ids(channel_id)
3610 .filter_map(|(connection_id, _)| {
3611 if participant_connection_ids.contains(&connection_id) {
3612 None
3613 } else {
3614 Some(connection_id)
3615 }
3616 });
3617 broadcast(None, non_participants, |peer_id| {
3618 session.peer.send(
3619 peer_id,
3620 proto::UpdateChannels {
3621 latest_channel_message_ids: vec![proto::ChannelMessageId {
3622 channel_id: channel_id.to_proto(),
3623 message_id: message_id.to_proto(),
3624 }],
3625 ..Default::default()
3626 },
3627 )
3628 });
3629 send_notifications(pool, &session.peer, notifications);
3630
3631 Ok(())
3632}
3633
3634/// Delete a channel message
3635async fn remove_channel_message(
3636 request: proto::RemoveChannelMessage,
3637 response: Response<proto::RemoveChannelMessage>,
3638 session: Session,
3639) -> Result<()> {
3640 let channel_id = ChannelId::from_proto(request.channel_id);
3641 let message_id = MessageId::from_proto(request.message_id);
3642 let (connection_ids, existing_notification_ids) = session
3643 .db()
3644 .await
3645 .remove_channel_message(channel_id, message_id, session.user_id())
3646 .await?;
3647
3648 broadcast(
3649 Some(session.connection_id),
3650 connection_ids,
3651 move |connection| {
3652 session.peer.send(connection, request.clone())?;
3653
3654 for notification_id in &existing_notification_ids {
3655 session.peer.send(
3656 connection,
3657 proto::DeleteNotification {
3658 notification_id: (*notification_id).to_proto(),
3659 },
3660 )?;
3661 }
3662
3663 Ok(())
3664 },
3665 );
3666 response.send(proto::Ack {})?;
3667 Ok(())
3668}
3669
3670async fn update_channel_message(
3671 request: proto::UpdateChannelMessage,
3672 response: Response<proto::UpdateChannelMessage>,
3673 session: Session,
3674) -> Result<()> {
3675 let channel_id = ChannelId::from_proto(request.channel_id);
3676 let message_id = MessageId::from_proto(request.message_id);
3677 let updated_at = OffsetDateTime::now_utc();
3678 let UpdatedChannelMessage {
3679 message_id,
3680 participant_connection_ids,
3681 notifications,
3682 reply_to_message_id,
3683 timestamp,
3684 deleted_mention_notification_ids,
3685 updated_mention_notifications,
3686 } = session
3687 .db()
3688 .await
3689 .update_channel_message(
3690 channel_id,
3691 message_id,
3692 session.user_id(),
3693 request.body.as_str(),
3694 &request.mentions,
3695 updated_at,
3696 )
3697 .await?;
3698
3699 let nonce = request
3700 .nonce
3701 .clone()
3702 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3703
3704 let message = proto::ChannelMessage {
3705 sender_id: session.user_id().to_proto(),
3706 id: message_id.to_proto(),
3707 body: request.body.clone(),
3708 mentions: request.mentions.clone(),
3709 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3710 nonce: Some(nonce),
3711 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3712 edited_at: Some(updated_at.unix_timestamp() as u64),
3713 };
3714
3715 response.send(proto::Ack {})?;
3716
3717 let pool = &*session.connection_pool().await;
3718 broadcast(
3719 Some(session.connection_id),
3720 participant_connection_ids,
3721 |connection| {
3722 session.peer.send(
3723 connection,
3724 proto::ChannelMessageUpdate {
3725 channel_id: channel_id.to_proto(),
3726 message: Some(message.clone()),
3727 },
3728 )?;
3729
3730 for notification_id in &deleted_mention_notification_ids {
3731 session.peer.send(
3732 connection,
3733 proto::DeleteNotification {
3734 notification_id: (*notification_id).to_proto(),
3735 },
3736 )?;
3737 }
3738
3739 for notification in &updated_mention_notifications {
3740 session.peer.send(
3741 connection,
3742 proto::UpdateNotification {
3743 notification: Some(notification.clone()),
3744 },
3745 )?;
3746 }
3747
3748 Ok(())
3749 },
3750 );
3751
3752 send_notifications(pool, &session.peer, notifications);
3753
3754 Ok(())
3755}
3756
3757/// Mark a channel message as read
3758async fn acknowledge_channel_message(
3759 request: proto::AckChannelMessage,
3760 session: Session,
3761) -> Result<()> {
3762 let channel_id = ChannelId::from_proto(request.channel_id);
3763 let message_id = MessageId::from_proto(request.message_id);
3764 let notifications = session
3765 .db()
3766 .await
3767 .observe_channel_message(channel_id, session.user_id(), message_id)
3768 .await?;
3769 send_notifications(
3770 &*session.connection_pool().await,
3771 &session.peer,
3772 notifications,
3773 );
3774 Ok(())
3775}
3776
3777/// Mark a buffer version as synced
3778async fn acknowledge_buffer_version(
3779 request: proto::AckBufferOperation,
3780 session: Session,
3781) -> Result<()> {
3782 let buffer_id = BufferId::from_proto(request.buffer_id);
3783 session
3784 .db()
3785 .await
3786 .observe_buffer_version(
3787 buffer_id,
3788 session.user_id(),
3789 request.epoch as i32,
3790 &request.version,
3791 )
3792 .await?;
3793 Ok(())
3794}
3795
3796/// Get a Supermaven API key for the user
3797async fn get_supermaven_api_key(
3798 _request: proto::GetSupermavenApiKey,
3799 response: Response<proto::GetSupermavenApiKey>,
3800 session: Session,
3801) -> Result<()> {
3802 let user_id: String = session.user_id().to_string();
3803 if !session.is_staff() {
3804 return Err(anyhow!("supermaven not enabled for this account"))?;
3805 }
3806
3807 let email = session
3808 .email()
3809 .ok_or_else(|| anyhow!("user must have an email"))?;
3810
3811 let supermaven_admin_api = session
3812 .supermaven_client
3813 .as_ref()
3814 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3815
3816 let result = supermaven_admin_api
3817 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3818 .await?;
3819
3820 response.send(proto::GetSupermavenApiKeyResponse {
3821 api_key: result.api_key,
3822 })?;
3823
3824 Ok(())
3825}
3826
3827/// Start receiving chat updates for a channel
3828async fn join_channel_chat(
3829 request: proto::JoinChannelChat,
3830 response: Response<proto::JoinChannelChat>,
3831 session: Session,
3832) -> Result<()> {
3833 let channel_id = ChannelId::from_proto(request.channel_id);
3834
3835 let db = session.db().await;
3836 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3837 .await?;
3838 let messages = db
3839 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3840 .await?;
3841 response.send(proto::JoinChannelChatResponse {
3842 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3843 messages,
3844 })?;
3845 Ok(())
3846}
3847
3848/// Stop receiving chat updates for a channel
3849async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3850 let channel_id = ChannelId::from_proto(request.channel_id);
3851 session
3852 .db()
3853 .await
3854 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3855 .await?;
3856 Ok(())
3857}
3858
3859/// Retrieve the chat history for a channel
3860async fn get_channel_messages(
3861 request: proto::GetChannelMessages,
3862 response: Response<proto::GetChannelMessages>,
3863 session: Session,
3864) -> Result<()> {
3865 let channel_id = ChannelId::from_proto(request.channel_id);
3866 let messages = session
3867 .db()
3868 .await
3869 .get_channel_messages(
3870 channel_id,
3871 session.user_id(),
3872 MESSAGE_COUNT_PER_PAGE,
3873 Some(MessageId::from_proto(request.before_message_id)),
3874 )
3875 .await?;
3876 response.send(proto::GetChannelMessagesResponse {
3877 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3878 messages,
3879 })?;
3880 Ok(())
3881}
3882
3883/// Retrieve specific chat messages
3884async fn get_channel_messages_by_id(
3885 request: proto::GetChannelMessagesById,
3886 response: Response<proto::GetChannelMessagesById>,
3887 session: Session,
3888) -> Result<()> {
3889 let message_ids = request
3890 .message_ids
3891 .iter()
3892 .map(|id| MessageId::from_proto(*id))
3893 .collect::<Vec<_>>();
3894 let messages = session
3895 .db()
3896 .await
3897 .get_channel_messages_by_id(session.user_id(), &message_ids)
3898 .await?;
3899 response.send(proto::GetChannelMessagesResponse {
3900 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3901 messages,
3902 })?;
3903 Ok(())
3904}
3905
3906/// Retrieve the current users notifications
3907async fn get_notifications(
3908 request: proto::GetNotifications,
3909 response: Response<proto::GetNotifications>,
3910 session: Session,
3911) -> Result<()> {
3912 let notifications = session
3913 .db()
3914 .await
3915 .get_notifications(
3916 session.user_id(),
3917 NOTIFICATION_COUNT_PER_PAGE,
3918 request.before_id.map(db::NotificationId::from_proto),
3919 )
3920 .await?;
3921 response.send(proto::GetNotificationsResponse {
3922 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3923 notifications,
3924 })?;
3925 Ok(())
3926}
3927
3928/// Mark notifications as read
3929async fn mark_notification_as_read(
3930 request: proto::MarkNotificationRead,
3931 response: Response<proto::MarkNotificationRead>,
3932 session: Session,
3933) -> Result<()> {
3934 let database = &session.db().await;
3935 let notifications = database
3936 .mark_notification_as_read_by_id(
3937 session.user_id(),
3938 NotificationId::from_proto(request.notification_id),
3939 )
3940 .await?;
3941 send_notifications(
3942 &*session.connection_pool().await,
3943 &session.peer,
3944 notifications,
3945 );
3946 response.send(proto::Ack {})?;
3947 Ok(())
3948}
3949
3950/// Get the current users information
3951async fn get_private_user_info(
3952 _request: proto::GetPrivateUserInfo,
3953 response: Response<proto::GetPrivateUserInfo>,
3954 session: Session,
3955) -> Result<()> {
3956 let db = session.db().await;
3957
3958 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3959 let user = db
3960 .get_user_by_id(session.user_id())
3961 .await?
3962 .ok_or_else(|| anyhow!("user not found"))?;
3963 let flags = db.get_user_flags(session.user_id()).await?;
3964
3965 response.send(proto::GetPrivateUserInfoResponse {
3966 metrics_id,
3967 staff: user.admin,
3968 flags,
3969 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3970 })?;
3971 Ok(())
3972}
3973
3974/// Accept the terms of service (tos) on behalf of the current user
3975async fn accept_terms_of_service(
3976 _request: proto::AcceptTermsOfService,
3977 response: Response<proto::AcceptTermsOfService>,
3978 session: Session,
3979) -> Result<()> {
3980 let db = session.db().await;
3981
3982 let accepted_tos_at = Utc::now();
3983 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3984 .await?;
3985
3986 response.send(proto::AcceptTermsOfServiceResponse {
3987 accepted_tos_at: accepted_tos_at.timestamp() as u64,
3988 })?;
3989 Ok(())
3990}
3991
3992/// The minimum account age an account must have in order to use the LLM service.
3993pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
3994
3995async fn get_llm_api_token(
3996 _request: proto::GetLlmToken,
3997 response: Response<proto::GetLlmToken>,
3998 session: Session,
3999) -> Result<()> {
4000 let db = session.db().await;
4001
4002 let flags = db.get_user_flags(session.user_id()).await?;
4003 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4004
4005 if !session.is_staff() && !has_language_models_feature_flag {
4006 Err(anyhow!("permission denied"))?
4007 }
4008
4009 let user_id = session.user_id();
4010 let user = db
4011 .get_user_by_id(user_id)
4012 .await?
4013 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4014
4015 if user.accepted_tos_at.is_none() {
4016 Err(anyhow!("terms of service not accepted"))?
4017 }
4018
4019 let has_legacy_llm_subscription = session.has_llm_subscription(&db).await?;
4020 let billing_subscription = db.get_active_billing_subscription(user.id).await?;
4021 let billing_preferences = db.get_billing_preferences(user.id).await?;
4022
4023 let token = LlmTokenClaims::create(
4024 &user,
4025 session.is_staff(),
4026 billing_preferences,
4027 &flags,
4028 has_legacy_llm_subscription,
4029 billing_subscription,
4030 session.system_id.clone(),
4031 &session.app_state.config,
4032 )?;
4033 response.send(proto::GetLlmTokenResponse { token })?;
4034 Ok(())
4035}
4036
4037fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4038 let message = match message {
4039 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4040 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4041 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4042 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4043 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4044 code: frame.code.into(),
4045 reason: frame.reason.as_str().to_owned().into(),
4046 })),
4047 // We should never receive a frame while reading the message, according
4048 // to the `tungstenite` maintainers:
4049 //
4050 // > It cannot occur when you read messages from the WebSocket, but it
4051 // > can be used when you want to send the raw frames (e.g. you want to
4052 // > send the frames to the WebSocket without composing the full message first).
4053 // >
4054 // > — https://github.com/snapview/tungstenite-rs/issues/268
4055 TungsteniteMessage::Frame(_) => {
4056 bail!("received an unexpected frame while reading the message")
4057 }
4058 };
4059
4060 Ok(message)
4061}
4062
4063fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4064 match message {
4065 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4066 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4067 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4068 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4069 AxumMessage::Close(frame) => {
4070 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4071 code: frame.code.into(),
4072 reason: frame.reason.as_ref().into(),
4073 }))
4074 }
4075 }
4076}
4077
4078fn notify_membership_updated(
4079 connection_pool: &mut ConnectionPool,
4080 result: MembershipUpdated,
4081 user_id: UserId,
4082 peer: &Peer,
4083) {
4084 for membership in &result.new_channels.channel_memberships {
4085 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4086 }
4087 for channel_id in &result.removed_channels {
4088 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4089 }
4090
4091 let user_channels_update = proto::UpdateUserChannels {
4092 channel_memberships: result
4093 .new_channels
4094 .channel_memberships
4095 .iter()
4096 .map(|cm| proto::ChannelMembership {
4097 channel_id: cm.channel_id.to_proto(),
4098 role: cm.role.into(),
4099 })
4100 .collect(),
4101 ..Default::default()
4102 };
4103
4104 let mut update = build_channels_update(result.new_channels);
4105 update.delete_channels = result
4106 .removed_channels
4107 .into_iter()
4108 .map(|id| id.to_proto())
4109 .collect();
4110 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4111
4112 for connection_id in connection_pool.user_connection_ids(user_id) {
4113 peer.send(connection_id, user_channels_update.clone())
4114 .trace_err();
4115 peer.send(connection_id, update.clone()).trace_err();
4116 }
4117}
4118
4119fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4120 proto::UpdateUserChannels {
4121 channel_memberships: channels
4122 .channel_memberships
4123 .iter()
4124 .map(|m| proto::ChannelMembership {
4125 channel_id: m.channel_id.to_proto(),
4126 role: m.role.into(),
4127 })
4128 .collect(),
4129 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4130 observed_channel_message_id: channels.observed_channel_messages.clone(),
4131 }
4132}
4133
4134fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4135 let mut update = proto::UpdateChannels::default();
4136
4137 for channel in channels.channels {
4138 update.channels.push(channel.to_proto());
4139 }
4140
4141 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4142 update.latest_channel_message_ids = channels.latest_channel_messages;
4143
4144 for (channel_id, participants) in channels.channel_participants {
4145 update
4146 .channel_participants
4147 .push(proto::ChannelParticipants {
4148 channel_id: channel_id.to_proto(),
4149 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4150 });
4151 }
4152
4153 for channel in channels.invited_channels {
4154 update.channel_invitations.push(channel.to_proto());
4155 }
4156
4157 update
4158}
4159
4160fn build_initial_contacts_update(
4161 contacts: Vec<db::Contact>,
4162 pool: &ConnectionPool,
4163) -> proto::UpdateContacts {
4164 let mut update = proto::UpdateContacts::default();
4165
4166 for contact in contacts {
4167 match contact {
4168 db::Contact::Accepted { user_id, busy } => {
4169 update.contacts.push(contact_for_user(user_id, busy, pool));
4170 }
4171 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4172 db::Contact::Incoming { user_id } => {
4173 update
4174 .incoming_requests
4175 .push(proto::IncomingContactRequest {
4176 requester_id: user_id.to_proto(),
4177 })
4178 }
4179 }
4180 }
4181
4182 update
4183}
4184
4185fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4186 proto::Contact {
4187 user_id: user_id.to_proto(),
4188 online: pool.is_user_online(user_id),
4189 busy,
4190 }
4191}
4192
4193fn room_updated(room: &proto::Room, peer: &Peer) {
4194 broadcast(
4195 None,
4196 room.participants
4197 .iter()
4198 .filter_map(|participant| Some(participant.peer_id?.into())),
4199 |peer_id| {
4200 peer.send(
4201 peer_id,
4202 proto::RoomUpdated {
4203 room: Some(room.clone()),
4204 },
4205 )
4206 },
4207 );
4208}
4209
4210fn channel_updated(
4211 channel: &db::channel::Model,
4212 room: &proto::Room,
4213 peer: &Peer,
4214 pool: &ConnectionPool,
4215) {
4216 let participants = room
4217 .participants
4218 .iter()
4219 .map(|p| p.user_id)
4220 .collect::<Vec<_>>();
4221
4222 broadcast(
4223 None,
4224 pool.channel_connection_ids(channel.root_id())
4225 .filter_map(|(channel_id, role)| {
4226 role.can_see_channel(channel.visibility)
4227 .then_some(channel_id)
4228 }),
4229 |peer_id| {
4230 peer.send(
4231 peer_id,
4232 proto::UpdateChannels {
4233 channel_participants: vec![proto::ChannelParticipants {
4234 channel_id: channel.id.to_proto(),
4235 participant_user_ids: participants.clone(),
4236 }],
4237 ..Default::default()
4238 },
4239 )
4240 },
4241 );
4242}
4243
4244async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4245 let db = session.db().await;
4246
4247 let contacts = db.get_contacts(user_id).await?;
4248 let busy = db.is_user_busy(user_id).await?;
4249
4250 let pool = session.connection_pool().await;
4251 let updated_contact = contact_for_user(user_id, busy, &pool);
4252 for contact in contacts {
4253 if let db::Contact::Accepted {
4254 user_id: contact_user_id,
4255 ..
4256 } = contact
4257 {
4258 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4259 session
4260 .peer
4261 .send(
4262 contact_conn_id,
4263 proto::UpdateContacts {
4264 contacts: vec![updated_contact.clone()],
4265 remove_contacts: Default::default(),
4266 incoming_requests: Default::default(),
4267 remove_incoming_requests: Default::default(),
4268 outgoing_requests: Default::default(),
4269 remove_outgoing_requests: Default::default(),
4270 },
4271 )
4272 .trace_err();
4273 }
4274 }
4275 }
4276 Ok(())
4277}
4278
4279async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4280 let mut contacts_to_update = HashSet::default();
4281
4282 let room_id;
4283 let canceled_calls_to_user_ids;
4284 let livekit_room;
4285 let delete_livekit_room;
4286 let room;
4287 let channel;
4288
4289 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4290 contacts_to_update.insert(session.user_id());
4291
4292 for project in left_room.left_projects.values() {
4293 project_left(project, session);
4294 }
4295
4296 room_id = RoomId::from_proto(left_room.room.id);
4297 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4298 livekit_room = mem::take(&mut left_room.room.livekit_room);
4299 delete_livekit_room = left_room.deleted;
4300 room = mem::take(&mut left_room.room);
4301 channel = mem::take(&mut left_room.channel);
4302
4303 room_updated(&room, &session.peer);
4304 } else {
4305 return Ok(());
4306 }
4307
4308 if let Some(channel) = channel {
4309 channel_updated(
4310 &channel,
4311 &room,
4312 &session.peer,
4313 &*session.connection_pool().await,
4314 );
4315 }
4316
4317 {
4318 let pool = session.connection_pool().await;
4319 for canceled_user_id in canceled_calls_to_user_ids {
4320 for connection_id in pool.user_connection_ids(canceled_user_id) {
4321 session
4322 .peer
4323 .send(
4324 connection_id,
4325 proto::CallCanceled {
4326 room_id: room_id.to_proto(),
4327 },
4328 )
4329 .trace_err();
4330 }
4331 contacts_to_update.insert(canceled_user_id);
4332 }
4333 }
4334
4335 for contact_user_id in contacts_to_update {
4336 update_user_contacts(contact_user_id, session).await?;
4337 }
4338
4339 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4340 live_kit
4341 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4342 .await
4343 .trace_err();
4344
4345 if delete_livekit_room {
4346 live_kit.delete_room(livekit_room).await.trace_err();
4347 }
4348 }
4349
4350 Ok(())
4351}
4352
4353async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4354 let left_channel_buffers = session
4355 .db()
4356 .await
4357 .leave_channel_buffers(session.connection_id)
4358 .await?;
4359
4360 for left_buffer in left_channel_buffers {
4361 channel_buffer_updated(
4362 session.connection_id,
4363 left_buffer.connections,
4364 &proto::UpdateChannelBufferCollaborators {
4365 channel_id: left_buffer.channel_id.to_proto(),
4366 collaborators: left_buffer.collaborators,
4367 },
4368 &session.peer,
4369 );
4370 }
4371
4372 Ok(())
4373}
4374
4375fn project_left(project: &db::LeftProject, session: &Session) {
4376 for connection_id in &project.connection_ids {
4377 if project.should_unshare {
4378 session
4379 .peer
4380 .send(
4381 *connection_id,
4382 proto::UnshareProject {
4383 project_id: project.id.to_proto(),
4384 },
4385 )
4386 .trace_err();
4387 } else {
4388 session
4389 .peer
4390 .send(
4391 *connection_id,
4392 proto::RemoveProjectCollaborator {
4393 project_id: project.id.to_proto(),
4394 peer_id: Some(session.connection_id.into()),
4395 },
4396 )
4397 .trace_err();
4398 }
4399 }
4400}
4401
4402pub trait ResultExt {
4403 type Ok;
4404
4405 fn trace_err(self) -> Option<Self::Ok>;
4406}
4407
4408impl<T, E> ResultExt for Result<T, E>
4409where
4410 E: std::fmt::Debug,
4411{
4412 type Ok = T;
4413
4414 #[track_caller]
4415 fn trace_err(self) -> Option<T> {
4416 match self {
4417 Ok(value) => Some(value),
4418 Err(error) => {
4419 tracing::error!("{:?}", error);
4420 None
4421 }
4422 }
4423 }
4424}