1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
8use crate::stripe_client::StripeCustomerId;
9use crate::{
10 AppState, Error, Result, auth,
11 db::{
12 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
13 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
14 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
15 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
16 },
17 executor::Executor,
18};
19use anyhow::{Context as _, anyhow, bail};
20use async_tungstenite::tungstenite::{
21 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
22};
23use axum::{
24 Extension, Router, TypedHeader,
25 body::Body,
26 extract::{
27 ConnectInfo, WebSocketUpgrade,
28 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
29 },
30 headers::{Header, HeaderName},
31 http::StatusCode,
32 middleware,
33 response::IntoResponse,
34 routing::get,
35};
36use chrono::Utc;
37use collections::{HashMap, HashSet};
38pub use connection_pool::{ConnectionPool, ZedVersion};
39use core::fmt::{self, Debug, Formatter};
40use reqwest_client::ReqwestClient;
41use rpc::proto::split_repository_update;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
46 stream::FuturesUnordered,
47};
48use prometheus::{IntGauge, register_int_gauge};
49use rpc::{
50 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
51 proto::{
52 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
53 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
54 },
55};
56use semantic_version::SemanticVersion;
57use serde::{Serialize, Serializer};
58use std::{
59 any::TypeId,
60 future::Future,
61 marker::PhantomData,
62 mem,
63 net::SocketAddr,
64 ops::{Deref, DerefMut},
65 rc::Rc,
66 sync::{
67 Arc, OnceLock,
68 atomic::{AtomicBool, Ordering::SeqCst},
69 },
70 time::{Duration, Instant},
71};
72use time::OffsetDateTime;
73use tokio::sync::{Semaphore, watch};
74use tower::ServiceBuilder;
75use tracing::{
76 Instrument,
77 field::{self},
78 info_span, instrument,
79};
80
81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
82
83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
85
86const MESSAGE_COUNT_PER_PAGE: usize = 100;
87const MAX_MESSAGE_LEN: usize = 1024;
88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
92
93struct Response<R> {
94 peer: Arc<Peer>,
95 receipt: Receipt<R>,
96 responded: Arc<AtomicBool>,
97}
98
99impl<R: RequestMessage> Response<R> {
100 fn send(self, payload: R::Response) -> Result<()> {
101 self.responded.store(true, SeqCst);
102 self.peer.respond(self.receipt, payload)?;
103 Ok(())
104 }
105}
106
107#[derive(Clone, Debug)]
108pub enum Principal {
109 User(User),
110 Impersonated { user: User, admin: User },
111}
112
113impl Principal {
114 fn user(&self) -> &User {
115 match self {
116 Principal::User(user) => user,
117 Principal::Impersonated { user, .. } => user,
118 }
119 }
120
121 fn update_span(&self, span: &tracing::Span) {
122 match &self {
123 Principal::User(user) => {
124 span.record("user_id", user.id.0);
125 span.record("login", &user.github_login);
126 }
127 Principal::Impersonated { user, admin } => {
128 span.record("user_id", user.id.0);
129 span.record("login", &user.github_login);
130 span.record("impersonator", &admin.github_login);
131 }
132 }
133 }
134}
135
136#[derive(Clone)]
137struct Session {
138 principal: Principal,
139 connection_id: ConnectionId,
140 db: Arc<tokio::sync::Mutex<DbHandle>>,
141 peer: Arc<Peer>,
142 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
143 app_state: Arc<AppState>,
144 supermaven_client: Option<Arc<SupermavenAdminApi>>,
145 /// The GeoIP country code for the user.
146 #[allow(unused)]
147 geoip_country_code: Option<String>,
148 system_id: Option<String>,
149 _executor: Executor,
150}
151
152impl Session {
153 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 let guard = self.db.lock().await;
157 #[cfg(test)]
158 tokio::task::yield_now().await;
159 guard
160 }
161
162 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 let guard = self.connection_pool.lock();
166 ConnectionPoolGuard {
167 guard,
168 _not_send: PhantomData,
169 }
170 }
171
172 fn is_staff(&self) -> bool {
173 match &self.principal {
174 Principal::User(user) => user.admin,
175 Principal::Impersonated { .. } => true,
176 }
177 }
178
179 fn user_id(&self) -> UserId {
180 match &self.principal {
181 Principal::User(user) => user.id,
182 Principal::Impersonated { user, .. } => user.id,
183 }
184 }
185
186 pub fn email(&self) -> Option<String> {
187 match &self.principal {
188 Principal::User(user) => user.email_address.clone(),
189 Principal::Impersonated { user, .. } => user.email_address.clone(),
190 }
191 }
192}
193
194impl Debug for Session {
195 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
196 let mut result = f.debug_struct("Session");
197 match &self.principal {
198 Principal::User(user) => {
199 result.field("user", &user.github_login);
200 }
201 Principal::Impersonated { user, admin } => {
202 result.field("user", &user.github_login);
203 result.field("impersonator", &admin.github_login);
204 }
205 }
206 result.field("connection_id", &self.connection_id).finish()
207 }
208}
209
210struct DbHandle(Arc<Database>);
211
212impl Deref for DbHandle {
213 type Target = Database;
214
215 fn deref(&self) -> &Self::Target {
216 self.0.as_ref()
217 }
218}
219
220pub struct Server {
221 id: parking_lot::Mutex<ServerId>,
222 peer: Arc<Peer>,
223 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
224 app_state: Arc<AppState>,
225 handlers: HashMap<TypeId, MessageHandler>,
226 teardown: watch::Sender<bool>,
227}
228
229pub(crate) struct ConnectionPoolGuard<'a> {
230 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
231 _not_send: PhantomData<Rc<()>>,
232}
233
234#[derive(Serialize)]
235pub struct ServerSnapshot<'a> {
236 peer: &'a Peer,
237 #[serde(serialize_with = "serialize_deref")]
238 connection_pool: ConnectionPoolGuard<'a>,
239}
240
241pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
242where
243 S: Serializer,
244 T: Deref<Target = U>,
245 U: Serialize,
246{
247 Serialize::serialize(value.deref(), serializer)
248}
249
250impl Server {
251 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
252 let mut server = Self {
253 id: parking_lot::Mutex::new(id),
254 peer: Peer::new(id.0 as u32),
255 app_state: app_state.clone(),
256 connection_pool: Default::default(),
257 handlers: Default::default(),
258 teardown: watch::channel(false).0,
259 };
260
261 server
262 .add_request_handler(ping)
263 .add_request_handler(create_room)
264 .add_request_handler(join_room)
265 .add_request_handler(rejoin_room)
266 .add_request_handler(leave_room)
267 .add_request_handler(set_room_participant_role)
268 .add_request_handler(call)
269 .add_request_handler(cancel_call)
270 .add_message_handler(decline_call)
271 .add_request_handler(update_participant_location)
272 .add_request_handler(share_project)
273 .add_message_handler(unshare_project)
274 .add_request_handler(join_project)
275 .add_message_handler(leave_project)
276 .add_request_handler(update_project)
277 .add_request_handler(update_worktree)
278 .add_request_handler(update_repository)
279 .add_request_handler(remove_repository)
280 .add_message_handler(start_language_server)
281 .add_message_handler(update_language_server)
282 .add_message_handler(update_diagnostic_summary)
283 .add_message_handler(update_worktree_settings)
284 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
285 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
286 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
287 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
288 .add_request_handler(forward_find_search_candidates_request)
289 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
290 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
291 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
292 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
293 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
294 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
295 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
296 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
297 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
298 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
299 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
300 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
301 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
302 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
303 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
304 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
305 .add_request_handler(
306 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
307 )
308 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
309 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
310 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
311 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
312 .add_request_handler(
313 forward_read_only_project_request::<proto::LanguageServerIdForName>,
314 )
315 .add_request_handler(
316 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
317 )
318 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
319 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
320 .add_request_handler(
321 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
322 )
323 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
324 .add_request_handler(
325 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
326 )
327 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
328 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
329 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
331 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
332 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
333 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
334 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
337 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
338 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
339 .add_request_handler(
340 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
341 )
342 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
343 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
344 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
345 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
346 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
347 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
348 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
349 .add_message_handler(create_buffer_for_peer)
350 .add_request_handler(update_buffer)
351 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
353 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
354 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
355 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
356 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
357 .add_request_handler(get_users)
358 .add_request_handler(fuzzy_search_users)
359 .add_request_handler(request_contact)
360 .add_request_handler(remove_contact)
361 .add_request_handler(respond_to_contact_request)
362 .add_message_handler(subscribe_to_channels)
363 .add_request_handler(create_channel)
364 .add_request_handler(delete_channel)
365 .add_request_handler(invite_channel_member)
366 .add_request_handler(remove_channel_member)
367 .add_request_handler(set_channel_member_role)
368 .add_request_handler(set_channel_visibility)
369 .add_request_handler(rename_channel)
370 .add_request_handler(join_channel_buffer)
371 .add_request_handler(leave_channel_buffer)
372 .add_message_handler(update_channel_buffer)
373 .add_request_handler(rejoin_channel_buffers)
374 .add_request_handler(get_channel_members)
375 .add_request_handler(respond_to_channel_invite)
376 .add_request_handler(join_channel)
377 .add_request_handler(join_channel_chat)
378 .add_message_handler(leave_channel_chat)
379 .add_request_handler(send_channel_message)
380 .add_request_handler(remove_channel_message)
381 .add_request_handler(update_channel_message)
382 .add_request_handler(get_channel_messages)
383 .add_request_handler(get_channel_messages_by_id)
384 .add_request_handler(get_notifications)
385 .add_request_handler(mark_notification_as_read)
386 .add_request_handler(move_channel)
387 .add_request_handler(follow)
388 .add_message_handler(unfollow)
389 .add_message_handler(update_followers)
390 .add_request_handler(get_private_user_info)
391 .add_request_handler(get_llm_api_token)
392 .add_request_handler(accept_terms_of_service)
393 .add_message_handler(acknowledge_channel_message)
394 .add_message_handler(acknowledge_buffer_version)
395 .add_request_handler(get_supermaven_api_key)
396 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
397 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
398 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
399 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
400 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
401 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
402 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
403 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
404 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
405 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
406 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
407 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
408 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
409 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
410 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
411 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
412 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
413 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
414 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
415 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
416 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
417 .add_message_handler(update_context);
418
419 Arc::new(server)
420 }
421
422 pub async fn start(&self) -> Result<()> {
423 let server_id = *self.id.lock();
424 let app_state = self.app_state.clone();
425 let peer = self.peer.clone();
426 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
427 let pool = self.connection_pool.clone();
428 let livekit_client = self.app_state.livekit_client.clone();
429
430 let span = info_span!("start server");
431 self.app_state.executor.spawn_detached(
432 async move {
433 tracing::info!("waiting for cleanup timeout");
434 timeout.await;
435 tracing::info!("cleanup timeout expired, retrieving stale rooms");
436
437 app_state
438 .db
439 .delete_stale_channel_chat_participants(
440 &app_state.config.zed_environment,
441 server_id,
442 )
443 .await
444 .trace_err();
445
446 if let Some((room_ids, channel_ids)) = app_state
447 .db
448 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
449 .await
450 .trace_err()
451 {
452 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
453 tracing::info!(
454 stale_channel_buffer_count = channel_ids.len(),
455 "retrieved stale channel buffers"
456 );
457
458 for channel_id in channel_ids {
459 if let Some(refreshed_channel_buffer) = app_state
460 .db
461 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
462 .await
463 .trace_err()
464 {
465 for connection_id in refreshed_channel_buffer.connection_ids {
466 peer.send(
467 connection_id,
468 proto::UpdateChannelBufferCollaborators {
469 channel_id: channel_id.to_proto(),
470 collaborators: refreshed_channel_buffer
471 .collaborators
472 .clone(),
473 },
474 )
475 .trace_err();
476 }
477 }
478 }
479
480 for room_id in room_ids {
481 let mut contacts_to_update = HashSet::default();
482 let mut canceled_calls_to_user_ids = Vec::new();
483 let mut livekit_room = String::new();
484 let mut delete_livekit_room = false;
485
486 if let Some(mut refreshed_room) = app_state
487 .db
488 .clear_stale_room_participants(room_id, server_id)
489 .await
490 .trace_err()
491 {
492 tracing::info!(
493 room_id = room_id.0,
494 new_participant_count = refreshed_room.room.participants.len(),
495 "refreshed room"
496 );
497 room_updated(&refreshed_room.room, &peer);
498 if let Some(channel) = refreshed_room.channel.as_ref() {
499 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
500 }
501 contacts_to_update
502 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
503 contacts_to_update
504 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
505 canceled_calls_to_user_ids =
506 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
507 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
508 delete_livekit_room = refreshed_room.room.participants.is_empty();
509 }
510
511 {
512 let pool = pool.lock();
513 for canceled_user_id in canceled_calls_to_user_ids {
514 for connection_id in pool.user_connection_ids(canceled_user_id) {
515 peer.send(
516 connection_id,
517 proto::CallCanceled {
518 room_id: room_id.to_proto(),
519 },
520 )
521 .trace_err();
522 }
523 }
524 }
525
526 for user_id in contacts_to_update {
527 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
528 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
529 if let Some((busy, contacts)) = busy.zip(contacts) {
530 let pool = pool.lock();
531 let updated_contact = contact_for_user(user_id, busy, &pool);
532 for contact in contacts {
533 if let db::Contact::Accepted {
534 user_id: contact_user_id,
535 ..
536 } = contact
537 {
538 for contact_conn_id in
539 pool.user_connection_ids(contact_user_id)
540 {
541 peer.send(
542 contact_conn_id,
543 proto::UpdateContacts {
544 contacts: vec![updated_contact.clone()],
545 remove_contacts: Default::default(),
546 incoming_requests: Default::default(),
547 remove_incoming_requests: Default::default(),
548 outgoing_requests: Default::default(),
549 remove_outgoing_requests: Default::default(),
550 },
551 )
552 .trace_err();
553 }
554 }
555 }
556 }
557 }
558
559 if let Some(live_kit) = livekit_client.as_ref() {
560 if delete_livekit_room {
561 live_kit.delete_room(livekit_room).await.trace_err();
562 }
563 }
564 }
565 }
566
567 app_state
568 .db
569 .delete_stale_channel_chat_participants(
570 &app_state.config.zed_environment,
571 server_id,
572 )
573 .await
574 .trace_err();
575
576 app_state
577 .db
578 .clear_old_worktree_entries(server_id)
579 .await
580 .trace_err();
581
582 app_state
583 .db
584 .delete_stale_servers(&app_state.config.zed_environment, server_id)
585 .await
586 .trace_err();
587 }
588 .instrument(span),
589 );
590 Ok(())
591 }
592
593 pub fn teardown(&self) {
594 self.peer.teardown();
595 self.connection_pool.lock().reset();
596 let _ = self.teardown.send(true);
597 }
598
599 #[cfg(test)]
600 pub fn reset(&self, id: ServerId) {
601 self.teardown();
602 *self.id.lock() = id;
603 self.peer.reset(id.0 as u32);
604 let _ = self.teardown.send(false);
605 }
606
607 #[cfg(test)]
608 pub fn id(&self) -> ServerId {
609 *self.id.lock()
610 }
611
612 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
613 where
614 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
615 Fut: 'static + Send + Future<Output = Result<()>>,
616 M: EnvelopedMessage,
617 {
618 let prev_handler = self.handlers.insert(
619 TypeId::of::<M>(),
620 Box::new(move |envelope, session| {
621 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
622 let received_at = envelope.received_at;
623 tracing::info!("message received");
624 let start_time = Instant::now();
625 let future = (handler)(*envelope, session);
626 async move {
627 let result = future.await;
628 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
629 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
630 let queue_duration_ms = total_duration_ms - processing_duration_ms;
631 let payload_type = M::NAME;
632
633 match result {
634 Err(error) => {
635 tracing::error!(
636 ?error,
637 total_duration_ms,
638 processing_duration_ms,
639 queue_duration_ms,
640 payload_type,
641 "error handling message"
642 )
643 }
644 Ok(()) => tracing::info!(
645 total_duration_ms,
646 processing_duration_ms,
647 queue_duration_ms,
648 "finished handling message"
649 ),
650 }
651 }
652 .boxed()
653 }),
654 );
655 if prev_handler.is_some() {
656 panic!("registered a handler for the same message twice");
657 }
658 self
659 }
660
661 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
662 where
663 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
664 Fut: 'static + Send + Future<Output = Result<()>>,
665 M: EnvelopedMessage,
666 {
667 self.add_handler(move |envelope, session| handler(envelope.payload, session));
668 self
669 }
670
671 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
672 where
673 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
674 Fut: Send + Future<Output = Result<()>>,
675 M: RequestMessage,
676 {
677 let handler = Arc::new(handler);
678 self.add_handler(move |envelope, session| {
679 let receipt = envelope.receipt();
680 let handler = handler.clone();
681 async move {
682 let peer = session.peer.clone();
683 let responded = Arc::new(AtomicBool::default());
684 let response = Response {
685 peer: peer.clone(),
686 responded: responded.clone(),
687 receipt,
688 };
689 match (handler)(envelope.payload, response, session).await {
690 Ok(()) => {
691 if responded.load(std::sync::atomic::Ordering::SeqCst) {
692 Ok(())
693 } else {
694 Err(anyhow!("handler did not send a response"))?
695 }
696 }
697 Err(error) => {
698 let proto_err = match &error {
699 Error::Internal(err) => err.to_proto(),
700 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
701 };
702 peer.respond_with_error(receipt, proto_err)?;
703 Err(error)
704 }
705 }
706 }
707 })
708 }
709
710 pub fn handle_connection(
711 self: &Arc<Self>,
712 connection: Connection,
713 address: String,
714 principal: Principal,
715 zed_version: ZedVersion,
716 geoip_country_code: Option<String>,
717 system_id: Option<String>,
718 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
719 executor: Executor,
720 ) -> impl Future<Output = ()> + use<> {
721 let this = self.clone();
722 let span = info_span!("handle connection", %address,
723 connection_id=field::Empty,
724 user_id=field::Empty,
725 login=field::Empty,
726 impersonator=field::Empty,
727 geoip_country_code=field::Empty
728 );
729 principal.update_span(&span);
730 if let Some(country_code) = geoip_country_code.as_ref() {
731 span.record("geoip_country_code", country_code);
732 }
733
734 let mut teardown = self.teardown.subscribe();
735 async move {
736 if *teardown.borrow() {
737 tracing::error!("server is tearing down");
738 return
739 }
740 let (connection_id, handle_io, mut incoming_rx) = this
741 .peer
742 .add_connection(connection, {
743 let executor = executor.clone();
744 move |duration| executor.sleep(duration)
745 });
746 tracing::Span::current().record("connection_id", format!("{}", connection_id));
747
748 tracing::info!("connection opened");
749
750 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
751 let http_client = match ReqwestClient::user_agent(&user_agent) {
752 Ok(http_client) => Arc::new(http_client),
753 Err(error) => {
754 tracing::error!(?error, "failed to create HTTP client");
755 return;
756 }
757 };
758
759 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
760 supermaven_admin_api_key.to_string(),
761 http_client.clone(),
762 )));
763
764 let session = Session {
765 principal: principal.clone(),
766 connection_id,
767 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
768 peer: this.peer.clone(),
769 connection_pool: this.connection_pool.clone(),
770 app_state: this.app_state.clone(),
771 geoip_country_code,
772 system_id,
773 _executor: executor.clone(),
774 supermaven_client,
775 };
776
777 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
778 tracing::error!(?error, "failed to send initial client update");
779 return;
780 }
781
782 let handle_io = handle_io.fuse();
783 futures::pin_mut!(handle_io);
784
785 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
786 // This prevents deadlocks when e.g., client A performs a request to client B and
787 // client B performs a request to client A. If both clients stop processing further
788 // messages until their respective request completes, they won't have a chance to
789 // respond to the other client's request and cause a deadlock.
790 //
791 // This arrangement ensures we will attempt to process earlier messages first, but fall
792 // back to processing messages arrived later in the spirit of making progress.
793 let mut foreground_message_handlers = FuturesUnordered::new();
794 let concurrent_handlers = Arc::new(Semaphore::new(256));
795 loop {
796 let next_message = async {
797 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
798 let message = incoming_rx.next().await;
799 (permit, message)
800 }.fuse();
801 futures::pin_mut!(next_message);
802 futures::select_biased! {
803 _ = teardown.changed().fuse() => return,
804 result = handle_io => {
805 if let Err(error) = result {
806 tracing::error!(?error, "error handling I/O");
807 }
808 break;
809 }
810 _ = foreground_message_handlers.next() => {}
811 next_message = next_message => {
812 let (permit, message) = next_message;
813 if let Some(message) = message {
814 let type_name = message.payload_type_name();
815 // note: we copy all the fields from the parent span so we can query them in the logs.
816 // (https://github.com/tokio-rs/tracing/issues/2670).
817 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
818 user_id=field::Empty,
819 login=field::Empty,
820 impersonator=field::Empty,
821 );
822 principal.update_span(&span);
823 let span_enter = span.enter();
824 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
825 let is_background = message.is_background();
826 let handle_message = (handler)(message, session.clone());
827 drop(span_enter);
828
829 let handle_message = async move {
830 handle_message.await;
831 drop(permit);
832 }.instrument(span);
833 if is_background {
834 executor.spawn_detached(handle_message);
835 } else {
836 foreground_message_handlers.push(handle_message);
837 }
838 } else {
839 tracing::error!("no message handler");
840 }
841 } else {
842 tracing::info!("connection closed");
843 break;
844 }
845 }
846 }
847 }
848
849 drop(foreground_message_handlers);
850 tracing::info!("signing out");
851 if let Err(error) = connection_lost(session, teardown, executor).await {
852 tracing::error!(?error, "error signing out");
853 }
854
855 }.instrument(span)
856 }
857
858 async fn send_initial_client_update(
859 &self,
860 connection_id: ConnectionId,
861 zed_version: ZedVersion,
862 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
863 session: &Session,
864 ) -> Result<()> {
865 self.peer.send(
866 connection_id,
867 proto::Hello {
868 peer_id: Some(connection_id.into()),
869 },
870 )?;
871 tracing::info!("sent hello message");
872 if let Some(send_connection_id) = send_connection_id.take() {
873 let _ = send_connection_id.send(connection_id);
874 }
875
876 match &session.principal {
877 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
878 if !user.connected_once {
879 self.peer.send(connection_id, proto::ShowContacts {})?;
880 self.app_state
881 .db
882 .set_user_connected_once(user.id, true)
883 .await?;
884 }
885
886 update_user_plan(session).await?;
887
888 let contacts = self.app_state.db.get_contacts(user.id).await?;
889
890 {
891 let mut pool = self.connection_pool.lock();
892 pool.add_connection(connection_id, user.id, user.admin, zed_version);
893 self.peer.send(
894 connection_id,
895 build_initial_contacts_update(contacts, &pool),
896 )?;
897 }
898
899 if should_auto_subscribe_to_channels(zed_version) {
900 subscribe_user_to_channels(user.id, session).await?;
901 }
902
903 if let Some(incoming_call) =
904 self.app_state.db.incoming_call_for_user(user.id).await?
905 {
906 self.peer.send(connection_id, incoming_call)?;
907 }
908
909 update_user_contacts(user.id, session).await?;
910 }
911 }
912
913 Ok(())
914 }
915
916 pub async fn invite_code_redeemed(
917 self: &Arc<Self>,
918 inviter_id: UserId,
919 invitee_id: UserId,
920 ) -> Result<()> {
921 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
922 if let Some(code) = &user.invite_code {
923 let pool = self.connection_pool.lock();
924 let invitee_contact = contact_for_user(invitee_id, false, &pool);
925 for connection_id in pool.user_connection_ids(inviter_id) {
926 self.peer.send(
927 connection_id,
928 proto::UpdateContacts {
929 contacts: vec![invitee_contact.clone()],
930 ..Default::default()
931 },
932 )?;
933 self.peer.send(
934 connection_id,
935 proto::UpdateInviteInfo {
936 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
937 count: user.invite_count as u32,
938 },
939 )?;
940 }
941 }
942 }
943 Ok(())
944 }
945
946 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
947 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
948 if let Some(invite_code) = &user.invite_code {
949 let pool = self.connection_pool.lock();
950 for connection_id in pool.user_connection_ids(user_id) {
951 self.peer.send(
952 connection_id,
953 proto::UpdateInviteInfo {
954 url: format!(
955 "{}{}",
956 self.app_state.config.invite_link_prefix, invite_code
957 ),
958 count: user.invite_count as u32,
959 },
960 )?;
961 }
962 }
963 }
964 Ok(())
965 }
966
967 pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
968 let user = self
969 .app_state
970 .db
971 .get_user_by_id(user_id)
972 .await?
973 .context("user not found")?;
974
975 let update_user_plan = make_update_user_plan_message(
976 &user,
977 user.admin,
978 &self.app_state.db,
979 self.app_state.llm_db.clone(),
980 )
981 .await?;
982
983 let pool = self.connection_pool.lock();
984 for connection_id in pool.user_connection_ids(user_id) {
985 self.peer
986 .send(connection_id, update_user_plan.clone())
987 .trace_err();
988 }
989
990 Ok(())
991 }
992
993 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
994 let pool = self.connection_pool.lock();
995 for connection_id in pool.user_connection_ids(user_id) {
996 self.peer
997 .send(connection_id, proto::RefreshLlmToken {})
998 .trace_err();
999 }
1000 }
1001
1002 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1003 ServerSnapshot {
1004 connection_pool: ConnectionPoolGuard {
1005 guard: self.connection_pool.lock(),
1006 _not_send: PhantomData,
1007 },
1008 peer: &self.peer,
1009 }
1010 }
1011}
1012
1013impl Deref for ConnectionPoolGuard<'_> {
1014 type Target = ConnectionPool;
1015
1016 fn deref(&self) -> &Self::Target {
1017 &self.guard
1018 }
1019}
1020
1021impl DerefMut for ConnectionPoolGuard<'_> {
1022 fn deref_mut(&mut self) -> &mut Self::Target {
1023 &mut self.guard
1024 }
1025}
1026
1027impl Drop for ConnectionPoolGuard<'_> {
1028 fn drop(&mut self) {
1029 #[cfg(test)]
1030 self.check_invariants();
1031 }
1032}
1033
1034fn broadcast<F>(
1035 sender_id: Option<ConnectionId>,
1036 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1037 mut f: F,
1038) where
1039 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1040{
1041 for receiver_id in receiver_ids {
1042 if Some(receiver_id) != sender_id {
1043 if let Err(error) = f(receiver_id) {
1044 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1045 }
1046 }
1047 }
1048}
1049
1050pub struct ProtocolVersion(u32);
1051
1052impl Header for ProtocolVersion {
1053 fn name() -> &'static HeaderName {
1054 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1055 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1056 }
1057
1058 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1059 where
1060 Self: Sized,
1061 I: Iterator<Item = &'i axum::http::HeaderValue>,
1062 {
1063 let version = values
1064 .next()
1065 .ok_or_else(axum::headers::Error::invalid)?
1066 .to_str()
1067 .map_err(|_| axum::headers::Error::invalid())?
1068 .parse()
1069 .map_err(|_| axum::headers::Error::invalid())?;
1070 Ok(Self(version))
1071 }
1072
1073 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1074 values.extend([self.0.to_string().parse().unwrap()]);
1075 }
1076}
1077
1078pub struct AppVersionHeader(SemanticVersion);
1079impl Header for AppVersionHeader {
1080 fn name() -> &'static HeaderName {
1081 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1082 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1083 }
1084
1085 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1086 where
1087 Self: Sized,
1088 I: Iterator<Item = &'i axum::http::HeaderValue>,
1089 {
1090 let version = values
1091 .next()
1092 .ok_or_else(axum::headers::Error::invalid)?
1093 .to_str()
1094 .map_err(|_| axum::headers::Error::invalid())?
1095 .parse()
1096 .map_err(|_| axum::headers::Error::invalid())?;
1097 Ok(Self(version))
1098 }
1099
1100 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1101 values.extend([self.0.to_string().parse().unwrap()]);
1102 }
1103}
1104
1105pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1106 Router::new()
1107 .route("/rpc", get(handle_websocket_request))
1108 .layer(
1109 ServiceBuilder::new()
1110 .layer(Extension(server.app_state.clone()))
1111 .layer(middleware::from_fn(auth::validate_header)),
1112 )
1113 .route("/metrics", get(handle_metrics))
1114 .layer(Extension(server))
1115}
1116
1117pub async fn handle_websocket_request(
1118 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1119 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1120 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1121 Extension(server): Extension<Arc<Server>>,
1122 Extension(principal): Extension<Principal>,
1123 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1124 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1125 ws: WebSocketUpgrade,
1126) -> axum::response::Response {
1127 if protocol_version != rpc::PROTOCOL_VERSION {
1128 return (
1129 StatusCode::UPGRADE_REQUIRED,
1130 "client must be upgraded".to_string(),
1131 )
1132 .into_response();
1133 }
1134
1135 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1136 return (
1137 StatusCode::UPGRADE_REQUIRED,
1138 "no version header found".to_string(),
1139 )
1140 .into_response();
1141 };
1142
1143 if !version.can_collaborate() {
1144 return (
1145 StatusCode::UPGRADE_REQUIRED,
1146 "client must be upgraded".to_string(),
1147 )
1148 .into_response();
1149 }
1150
1151 let socket_address = socket_address.to_string();
1152 ws.on_upgrade(move |socket| {
1153 let socket = socket
1154 .map_ok(to_tungstenite_message)
1155 .err_into()
1156 .with(|message| async move { to_axum_message(message) });
1157 let connection = Connection::new(Box::pin(socket));
1158 async move {
1159 server
1160 .handle_connection(
1161 connection,
1162 socket_address,
1163 principal,
1164 version,
1165 country_code_header.map(|header| header.to_string()),
1166 system_id_header.map(|header| header.to_string()),
1167 None,
1168 Executor::Production,
1169 )
1170 .await;
1171 }
1172 })
1173}
1174
1175pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1176 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1177 let connections_metric = CONNECTIONS_METRIC
1178 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1179
1180 let connections = server
1181 .connection_pool
1182 .lock()
1183 .connections()
1184 .filter(|connection| !connection.admin)
1185 .count();
1186 connections_metric.set(connections as _);
1187
1188 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1189 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1190 register_int_gauge!(
1191 "shared_projects",
1192 "number of open projects with one or more guests"
1193 )
1194 .unwrap()
1195 });
1196
1197 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1198 shared_projects_metric.set(shared_projects as _);
1199
1200 let encoder = prometheus::TextEncoder::new();
1201 let metric_families = prometheus::gather();
1202 let encoded_metrics = encoder
1203 .encode_to_string(&metric_families)
1204 .map_err(|err| anyhow!("{err}"))?;
1205 Ok(encoded_metrics)
1206}
1207
1208#[instrument(err, skip(executor))]
1209async fn connection_lost(
1210 session: Session,
1211 mut teardown: watch::Receiver<bool>,
1212 executor: Executor,
1213) -> Result<()> {
1214 session.peer.disconnect(session.connection_id);
1215 session
1216 .connection_pool()
1217 .await
1218 .remove_connection(session.connection_id)?;
1219
1220 session
1221 .db()
1222 .await
1223 .connection_lost(session.connection_id)
1224 .await
1225 .trace_err();
1226
1227 futures::select_biased! {
1228 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1229
1230 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1231 leave_room_for_session(&session, session.connection_id).await.trace_err();
1232 leave_channel_buffers_for_session(&session)
1233 .await
1234 .trace_err();
1235
1236 if !session
1237 .connection_pool()
1238 .await
1239 .is_user_online(session.user_id())
1240 {
1241 let db = session.db().await;
1242 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1243 room_updated(&room, &session.peer);
1244 }
1245 }
1246
1247 update_user_contacts(session.user_id(), &session).await?;
1248 },
1249 _ = teardown.changed().fuse() => {}
1250 }
1251
1252 Ok(())
1253}
1254
1255/// Acknowledges a ping from a client, used to keep the connection alive.
1256async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1257 response.send(proto::Ack {})?;
1258 Ok(())
1259}
1260
1261/// Creates a new room for calling (outside of channels)
1262async fn create_room(
1263 _request: proto::CreateRoom,
1264 response: Response<proto::CreateRoom>,
1265 session: Session,
1266) -> Result<()> {
1267 let livekit_room = nanoid::nanoid!(30);
1268
1269 let live_kit_connection_info = util::maybe!(async {
1270 let live_kit = session.app_state.livekit_client.as_ref();
1271 let live_kit = live_kit?;
1272 let user_id = session.user_id().to_string();
1273
1274 let token = live_kit
1275 .room_token(&livekit_room, &user_id.to_string())
1276 .trace_err()?;
1277
1278 Some(proto::LiveKitConnectionInfo {
1279 server_url: live_kit.url().into(),
1280 token,
1281 can_publish: true,
1282 })
1283 })
1284 .await;
1285
1286 let room = session
1287 .db()
1288 .await
1289 .create_room(session.user_id(), session.connection_id, &livekit_room)
1290 .await?;
1291
1292 response.send(proto::CreateRoomResponse {
1293 room: Some(room.clone()),
1294 live_kit_connection_info,
1295 })?;
1296
1297 update_user_contacts(session.user_id(), &session).await?;
1298 Ok(())
1299}
1300
1301/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1302async fn join_room(
1303 request: proto::JoinRoom,
1304 response: Response<proto::JoinRoom>,
1305 session: Session,
1306) -> Result<()> {
1307 let room_id = RoomId::from_proto(request.id);
1308
1309 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1310
1311 if let Some(channel_id) = channel_id {
1312 return join_channel_internal(channel_id, Box::new(response), session).await;
1313 }
1314
1315 let joined_room = {
1316 let room = session
1317 .db()
1318 .await
1319 .join_room(room_id, session.user_id(), session.connection_id)
1320 .await?;
1321 room_updated(&room.room, &session.peer);
1322 room.into_inner()
1323 };
1324
1325 for connection_id in session
1326 .connection_pool()
1327 .await
1328 .user_connection_ids(session.user_id())
1329 {
1330 session
1331 .peer
1332 .send(
1333 connection_id,
1334 proto::CallCanceled {
1335 room_id: room_id.to_proto(),
1336 },
1337 )
1338 .trace_err();
1339 }
1340
1341 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1342 {
1343 live_kit
1344 .room_token(
1345 &joined_room.room.livekit_room,
1346 &session.user_id().to_string(),
1347 )
1348 .trace_err()
1349 .map(|token| proto::LiveKitConnectionInfo {
1350 server_url: live_kit.url().into(),
1351 token,
1352 can_publish: true,
1353 })
1354 } else {
1355 None
1356 };
1357
1358 response.send(proto::JoinRoomResponse {
1359 room: Some(joined_room.room),
1360 channel_id: None,
1361 live_kit_connection_info,
1362 })?;
1363
1364 update_user_contacts(session.user_id(), &session).await?;
1365 Ok(())
1366}
1367
1368/// Rejoin room is used to reconnect to a room after connection errors.
1369async fn rejoin_room(
1370 request: proto::RejoinRoom,
1371 response: Response<proto::RejoinRoom>,
1372 session: Session,
1373) -> Result<()> {
1374 let room;
1375 let channel;
1376 {
1377 let mut rejoined_room = session
1378 .db()
1379 .await
1380 .rejoin_room(request, session.user_id(), session.connection_id)
1381 .await?;
1382
1383 response.send(proto::RejoinRoomResponse {
1384 room: Some(rejoined_room.room.clone()),
1385 reshared_projects: rejoined_room
1386 .reshared_projects
1387 .iter()
1388 .map(|project| proto::ResharedProject {
1389 id: project.id.to_proto(),
1390 collaborators: project
1391 .collaborators
1392 .iter()
1393 .map(|collaborator| collaborator.to_proto())
1394 .collect(),
1395 })
1396 .collect(),
1397 rejoined_projects: rejoined_room
1398 .rejoined_projects
1399 .iter()
1400 .map(|rejoined_project| rejoined_project.to_proto())
1401 .collect(),
1402 })?;
1403 room_updated(&rejoined_room.room, &session.peer);
1404
1405 for project in &rejoined_room.reshared_projects {
1406 for collaborator in &project.collaborators {
1407 session
1408 .peer
1409 .send(
1410 collaborator.connection_id,
1411 proto::UpdateProjectCollaborator {
1412 project_id: project.id.to_proto(),
1413 old_peer_id: Some(project.old_connection_id.into()),
1414 new_peer_id: Some(session.connection_id.into()),
1415 },
1416 )
1417 .trace_err();
1418 }
1419
1420 broadcast(
1421 Some(session.connection_id),
1422 project
1423 .collaborators
1424 .iter()
1425 .map(|collaborator| collaborator.connection_id),
1426 |connection_id| {
1427 session.peer.forward_send(
1428 session.connection_id,
1429 connection_id,
1430 proto::UpdateProject {
1431 project_id: project.id.to_proto(),
1432 worktrees: project.worktrees.clone(),
1433 },
1434 )
1435 },
1436 );
1437 }
1438
1439 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1440
1441 let rejoined_room = rejoined_room.into_inner();
1442
1443 room = rejoined_room.room;
1444 channel = rejoined_room.channel;
1445 }
1446
1447 if let Some(channel) = channel {
1448 channel_updated(
1449 &channel,
1450 &room,
1451 &session.peer,
1452 &*session.connection_pool().await,
1453 );
1454 }
1455
1456 update_user_contacts(session.user_id(), &session).await?;
1457 Ok(())
1458}
1459
1460fn notify_rejoined_projects(
1461 rejoined_projects: &mut Vec<RejoinedProject>,
1462 session: &Session,
1463) -> Result<()> {
1464 for project in rejoined_projects.iter() {
1465 for collaborator in &project.collaborators {
1466 session
1467 .peer
1468 .send(
1469 collaborator.connection_id,
1470 proto::UpdateProjectCollaborator {
1471 project_id: project.id.to_proto(),
1472 old_peer_id: Some(project.old_connection_id.into()),
1473 new_peer_id: Some(session.connection_id.into()),
1474 },
1475 )
1476 .trace_err();
1477 }
1478 }
1479
1480 for project in rejoined_projects {
1481 for worktree in mem::take(&mut project.worktrees) {
1482 // Stream this worktree's entries.
1483 let message = proto::UpdateWorktree {
1484 project_id: project.id.to_proto(),
1485 worktree_id: worktree.id,
1486 abs_path: worktree.abs_path.clone(),
1487 root_name: worktree.root_name,
1488 updated_entries: worktree.updated_entries,
1489 removed_entries: worktree.removed_entries,
1490 scan_id: worktree.scan_id,
1491 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1492 updated_repositories: worktree.updated_repositories,
1493 removed_repositories: worktree.removed_repositories,
1494 };
1495 for update in proto::split_worktree_update(message) {
1496 session.peer.send(session.connection_id, update)?;
1497 }
1498
1499 // Stream this worktree's diagnostics.
1500 for summary in worktree.diagnostic_summaries {
1501 session.peer.send(
1502 session.connection_id,
1503 proto::UpdateDiagnosticSummary {
1504 project_id: project.id.to_proto(),
1505 worktree_id: worktree.id,
1506 summary: Some(summary),
1507 },
1508 )?;
1509 }
1510
1511 for settings_file in worktree.settings_files {
1512 session.peer.send(
1513 session.connection_id,
1514 proto::UpdateWorktreeSettings {
1515 project_id: project.id.to_proto(),
1516 worktree_id: worktree.id,
1517 path: settings_file.path,
1518 content: Some(settings_file.content),
1519 kind: Some(settings_file.kind.to_proto().into()),
1520 },
1521 )?;
1522 }
1523 }
1524
1525 for repository in mem::take(&mut project.updated_repositories) {
1526 for update in split_repository_update(repository) {
1527 session.peer.send(session.connection_id, update)?;
1528 }
1529 }
1530
1531 for id in mem::take(&mut project.removed_repositories) {
1532 session.peer.send(
1533 session.connection_id,
1534 proto::RemoveRepository {
1535 project_id: project.id.to_proto(),
1536 id,
1537 },
1538 )?;
1539 }
1540 }
1541
1542 Ok(())
1543}
1544
1545/// leave room disconnects from the room.
1546async fn leave_room(
1547 _: proto::LeaveRoom,
1548 response: Response<proto::LeaveRoom>,
1549 session: Session,
1550) -> Result<()> {
1551 leave_room_for_session(&session, session.connection_id).await?;
1552 response.send(proto::Ack {})?;
1553 Ok(())
1554}
1555
1556/// Updates the permissions of someone else in the room.
1557async fn set_room_participant_role(
1558 request: proto::SetRoomParticipantRole,
1559 response: Response<proto::SetRoomParticipantRole>,
1560 session: Session,
1561) -> Result<()> {
1562 let user_id = UserId::from_proto(request.user_id);
1563 let role = ChannelRole::from(request.role());
1564
1565 let (livekit_room, can_publish) = {
1566 let room = session
1567 .db()
1568 .await
1569 .set_room_participant_role(
1570 session.user_id(),
1571 RoomId::from_proto(request.room_id),
1572 user_id,
1573 role,
1574 )
1575 .await?;
1576
1577 let livekit_room = room.livekit_room.clone();
1578 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1579 room_updated(&room, &session.peer);
1580 (livekit_room, can_publish)
1581 };
1582
1583 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1584 live_kit
1585 .update_participant(
1586 livekit_room.clone(),
1587 request.user_id.to_string(),
1588 livekit_api::proto::ParticipantPermission {
1589 can_subscribe: true,
1590 can_publish,
1591 can_publish_data: can_publish,
1592 hidden: false,
1593 recorder: false,
1594 },
1595 )
1596 .await
1597 .trace_err();
1598 }
1599
1600 response.send(proto::Ack {})?;
1601 Ok(())
1602}
1603
1604/// Call someone else into the current room
1605async fn call(
1606 request: proto::Call,
1607 response: Response<proto::Call>,
1608 session: Session,
1609) -> Result<()> {
1610 let room_id = RoomId::from_proto(request.room_id);
1611 let calling_user_id = session.user_id();
1612 let calling_connection_id = session.connection_id;
1613 let called_user_id = UserId::from_proto(request.called_user_id);
1614 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1615 if !session
1616 .db()
1617 .await
1618 .has_contact(calling_user_id, called_user_id)
1619 .await?
1620 {
1621 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1622 }
1623
1624 let incoming_call = {
1625 let (room, incoming_call) = &mut *session
1626 .db()
1627 .await
1628 .call(
1629 room_id,
1630 calling_user_id,
1631 calling_connection_id,
1632 called_user_id,
1633 initial_project_id,
1634 )
1635 .await?;
1636 room_updated(room, &session.peer);
1637 mem::take(incoming_call)
1638 };
1639 update_user_contacts(called_user_id, &session).await?;
1640
1641 let mut calls = session
1642 .connection_pool()
1643 .await
1644 .user_connection_ids(called_user_id)
1645 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1646 .collect::<FuturesUnordered<_>>();
1647
1648 while let Some(call_response) = calls.next().await {
1649 match call_response.as_ref() {
1650 Ok(_) => {
1651 response.send(proto::Ack {})?;
1652 return Ok(());
1653 }
1654 Err(_) => {
1655 call_response.trace_err();
1656 }
1657 }
1658 }
1659
1660 {
1661 let room = session
1662 .db()
1663 .await
1664 .call_failed(room_id, called_user_id)
1665 .await?;
1666 room_updated(&room, &session.peer);
1667 }
1668 update_user_contacts(called_user_id, &session).await?;
1669
1670 Err(anyhow!("failed to ring user"))?
1671}
1672
1673/// Cancel an outgoing call.
1674async fn cancel_call(
1675 request: proto::CancelCall,
1676 response: Response<proto::CancelCall>,
1677 session: Session,
1678) -> Result<()> {
1679 let called_user_id = UserId::from_proto(request.called_user_id);
1680 let room_id = RoomId::from_proto(request.room_id);
1681 {
1682 let room = session
1683 .db()
1684 .await
1685 .cancel_call(room_id, session.connection_id, called_user_id)
1686 .await?;
1687 room_updated(&room, &session.peer);
1688 }
1689
1690 for connection_id in session
1691 .connection_pool()
1692 .await
1693 .user_connection_ids(called_user_id)
1694 {
1695 session
1696 .peer
1697 .send(
1698 connection_id,
1699 proto::CallCanceled {
1700 room_id: room_id.to_proto(),
1701 },
1702 )
1703 .trace_err();
1704 }
1705 response.send(proto::Ack {})?;
1706
1707 update_user_contacts(called_user_id, &session).await?;
1708 Ok(())
1709}
1710
1711/// Decline an incoming call.
1712async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1713 let room_id = RoomId::from_proto(message.room_id);
1714 {
1715 let room = session
1716 .db()
1717 .await
1718 .decline_call(Some(room_id), session.user_id())
1719 .await?
1720 .context("declining call")?;
1721 room_updated(&room, &session.peer);
1722 }
1723
1724 for connection_id in session
1725 .connection_pool()
1726 .await
1727 .user_connection_ids(session.user_id())
1728 {
1729 session
1730 .peer
1731 .send(
1732 connection_id,
1733 proto::CallCanceled {
1734 room_id: room_id.to_proto(),
1735 },
1736 )
1737 .trace_err();
1738 }
1739 update_user_contacts(session.user_id(), &session).await?;
1740 Ok(())
1741}
1742
1743/// Updates other participants in the room with your current location.
1744async fn update_participant_location(
1745 request: proto::UpdateParticipantLocation,
1746 response: Response<proto::UpdateParticipantLocation>,
1747 session: Session,
1748) -> Result<()> {
1749 let room_id = RoomId::from_proto(request.room_id);
1750 let location = request.location.context("invalid location")?;
1751
1752 let db = session.db().await;
1753 let room = db
1754 .update_room_participant_location(room_id, session.connection_id, location)
1755 .await?;
1756
1757 room_updated(&room, &session.peer);
1758 response.send(proto::Ack {})?;
1759 Ok(())
1760}
1761
1762/// Share a project into the room.
1763async fn share_project(
1764 request: proto::ShareProject,
1765 response: Response<proto::ShareProject>,
1766 session: Session,
1767) -> Result<()> {
1768 let (project_id, room) = &*session
1769 .db()
1770 .await
1771 .share_project(
1772 RoomId::from_proto(request.room_id),
1773 session.connection_id,
1774 &request.worktrees,
1775 request.is_ssh_project,
1776 )
1777 .await?;
1778 response.send(proto::ShareProjectResponse {
1779 project_id: project_id.to_proto(),
1780 })?;
1781 room_updated(room, &session.peer);
1782
1783 Ok(())
1784}
1785
1786/// Unshare a project from the room.
1787async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1788 let project_id = ProjectId::from_proto(message.project_id);
1789 unshare_project_internal(project_id, session.connection_id, &session).await
1790}
1791
1792async fn unshare_project_internal(
1793 project_id: ProjectId,
1794 connection_id: ConnectionId,
1795 session: &Session,
1796) -> Result<()> {
1797 let delete = {
1798 let room_guard = session
1799 .db()
1800 .await
1801 .unshare_project(project_id, connection_id)
1802 .await?;
1803
1804 let (delete, room, guest_connection_ids) = &*room_guard;
1805
1806 let message = proto::UnshareProject {
1807 project_id: project_id.to_proto(),
1808 };
1809
1810 broadcast(
1811 Some(connection_id),
1812 guest_connection_ids.iter().copied(),
1813 |conn_id| session.peer.send(conn_id, message.clone()),
1814 );
1815 if let Some(room) = room {
1816 room_updated(room, &session.peer);
1817 }
1818
1819 *delete
1820 };
1821
1822 if delete {
1823 let db = session.db().await;
1824 db.delete_project(project_id).await?;
1825 }
1826
1827 Ok(())
1828}
1829
1830/// Join someone elses shared project.
1831async fn join_project(
1832 request: proto::JoinProject,
1833 response: Response<proto::JoinProject>,
1834 session: Session,
1835) -> Result<()> {
1836 let project_id = ProjectId::from_proto(request.project_id);
1837
1838 tracing::info!(%project_id, "join project");
1839
1840 let db = session.db().await;
1841 let (project, replica_id) = &mut *db
1842 .join_project(
1843 project_id,
1844 session.connection_id,
1845 session.user_id(),
1846 request.committer_name.clone(),
1847 request.committer_email.clone(),
1848 )
1849 .await?;
1850 drop(db);
1851 tracing::info!(%project_id, "join remote project");
1852 let collaborators = project
1853 .collaborators
1854 .iter()
1855 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1856 .map(|collaborator| collaborator.to_proto())
1857 .collect::<Vec<_>>();
1858 let project_id = project.id;
1859 let guest_user_id = session.user_id();
1860
1861 let worktrees = project
1862 .worktrees
1863 .iter()
1864 .map(|(id, worktree)| proto::WorktreeMetadata {
1865 id: *id,
1866 root_name: worktree.root_name.clone(),
1867 visible: worktree.visible,
1868 abs_path: worktree.abs_path.clone(),
1869 })
1870 .collect::<Vec<_>>();
1871
1872 let add_project_collaborator = proto::AddProjectCollaborator {
1873 project_id: project_id.to_proto(),
1874 collaborator: Some(proto::Collaborator {
1875 peer_id: Some(session.connection_id.into()),
1876 replica_id: replica_id.0 as u32,
1877 user_id: guest_user_id.to_proto(),
1878 is_host: false,
1879 committer_name: request.committer_name.clone(),
1880 committer_email: request.committer_email.clone(),
1881 }),
1882 };
1883
1884 for collaborator in &collaborators {
1885 session
1886 .peer
1887 .send(
1888 collaborator.peer_id.unwrap().into(),
1889 add_project_collaborator.clone(),
1890 )
1891 .trace_err();
1892 }
1893
1894 // First, we send the metadata associated with each worktree.
1895 response.send(proto::JoinProjectResponse {
1896 project_id: project.id.0 as u64,
1897 worktrees: worktrees.clone(),
1898 replica_id: replica_id.0 as u32,
1899 collaborators: collaborators.clone(),
1900 language_servers: project.language_servers.clone(),
1901 role: project.role.into(),
1902 })?;
1903
1904 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1905 // Stream this worktree's entries.
1906 let message = proto::UpdateWorktree {
1907 project_id: project_id.to_proto(),
1908 worktree_id,
1909 abs_path: worktree.abs_path.clone(),
1910 root_name: worktree.root_name,
1911 updated_entries: worktree.entries,
1912 removed_entries: Default::default(),
1913 scan_id: worktree.scan_id,
1914 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1915 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1916 removed_repositories: Default::default(),
1917 };
1918 for update in proto::split_worktree_update(message) {
1919 session.peer.send(session.connection_id, update.clone())?;
1920 }
1921
1922 // Stream this worktree's diagnostics.
1923 for summary in worktree.diagnostic_summaries {
1924 session.peer.send(
1925 session.connection_id,
1926 proto::UpdateDiagnosticSummary {
1927 project_id: project_id.to_proto(),
1928 worktree_id: worktree.id,
1929 summary: Some(summary),
1930 },
1931 )?;
1932 }
1933
1934 for settings_file in worktree.settings_files {
1935 session.peer.send(
1936 session.connection_id,
1937 proto::UpdateWorktreeSettings {
1938 project_id: project_id.to_proto(),
1939 worktree_id: worktree.id,
1940 path: settings_file.path,
1941 content: Some(settings_file.content),
1942 kind: Some(settings_file.kind.to_proto() as i32),
1943 },
1944 )?;
1945 }
1946 }
1947
1948 for repository in mem::take(&mut project.repositories) {
1949 for update in split_repository_update(repository) {
1950 session.peer.send(session.connection_id, update)?;
1951 }
1952 }
1953
1954 for language_server in &project.language_servers {
1955 session.peer.send(
1956 session.connection_id,
1957 proto::UpdateLanguageServer {
1958 project_id: project_id.to_proto(),
1959 language_server_id: language_server.id,
1960 variant: Some(
1961 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1962 proto::LspDiskBasedDiagnosticsUpdated {},
1963 ),
1964 ),
1965 },
1966 )?;
1967 }
1968
1969 Ok(())
1970}
1971
1972/// Leave someone elses shared project.
1973async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1974 let sender_id = session.connection_id;
1975 let project_id = ProjectId::from_proto(request.project_id);
1976 let db = session.db().await;
1977
1978 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1979 tracing::info!(
1980 %project_id,
1981 "leave project"
1982 );
1983
1984 project_left(project, &session);
1985 if let Some(room) = room {
1986 room_updated(room, &session.peer);
1987 }
1988
1989 Ok(())
1990}
1991
1992/// Updates other participants with changes to the project
1993async fn update_project(
1994 request: proto::UpdateProject,
1995 response: Response<proto::UpdateProject>,
1996 session: Session,
1997) -> Result<()> {
1998 let project_id = ProjectId::from_proto(request.project_id);
1999 let (room, guest_connection_ids) = &*session
2000 .db()
2001 .await
2002 .update_project(project_id, session.connection_id, &request.worktrees)
2003 .await?;
2004 broadcast(
2005 Some(session.connection_id),
2006 guest_connection_ids.iter().copied(),
2007 |connection_id| {
2008 session
2009 .peer
2010 .forward_send(session.connection_id, connection_id, request.clone())
2011 },
2012 );
2013 if let Some(room) = room {
2014 room_updated(room, &session.peer);
2015 }
2016 response.send(proto::Ack {})?;
2017
2018 Ok(())
2019}
2020
2021/// Updates other participants with changes to the worktree
2022async fn update_worktree(
2023 request: proto::UpdateWorktree,
2024 response: Response<proto::UpdateWorktree>,
2025 session: Session,
2026) -> Result<()> {
2027 let guest_connection_ids = session
2028 .db()
2029 .await
2030 .update_worktree(&request, session.connection_id)
2031 .await?;
2032
2033 broadcast(
2034 Some(session.connection_id),
2035 guest_connection_ids.iter().copied(),
2036 |connection_id| {
2037 session
2038 .peer
2039 .forward_send(session.connection_id, connection_id, request.clone())
2040 },
2041 );
2042 response.send(proto::Ack {})?;
2043 Ok(())
2044}
2045
2046async fn update_repository(
2047 request: proto::UpdateRepository,
2048 response: Response<proto::UpdateRepository>,
2049 session: Session,
2050) -> Result<()> {
2051 let guest_connection_ids = session
2052 .db()
2053 .await
2054 .update_repository(&request, session.connection_id)
2055 .await?;
2056
2057 broadcast(
2058 Some(session.connection_id),
2059 guest_connection_ids.iter().copied(),
2060 |connection_id| {
2061 session
2062 .peer
2063 .forward_send(session.connection_id, connection_id, request.clone())
2064 },
2065 );
2066 response.send(proto::Ack {})?;
2067 Ok(())
2068}
2069
2070async fn remove_repository(
2071 request: proto::RemoveRepository,
2072 response: Response<proto::RemoveRepository>,
2073 session: Session,
2074) -> Result<()> {
2075 let guest_connection_ids = session
2076 .db()
2077 .await
2078 .remove_repository(&request, session.connection_id)
2079 .await?;
2080
2081 broadcast(
2082 Some(session.connection_id),
2083 guest_connection_ids.iter().copied(),
2084 |connection_id| {
2085 session
2086 .peer
2087 .forward_send(session.connection_id, connection_id, request.clone())
2088 },
2089 );
2090 response.send(proto::Ack {})?;
2091 Ok(())
2092}
2093
2094/// Updates other participants with changes to the diagnostics
2095async fn update_diagnostic_summary(
2096 message: proto::UpdateDiagnosticSummary,
2097 session: Session,
2098) -> Result<()> {
2099 let guest_connection_ids = session
2100 .db()
2101 .await
2102 .update_diagnostic_summary(&message, session.connection_id)
2103 .await?;
2104
2105 broadcast(
2106 Some(session.connection_id),
2107 guest_connection_ids.iter().copied(),
2108 |connection_id| {
2109 session
2110 .peer
2111 .forward_send(session.connection_id, connection_id, message.clone())
2112 },
2113 );
2114
2115 Ok(())
2116}
2117
2118/// Updates other participants with changes to the worktree settings
2119async fn update_worktree_settings(
2120 message: proto::UpdateWorktreeSettings,
2121 session: Session,
2122) -> Result<()> {
2123 let guest_connection_ids = session
2124 .db()
2125 .await
2126 .update_worktree_settings(&message, session.connection_id)
2127 .await?;
2128
2129 broadcast(
2130 Some(session.connection_id),
2131 guest_connection_ids.iter().copied(),
2132 |connection_id| {
2133 session
2134 .peer
2135 .forward_send(session.connection_id, connection_id, message.clone())
2136 },
2137 );
2138
2139 Ok(())
2140}
2141
2142/// Notify other participants that a language server has started.
2143async fn start_language_server(
2144 request: proto::StartLanguageServer,
2145 session: Session,
2146) -> Result<()> {
2147 let guest_connection_ids = session
2148 .db()
2149 .await
2150 .start_language_server(&request, session.connection_id)
2151 .await?;
2152
2153 broadcast(
2154 Some(session.connection_id),
2155 guest_connection_ids.iter().copied(),
2156 |connection_id| {
2157 session
2158 .peer
2159 .forward_send(session.connection_id, connection_id, request.clone())
2160 },
2161 );
2162 Ok(())
2163}
2164
2165/// Notify other participants that a language server has changed.
2166async fn update_language_server(
2167 request: proto::UpdateLanguageServer,
2168 session: Session,
2169) -> Result<()> {
2170 let project_id = ProjectId::from_proto(request.project_id);
2171 let project_connection_ids = session
2172 .db()
2173 .await
2174 .project_connection_ids(project_id, session.connection_id, true)
2175 .await?;
2176 broadcast(
2177 Some(session.connection_id),
2178 project_connection_ids.iter().copied(),
2179 |connection_id| {
2180 session
2181 .peer
2182 .forward_send(session.connection_id, connection_id, request.clone())
2183 },
2184 );
2185 Ok(())
2186}
2187
2188/// forward a project request to the host. These requests should be read only
2189/// as guests are allowed to send them.
2190async fn forward_read_only_project_request<T>(
2191 request: T,
2192 response: Response<T>,
2193 session: Session,
2194) -> Result<()>
2195where
2196 T: EntityMessage + RequestMessage,
2197{
2198 let project_id = ProjectId::from_proto(request.remote_entity_id());
2199 let host_connection_id = session
2200 .db()
2201 .await
2202 .host_for_read_only_project_request(project_id, session.connection_id)
2203 .await?;
2204 let payload = session
2205 .peer
2206 .forward_request(session.connection_id, host_connection_id, request)
2207 .await?;
2208 response.send(payload)?;
2209 Ok(())
2210}
2211
2212async fn forward_find_search_candidates_request(
2213 request: proto::FindSearchCandidates,
2214 response: Response<proto::FindSearchCandidates>,
2215 session: Session,
2216) -> Result<()> {
2217 let project_id = ProjectId::from_proto(request.remote_entity_id());
2218 let host_connection_id = session
2219 .db()
2220 .await
2221 .host_for_read_only_project_request(project_id, session.connection_id)
2222 .await?;
2223 let payload = session
2224 .peer
2225 .forward_request(session.connection_id, host_connection_id, request)
2226 .await?;
2227 response.send(payload)?;
2228 Ok(())
2229}
2230
2231/// forward a project request to the host. These requests are disallowed
2232/// for guests.
2233async fn forward_mutating_project_request<T>(
2234 request: T,
2235 response: Response<T>,
2236 session: Session,
2237) -> Result<()>
2238where
2239 T: EntityMessage + RequestMessage,
2240{
2241 let project_id = ProjectId::from_proto(request.remote_entity_id());
2242
2243 let host_connection_id = session
2244 .db()
2245 .await
2246 .host_for_mutating_project_request(project_id, session.connection_id)
2247 .await?;
2248 let payload = session
2249 .peer
2250 .forward_request(session.connection_id, host_connection_id, request)
2251 .await?;
2252 response.send(payload)?;
2253 Ok(())
2254}
2255
2256/// Notify other participants that a new buffer has been created
2257async fn create_buffer_for_peer(
2258 request: proto::CreateBufferForPeer,
2259 session: Session,
2260) -> Result<()> {
2261 session
2262 .db()
2263 .await
2264 .check_user_is_project_host(
2265 ProjectId::from_proto(request.project_id),
2266 session.connection_id,
2267 )
2268 .await?;
2269 let peer_id = request.peer_id.context("invalid peer id")?;
2270 session
2271 .peer
2272 .forward_send(session.connection_id, peer_id.into(), request)?;
2273 Ok(())
2274}
2275
2276/// Notify other participants that a buffer has been updated. This is
2277/// allowed for guests as long as the update is limited to selections.
2278async fn update_buffer(
2279 request: proto::UpdateBuffer,
2280 response: Response<proto::UpdateBuffer>,
2281 session: Session,
2282) -> Result<()> {
2283 let project_id = ProjectId::from_proto(request.project_id);
2284 let mut capability = Capability::ReadOnly;
2285
2286 for op in request.operations.iter() {
2287 match op.variant {
2288 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2289 Some(_) => capability = Capability::ReadWrite,
2290 }
2291 }
2292
2293 let host = {
2294 let guard = session
2295 .db()
2296 .await
2297 .connections_for_buffer_update(project_id, session.connection_id, capability)
2298 .await?;
2299
2300 let (host, guests) = &*guard;
2301
2302 broadcast(
2303 Some(session.connection_id),
2304 guests.clone(),
2305 |connection_id| {
2306 session
2307 .peer
2308 .forward_send(session.connection_id, connection_id, request.clone())
2309 },
2310 );
2311
2312 *host
2313 };
2314
2315 if host != session.connection_id {
2316 session
2317 .peer
2318 .forward_request(session.connection_id, host, request.clone())
2319 .await?;
2320 }
2321
2322 response.send(proto::Ack {})?;
2323 Ok(())
2324}
2325
2326async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2327 let project_id = ProjectId::from_proto(message.project_id);
2328
2329 let operation = message.operation.as_ref().context("invalid operation")?;
2330 let capability = match operation.variant.as_ref() {
2331 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2332 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2333 match buffer_op.variant {
2334 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2335 Capability::ReadOnly
2336 }
2337 _ => Capability::ReadWrite,
2338 }
2339 } else {
2340 Capability::ReadWrite
2341 }
2342 }
2343 Some(_) => Capability::ReadWrite,
2344 None => Capability::ReadOnly,
2345 };
2346
2347 let guard = session
2348 .db()
2349 .await
2350 .connections_for_buffer_update(project_id, session.connection_id, capability)
2351 .await?;
2352
2353 let (host, guests) = &*guard;
2354
2355 broadcast(
2356 Some(session.connection_id),
2357 guests.iter().chain([host]).copied(),
2358 |connection_id| {
2359 session
2360 .peer
2361 .forward_send(session.connection_id, connection_id, message.clone())
2362 },
2363 );
2364
2365 Ok(())
2366}
2367
2368/// Notify other participants that a project has been updated.
2369async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2370 request: T,
2371 session: Session,
2372) -> Result<()> {
2373 let project_id = ProjectId::from_proto(request.remote_entity_id());
2374 let project_connection_ids = session
2375 .db()
2376 .await
2377 .project_connection_ids(project_id, session.connection_id, false)
2378 .await?;
2379
2380 broadcast(
2381 Some(session.connection_id),
2382 project_connection_ids.iter().copied(),
2383 |connection_id| {
2384 session
2385 .peer
2386 .forward_send(session.connection_id, connection_id, request.clone())
2387 },
2388 );
2389 Ok(())
2390}
2391
2392/// Start following another user in a call.
2393async fn follow(
2394 request: proto::Follow,
2395 response: Response<proto::Follow>,
2396 session: Session,
2397) -> Result<()> {
2398 let room_id = RoomId::from_proto(request.room_id);
2399 let project_id = request.project_id.map(ProjectId::from_proto);
2400 let leader_id = request.leader_id.context("invalid leader id")?.into();
2401 let follower_id = session.connection_id;
2402
2403 session
2404 .db()
2405 .await
2406 .check_room_participants(room_id, leader_id, session.connection_id)
2407 .await?;
2408
2409 let response_payload = session
2410 .peer
2411 .forward_request(session.connection_id, leader_id, request)
2412 .await?;
2413 response.send(response_payload)?;
2414
2415 if let Some(project_id) = project_id {
2416 let room = session
2417 .db()
2418 .await
2419 .follow(room_id, project_id, leader_id, follower_id)
2420 .await?;
2421 room_updated(&room, &session.peer);
2422 }
2423
2424 Ok(())
2425}
2426
2427/// Stop following another user in a call.
2428async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2429 let room_id = RoomId::from_proto(request.room_id);
2430 let project_id = request.project_id.map(ProjectId::from_proto);
2431 let leader_id = request.leader_id.context("invalid leader id")?.into();
2432 let follower_id = session.connection_id;
2433
2434 session
2435 .db()
2436 .await
2437 .check_room_participants(room_id, leader_id, session.connection_id)
2438 .await?;
2439
2440 session
2441 .peer
2442 .forward_send(session.connection_id, leader_id, request)?;
2443
2444 if let Some(project_id) = project_id {
2445 let room = session
2446 .db()
2447 .await
2448 .unfollow(room_id, project_id, leader_id, follower_id)
2449 .await?;
2450 room_updated(&room, &session.peer);
2451 }
2452
2453 Ok(())
2454}
2455
2456/// Notify everyone following you of your current location.
2457async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2458 let room_id = RoomId::from_proto(request.room_id);
2459 let database = session.db.lock().await;
2460
2461 let connection_ids = if let Some(project_id) = request.project_id {
2462 let project_id = ProjectId::from_proto(project_id);
2463 database
2464 .project_connection_ids(project_id, session.connection_id, true)
2465 .await?
2466 } else {
2467 database
2468 .room_connection_ids(room_id, session.connection_id)
2469 .await?
2470 };
2471
2472 // For now, don't send view update messages back to that view's current leader.
2473 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2474 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2475 _ => None,
2476 });
2477
2478 for connection_id in connection_ids.iter().cloned() {
2479 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2480 session
2481 .peer
2482 .forward_send(session.connection_id, connection_id, request.clone())?;
2483 }
2484 }
2485 Ok(())
2486}
2487
2488/// Get public data about users.
2489async fn get_users(
2490 request: proto::GetUsers,
2491 response: Response<proto::GetUsers>,
2492 session: Session,
2493) -> Result<()> {
2494 let user_ids = request
2495 .user_ids
2496 .into_iter()
2497 .map(UserId::from_proto)
2498 .collect();
2499 let users = session
2500 .db()
2501 .await
2502 .get_users_by_ids(user_ids)
2503 .await?
2504 .into_iter()
2505 .map(|user| proto::User {
2506 id: user.id.to_proto(),
2507 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2508 github_login: user.github_login,
2509 name: user.name,
2510 })
2511 .collect();
2512 response.send(proto::UsersResponse { users })?;
2513 Ok(())
2514}
2515
2516/// Search for users (to invite) buy Github login
2517async fn fuzzy_search_users(
2518 request: proto::FuzzySearchUsers,
2519 response: Response<proto::FuzzySearchUsers>,
2520 session: Session,
2521) -> Result<()> {
2522 let query = request.query;
2523 let users = match query.len() {
2524 0 => vec![],
2525 1 | 2 => session
2526 .db()
2527 .await
2528 .get_user_by_github_login(&query)
2529 .await?
2530 .into_iter()
2531 .collect(),
2532 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2533 };
2534 let users = users
2535 .into_iter()
2536 .filter(|user| user.id != session.user_id())
2537 .map(|user| proto::User {
2538 id: user.id.to_proto(),
2539 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2540 github_login: user.github_login,
2541 name: user.name,
2542 })
2543 .collect();
2544 response.send(proto::UsersResponse { users })?;
2545 Ok(())
2546}
2547
2548/// Send a contact request to another user.
2549async fn request_contact(
2550 request: proto::RequestContact,
2551 response: Response<proto::RequestContact>,
2552 session: Session,
2553) -> Result<()> {
2554 let requester_id = session.user_id();
2555 let responder_id = UserId::from_proto(request.responder_id);
2556 if requester_id == responder_id {
2557 return Err(anyhow!("cannot add yourself as a contact"))?;
2558 }
2559
2560 let notifications = session
2561 .db()
2562 .await
2563 .send_contact_request(requester_id, responder_id)
2564 .await?;
2565
2566 // Update outgoing contact requests of requester
2567 let mut update = proto::UpdateContacts::default();
2568 update.outgoing_requests.push(responder_id.to_proto());
2569 for connection_id in session
2570 .connection_pool()
2571 .await
2572 .user_connection_ids(requester_id)
2573 {
2574 session.peer.send(connection_id, update.clone())?;
2575 }
2576
2577 // Update incoming contact requests of responder
2578 let mut update = proto::UpdateContacts::default();
2579 update
2580 .incoming_requests
2581 .push(proto::IncomingContactRequest {
2582 requester_id: requester_id.to_proto(),
2583 });
2584 let connection_pool = session.connection_pool().await;
2585 for connection_id in connection_pool.user_connection_ids(responder_id) {
2586 session.peer.send(connection_id, update.clone())?;
2587 }
2588
2589 send_notifications(&connection_pool, &session.peer, notifications);
2590
2591 response.send(proto::Ack {})?;
2592 Ok(())
2593}
2594
2595/// Accept or decline a contact request
2596async fn respond_to_contact_request(
2597 request: proto::RespondToContactRequest,
2598 response: Response<proto::RespondToContactRequest>,
2599 session: Session,
2600) -> Result<()> {
2601 let responder_id = session.user_id();
2602 let requester_id = UserId::from_proto(request.requester_id);
2603 let db = session.db().await;
2604 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2605 db.dismiss_contact_notification(responder_id, requester_id)
2606 .await?;
2607 } else {
2608 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2609
2610 let notifications = db
2611 .respond_to_contact_request(responder_id, requester_id, accept)
2612 .await?;
2613 let requester_busy = db.is_user_busy(requester_id).await?;
2614 let responder_busy = db.is_user_busy(responder_id).await?;
2615
2616 let pool = session.connection_pool().await;
2617 // Update responder with new contact
2618 let mut update = proto::UpdateContacts::default();
2619 if accept {
2620 update
2621 .contacts
2622 .push(contact_for_user(requester_id, requester_busy, &pool));
2623 }
2624 update
2625 .remove_incoming_requests
2626 .push(requester_id.to_proto());
2627 for connection_id in pool.user_connection_ids(responder_id) {
2628 session.peer.send(connection_id, update.clone())?;
2629 }
2630
2631 // Update requester with new contact
2632 let mut update = proto::UpdateContacts::default();
2633 if accept {
2634 update
2635 .contacts
2636 .push(contact_for_user(responder_id, responder_busy, &pool));
2637 }
2638 update
2639 .remove_outgoing_requests
2640 .push(responder_id.to_proto());
2641
2642 for connection_id in pool.user_connection_ids(requester_id) {
2643 session.peer.send(connection_id, update.clone())?;
2644 }
2645
2646 send_notifications(&pool, &session.peer, notifications);
2647 }
2648
2649 response.send(proto::Ack {})?;
2650 Ok(())
2651}
2652
2653/// Remove a contact.
2654async fn remove_contact(
2655 request: proto::RemoveContact,
2656 response: Response<proto::RemoveContact>,
2657 session: Session,
2658) -> Result<()> {
2659 let requester_id = session.user_id();
2660 let responder_id = UserId::from_proto(request.user_id);
2661 let db = session.db().await;
2662 let (contact_accepted, deleted_notification_id) =
2663 db.remove_contact(requester_id, responder_id).await?;
2664
2665 let pool = session.connection_pool().await;
2666 // Update outgoing contact requests of requester
2667 let mut update = proto::UpdateContacts::default();
2668 if contact_accepted {
2669 update.remove_contacts.push(responder_id.to_proto());
2670 } else {
2671 update
2672 .remove_outgoing_requests
2673 .push(responder_id.to_proto());
2674 }
2675 for connection_id in pool.user_connection_ids(requester_id) {
2676 session.peer.send(connection_id, update.clone())?;
2677 }
2678
2679 // Update incoming contact requests of responder
2680 let mut update = proto::UpdateContacts::default();
2681 if contact_accepted {
2682 update.remove_contacts.push(requester_id.to_proto());
2683 } else {
2684 update
2685 .remove_incoming_requests
2686 .push(requester_id.to_proto());
2687 }
2688 for connection_id in pool.user_connection_ids(responder_id) {
2689 session.peer.send(connection_id, update.clone())?;
2690 if let Some(notification_id) = deleted_notification_id {
2691 session.peer.send(
2692 connection_id,
2693 proto::DeleteNotification {
2694 notification_id: notification_id.to_proto(),
2695 },
2696 )?;
2697 }
2698 }
2699
2700 response.send(proto::Ack {})?;
2701 Ok(())
2702}
2703
2704fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2705 version.0.minor() < 139
2706}
2707
2708async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2709 if is_staff {
2710 return Ok(proto::Plan::ZedPro);
2711 }
2712
2713 let subscription = db.get_active_billing_subscription(user_id).await?;
2714 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2715
2716 let plan = if let Some(subscription_kind) = subscription_kind {
2717 match subscription_kind {
2718 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2719 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2720 SubscriptionKind::ZedFree => proto::Plan::Free,
2721 }
2722 } else {
2723 proto::Plan::Free
2724 };
2725
2726 Ok(plan)
2727}
2728
2729async fn make_update_user_plan_message(
2730 user: &User,
2731 is_staff: bool,
2732 db: &Arc<Database>,
2733 llm_db: Option<Arc<LlmDatabase>>,
2734) -> Result<proto::UpdateUserPlan> {
2735 let feature_flags = db.get_user_flags(user.id).await?;
2736 let plan = current_plan(db, user.id, is_staff).await?;
2737 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2738 let billing_preferences = db.get_billing_preferences(user.id).await?;
2739
2740 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2741 let subscription = db.get_active_billing_subscription(user.id).await?;
2742
2743 let subscription_period =
2744 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2745
2746 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2747 llm_db
2748 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2749 .await?
2750 } else {
2751 None
2752 };
2753
2754 (subscription_period, usage)
2755 } else {
2756 (None, None)
2757 };
2758
2759 let account_too_young =
2760 !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2761
2762 Ok(proto::UpdateUserPlan {
2763 plan: plan.into(),
2764 trial_started_at: billing_customer
2765 .as_ref()
2766 .and_then(|billing_customer| billing_customer.trial_started_at)
2767 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2768 is_usage_based_billing_enabled: if is_staff {
2769 Some(true)
2770 } else {
2771 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2772 },
2773 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2774 proto::SubscriptionPeriod {
2775 started_at: started_at.timestamp() as u64,
2776 ended_at: ended_at.timestamp() as u64,
2777 }
2778 }),
2779 account_too_young: Some(account_too_young),
2780 has_overdue_invoices: billing_customer
2781 .map(|billing_customer| billing_customer.has_overdue_invoices),
2782 usage: usage.map(|usage| {
2783 let plan = match plan {
2784 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2785 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2786 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2787 };
2788
2789 let model_requests_limit = match plan.model_requests_limit() {
2790 zed_llm_client::UsageLimit::Limited(limit) => {
2791 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2792 && feature_flags
2793 .iter()
2794 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2795 {
2796 1_000
2797 } else {
2798 limit
2799 };
2800
2801 zed_llm_client::UsageLimit::Limited(limit)
2802 }
2803 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2804 };
2805
2806 proto::SubscriptionUsage {
2807 model_requests_usage_amount: usage.model_requests as u32,
2808 model_requests_usage_limit: Some(proto::UsageLimit {
2809 variant: Some(match model_requests_limit {
2810 zed_llm_client::UsageLimit::Limited(limit) => {
2811 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2812 limit: limit as u32,
2813 })
2814 }
2815 zed_llm_client::UsageLimit::Unlimited => {
2816 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2817 }
2818 }),
2819 }),
2820 edit_predictions_usage_amount: usage.edit_predictions as u32,
2821 edit_predictions_usage_limit: Some(proto::UsageLimit {
2822 variant: Some(match plan.edit_predictions_limit() {
2823 zed_llm_client::UsageLimit::Limited(limit) => {
2824 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2825 limit: limit as u32,
2826 })
2827 }
2828 zed_llm_client::UsageLimit::Unlimited => {
2829 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2830 }
2831 }),
2832 }),
2833 }
2834 }),
2835 })
2836}
2837
2838async fn update_user_plan(session: &Session) -> Result<()> {
2839 let db = session.db().await;
2840
2841 let update_user_plan = make_update_user_plan_message(
2842 session.principal.user(),
2843 session.is_staff(),
2844 &db.0,
2845 session.app_state.llm_db.clone(),
2846 )
2847 .await?;
2848
2849 session
2850 .peer
2851 .send(session.connection_id, update_user_plan)
2852 .trace_err();
2853
2854 Ok(())
2855}
2856
2857async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2858 subscribe_user_to_channels(session.user_id(), &session).await?;
2859 Ok(())
2860}
2861
2862async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2863 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2864 let mut pool = session.connection_pool().await;
2865 for membership in &channels_for_user.channel_memberships {
2866 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2867 }
2868 session.peer.send(
2869 session.connection_id,
2870 build_update_user_channels(&channels_for_user),
2871 )?;
2872 session.peer.send(
2873 session.connection_id,
2874 build_channels_update(channels_for_user),
2875 )?;
2876 Ok(())
2877}
2878
2879/// Creates a new channel.
2880async fn create_channel(
2881 request: proto::CreateChannel,
2882 response: Response<proto::CreateChannel>,
2883 session: Session,
2884) -> Result<()> {
2885 let db = session.db().await;
2886
2887 let parent_id = request.parent_id.map(ChannelId::from_proto);
2888 let (channel, membership) = db
2889 .create_channel(&request.name, parent_id, session.user_id())
2890 .await?;
2891
2892 let root_id = channel.root_id();
2893 let channel = Channel::from_model(channel);
2894
2895 response.send(proto::CreateChannelResponse {
2896 channel: Some(channel.to_proto()),
2897 parent_id: request.parent_id,
2898 })?;
2899
2900 let mut connection_pool = session.connection_pool().await;
2901 if let Some(membership) = membership {
2902 connection_pool.subscribe_to_channel(
2903 membership.user_id,
2904 membership.channel_id,
2905 membership.role,
2906 );
2907 let update = proto::UpdateUserChannels {
2908 channel_memberships: vec![proto::ChannelMembership {
2909 channel_id: membership.channel_id.to_proto(),
2910 role: membership.role.into(),
2911 }],
2912 ..Default::default()
2913 };
2914 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2915 session.peer.send(connection_id, update.clone())?;
2916 }
2917 }
2918
2919 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2920 if !role.can_see_channel(channel.visibility) {
2921 continue;
2922 }
2923
2924 let update = proto::UpdateChannels {
2925 channels: vec![channel.to_proto()],
2926 ..Default::default()
2927 };
2928 session.peer.send(connection_id, update.clone())?;
2929 }
2930
2931 Ok(())
2932}
2933
2934/// Delete a channel
2935async fn delete_channel(
2936 request: proto::DeleteChannel,
2937 response: Response<proto::DeleteChannel>,
2938 session: Session,
2939) -> Result<()> {
2940 let db = session.db().await;
2941
2942 let channel_id = request.channel_id;
2943 let (root_channel, removed_channels) = db
2944 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2945 .await?;
2946 response.send(proto::Ack {})?;
2947
2948 // Notify members of removed channels
2949 let mut update = proto::UpdateChannels::default();
2950 update
2951 .delete_channels
2952 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2953
2954 let connection_pool = session.connection_pool().await;
2955 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2956 session.peer.send(connection_id, update.clone())?;
2957 }
2958
2959 Ok(())
2960}
2961
2962/// Invite someone to join a channel.
2963async fn invite_channel_member(
2964 request: proto::InviteChannelMember,
2965 response: Response<proto::InviteChannelMember>,
2966 session: Session,
2967) -> Result<()> {
2968 let db = session.db().await;
2969 let channel_id = ChannelId::from_proto(request.channel_id);
2970 let invitee_id = UserId::from_proto(request.user_id);
2971 let InviteMemberResult {
2972 channel,
2973 notifications,
2974 } = db
2975 .invite_channel_member(
2976 channel_id,
2977 invitee_id,
2978 session.user_id(),
2979 request.role().into(),
2980 )
2981 .await?;
2982
2983 let update = proto::UpdateChannels {
2984 channel_invitations: vec![channel.to_proto()],
2985 ..Default::default()
2986 };
2987
2988 let connection_pool = session.connection_pool().await;
2989 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2990 session.peer.send(connection_id, update.clone())?;
2991 }
2992
2993 send_notifications(&connection_pool, &session.peer, notifications);
2994
2995 response.send(proto::Ack {})?;
2996 Ok(())
2997}
2998
2999/// remove someone from a channel
3000async fn remove_channel_member(
3001 request: proto::RemoveChannelMember,
3002 response: Response<proto::RemoveChannelMember>,
3003 session: Session,
3004) -> Result<()> {
3005 let db = session.db().await;
3006 let channel_id = ChannelId::from_proto(request.channel_id);
3007 let member_id = UserId::from_proto(request.user_id);
3008
3009 let RemoveChannelMemberResult {
3010 membership_update,
3011 notification_id,
3012 } = db
3013 .remove_channel_member(channel_id, member_id, session.user_id())
3014 .await?;
3015
3016 let mut connection_pool = session.connection_pool().await;
3017 notify_membership_updated(
3018 &mut connection_pool,
3019 membership_update,
3020 member_id,
3021 &session.peer,
3022 );
3023 for connection_id in connection_pool.user_connection_ids(member_id) {
3024 if let Some(notification_id) = notification_id {
3025 session
3026 .peer
3027 .send(
3028 connection_id,
3029 proto::DeleteNotification {
3030 notification_id: notification_id.to_proto(),
3031 },
3032 )
3033 .trace_err();
3034 }
3035 }
3036
3037 response.send(proto::Ack {})?;
3038 Ok(())
3039}
3040
3041/// Toggle the channel between public and private.
3042/// Care is taken to maintain the invariant that public channels only descend from public channels,
3043/// (though members-only channels can appear at any point in the hierarchy).
3044async fn set_channel_visibility(
3045 request: proto::SetChannelVisibility,
3046 response: Response<proto::SetChannelVisibility>,
3047 session: Session,
3048) -> Result<()> {
3049 let db = session.db().await;
3050 let channel_id = ChannelId::from_proto(request.channel_id);
3051 let visibility = request.visibility().into();
3052
3053 let channel_model = db
3054 .set_channel_visibility(channel_id, visibility, session.user_id())
3055 .await?;
3056 let root_id = channel_model.root_id();
3057 let channel = Channel::from_model(channel_model);
3058
3059 let mut connection_pool = session.connection_pool().await;
3060 for (user_id, role) in connection_pool
3061 .channel_user_ids(root_id)
3062 .collect::<Vec<_>>()
3063 .into_iter()
3064 {
3065 let update = if role.can_see_channel(channel.visibility) {
3066 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3067 proto::UpdateChannels {
3068 channels: vec![channel.to_proto()],
3069 ..Default::default()
3070 }
3071 } else {
3072 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3073 proto::UpdateChannels {
3074 delete_channels: vec![channel.id.to_proto()],
3075 ..Default::default()
3076 }
3077 };
3078
3079 for connection_id in connection_pool.user_connection_ids(user_id) {
3080 session.peer.send(connection_id, update.clone())?;
3081 }
3082 }
3083
3084 response.send(proto::Ack {})?;
3085 Ok(())
3086}
3087
3088/// Alter the role for a user in the channel.
3089async fn set_channel_member_role(
3090 request: proto::SetChannelMemberRole,
3091 response: Response<proto::SetChannelMemberRole>,
3092 session: Session,
3093) -> Result<()> {
3094 let db = session.db().await;
3095 let channel_id = ChannelId::from_proto(request.channel_id);
3096 let member_id = UserId::from_proto(request.user_id);
3097 let result = db
3098 .set_channel_member_role(
3099 channel_id,
3100 session.user_id(),
3101 member_id,
3102 request.role().into(),
3103 )
3104 .await?;
3105
3106 match result {
3107 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3108 let mut connection_pool = session.connection_pool().await;
3109 notify_membership_updated(
3110 &mut connection_pool,
3111 membership_update,
3112 member_id,
3113 &session.peer,
3114 )
3115 }
3116 db::SetMemberRoleResult::InviteUpdated(channel) => {
3117 let update = proto::UpdateChannels {
3118 channel_invitations: vec![channel.to_proto()],
3119 ..Default::default()
3120 };
3121
3122 for connection_id in session
3123 .connection_pool()
3124 .await
3125 .user_connection_ids(member_id)
3126 {
3127 session.peer.send(connection_id, update.clone())?;
3128 }
3129 }
3130 }
3131
3132 response.send(proto::Ack {})?;
3133 Ok(())
3134}
3135
3136/// Change the name of a channel
3137async fn rename_channel(
3138 request: proto::RenameChannel,
3139 response: Response<proto::RenameChannel>,
3140 session: Session,
3141) -> Result<()> {
3142 let db = session.db().await;
3143 let channel_id = ChannelId::from_proto(request.channel_id);
3144 let channel_model = db
3145 .rename_channel(channel_id, session.user_id(), &request.name)
3146 .await?;
3147 let root_id = channel_model.root_id();
3148 let channel = Channel::from_model(channel_model);
3149
3150 response.send(proto::RenameChannelResponse {
3151 channel: Some(channel.to_proto()),
3152 })?;
3153
3154 let connection_pool = session.connection_pool().await;
3155 let update = proto::UpdateChannels {
3156 channels: vec![channel.to_proto()],
3157 ..Default::default()
3158 };
3159 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3160 if role.can_see_channel(channel.visibility) {
3161 session.peer.send(connection_id, update.clone())?;
3162 }
3163 }
3164
3165 Ok(())
3166}
3167
3168/// Move a channel to a new parent.
3169async fn move_channel(
3170 request: proto::MoveChannel,
3171 response: Response<proto::MoveChannel>,
3172 session: Session,
3173) -> Result<()> {
3174 let channel_id = ChannelId::from_proto(request.channel_id);
3175 let to = ChannelId::from_proto(request.to);
3176
3177 let (root_id, channels) = session
3178 .db()
3179 .await
3180 .move_channel(channel_id, to, session.user_id())
3181 .await?;
3182
3183 let connection_pool = session.connection_pool().await;
3184 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3185 let channels = channels
3186 .iter()
3187 .filter_map(|channel| {
3188 if role.can_see_channel(channel.visibility) {
3189 Some(channel.to_proto())
3190 } else {
3191 None
3192 }
3193 })
3194 .collect::<Vec<_>>();
3195 if channels.is_empty() {
3196 continue;
3197 }
3198
3199 let update = proto::UpdateChannels {
3200 channels,
3201 ..Default::default()
3202 };
3203
3204 session.peer.send(connection_id, update.clone())?;
3205 }
3206
3207 response.send(Ack {})?;
3208 Ok(())
3209}
3210
3211/// Get the list of channel members
3212async fn get_channel_members(
3213 request: proto::GetChannelMembers,
3214 response: Response<proto::GetChannelMembers>,
3215 session: Session,
3216) -> Result<()> {
3217 let db = session.db().await;
3218 let channel_id = ChannelId::from_proto(request.channel_id);
3219 let limit = if request.limit == 0 {
3220 u16::MAX as u64
3221 } else {
3222 request.limit
3223 };
3224 let (members, users) = db
3225 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3226 .await?;
3227 response.send(proto::GetChannelMembersResponse { members, users })?;
3228 Ok(())
3229}
3230
3231/// Accept or decline a channel invitation.
3232async fn respond_to_channel_invite(
3233 request: proto::RespondToChannelInvite,
3234 response: Response<proto::RespondToChannelInvite>,
3235 session: Session,
3236) -> Result<()> {
3237 let db = session.db().await;
3238 let channel_id = ChannelId::from_proto(request.channel_id);
3239 let RespondToChannelInvite {
3240 membership_update,
3241 notifications,
3242 } = db
3243 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3244 .await?;
3245
3246 let mut connection_pool = session.connection_pool().await;
3247 if let Some(membership_update) = membership_update {
3248 notify_membership_updated(
3249 &mut connection_pool,
3250 membership_update,
3251 session.user_id(),
3252 &session.peer,
3253 );
3254 } else {
3255 let update = proto::UpdateChannels {
3256 remove_channel_invitations: vec![channel_id.to_proto()],
3257 ..Default::default()
3258 };
3259
3260 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3261 session.peer.send(connection_id, update.clone())?;
3262 }
3263 };
3264
3265 send_notifications(&connection_pool, &session.peer, notifications);
3266
3267 response.send(proto::Ack {})?;
3268
3269 Ok(())
3270}
3271
3272/// Join the channels' room
3273async fn join_channel(
3274 request: proto::JoinChannel,
3275 response: Response<proto::JoinChannel>,
3276 session: Session,
3277) -> Result<()> {
3278 let channel_id = ChannelId::from_proto(request.channel_id);
3279 join_channel_internal(channel_id, Box::new(response), session).await
3280}
3281
3282trait JoinChannelInternalResponse {
3283 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3284}
3285impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3286 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3287 Response::<proto::JoinChannel>::send(self, result)
3288 }
3289}
3290impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3291 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3292 Response::<proto::JoinRoom>::send(self, result)
3293 }
3294}
3295
3296async fn join_channel_internal(
3297 channel_id: ChannelId,
3298 response: Box<impl JoinChannelInternalResponse>,
3299 session: Session,
3300) -> Result<()> {
3301 let joined_room = {
3302 let mut db = session.db().await;
3303 // If zed quits without leaving the room, and the user re-opens zed before the
3304 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3305 // room they were in.
3306 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3307 tracing::info!(
3308 stale_connection_id = %connection,
3309 "cleaning up stale connection",
3310 );
3311 drop(db);
3312 leave_room_for_session(&session, connection).await?;
3313 db = session.db().await;
3314 }
3315
3316 let (joined_room, membership_updated, role) = db
3317 .join_channel(channel_id, session.user_id(), session.connection_id)
3318 .await?;
3319
3320 let live_kit_connection_info =
3321 session
3322 .app_state
3323 .livekit_client
3324 .as_ref()
3325 .and_then(|live_kit| {
3326 let (can_publish, token) = if role == ChannelRole::Guest {
3327 (
3328 false,
3329 live_kit
3330 .guest_token(
3331 &joined_room.room.livekit_room,
3332 &session.user_id().to_string(),
3333 )
3334 .trace_err()?,
3335 )
3336 } else {
3337 (
3338 true,
3339 live_kit
3340 .room_token(
3341 &joined_room.room.livekit_room,
3342 &session.user_id().to_string(),
3343 )
3344 .trace_err()?,
3345 )
3346 };
3347
3348 Some(LiveKitConnectionInfo {
3349 server_url: live_kit.url().into(),
3350 token,
3351 can_publish,
3352 })
3353 });
3354
3355 response.send(proto::JoinRoomResponse {
3356 room: Some(joined_room.room.clone()),
3357 channel_id: joined_room
3358 .channel
3359 .as_ref()
3360 .map(|channel| channel.id.to_proto()),
3361 live_kit_connection_info,
3362 })?;
3363
3364 let mut connection_pool = session.connection_pool().await;
3365 if let Some(membership_updated) = membership_updated {
3366 notify_membership_updated(
3367 &mut connection_pool,
3368 membership_updated,
3369 session.user_id(),
3370 &session.peer,
3371 );
3372 }
3373
3374 room_updated(&joined_room.room, &session.peer);
3375
3376 joined_room
3377 };
3378
3379 channel_updated(
3380 &joined_room.channel.context("channel not returned")?,
3381 &joined_room.room,
3382 &session.peer,
3383 &*session.connection_pool().await,
3384 );
3385
3386 update_user_contacts(session.user_id(), &session).await?;
3387 Ok(())
3388}
3389
3390/// Start editing the channel notes
3391async fn join_channel_buffer(
3392 request: proto::JoinChannelBuffer,
3393 response: Response<proto::JoinChannelBuffer>,
3394 session: Session,
3395) -> Result<()> {
3396 let db = session.db().await;
3397 let channel_id = ChannelId::from_proto(request.channel_id);
3398
3399 let open_response = db
3400 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3401 .await?;
3402
3403 let collaborators = open_response.collaborators.clone();
3404 response.send(open_response)?;
3405
3406 let update = UpdateChannelBufferCollaborators {
3407 channel_id: channel_id.to_proto(),
3408 collaborators: collaborators.clone(),
3409 };
3410 channel_buffer_updated(
3411 session.connection_id,
3412 collaborators
3413 .iter()
3414 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3415 &update,
3416 &session.peer,
3417 );
3418
3419 Ok(())
3420}
3421
3422/// Edit the channel notes
3423async fn update_channel_buffer(
3424 request: proto::UpdateChannelBuffer,
3425 session: Session,
3426) -> Result<()> {
3427 let db = session.db().await;
3428 let channel_id = ChannelId::from_proto(request.channel_id);
3429
3430 let (collaborators, epoch, version) = db
3431 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3432 .await?;
3433
3434 channel_buffer_updated(
3435 session.connection_id,
3436 collaborators.clone(),
3437 &proto::UpdateChannelBuffer {
3438 channel_id: channel_id.to_proto(),
3439 operations: request.operations,
3440 },
3441 &session.peer,
3442 );
3443
3444 let pool = &*session.connection_pool().await;
3445
3446 let non_collaborators =
3447 pool.channel_connection_ids(channel_id)
3448 .filter_map(|(connection_id, _)| {
3449 if collaborators.contains(&connection_id) {
3450 None
3451 } else {
3452 Some(connection_id)
3453 }
3454 });
3455
3456 broadcast(None, non_collaborators, |peer_id| {
3457 session.peer.send(
3458 peer_id,
3459 proto::UpdateChannels {
3460 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3461 channel_id: channel_id.to_proto(),
3462 epoch: epoch as u64,
3463 version: version.clone(),
3464 }],
3465 ..Default::default()
3466 },
3467 )
3468 });
3469
3470 Ok(())
3471}
3472
3473/// Rejoin the channel notes after a connection blip
3474async fn rejoin_channel_buffers(
3475 request: proto::RejoinChannelBuffers,
3476 response: Response<proto::RejoinChannelBuffers>,
3477 session: Session,
3478) -> Result<()> {
3479 let db = session.db().await;
3480 let buffers = db
3481 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3482 .await?;
3483
3484 for rejoined_buffer in &buffers {
3485 let collaborators_to_notify = rejoined_buffer
3486 .buffer
3487 .collaborators
3488 .iter()
3489 .filter_map(|c| Some(c.peer_id?.into()));
3490 channel_buffer_updated(
3491 session.connection_id,
3492 collaborators_to_notify,
3493 &proto::UpdateChannelBufferCollaborators {
3494 channel_id: rejoined_buffer.buffer.channel_id,
3495 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3496 },
3497 &session.peer,
3498 );
3499 }
3500
3501 response.send(proto::RejoinChannelBuffersResponse {
3502 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3503 })?;
3504
3505 Ok(())
3506}
3507
3508/// Stop editing the channel notes
3509async fn leave_channel_buffer(
3510 request: proto::LeaveChannelBuffer,
3511 response: Response<proto::LeaveChannelBuffer>,
3512 session: Session,
3513) -> Result<()> {
3514 let db = session.db().await;
3515 let channel_id = ChannelId::from_proto(request.channel_id);
3516
3517 let left_buffer = db
3518 .leave_channel_buffer(channel_id, session.connection_id)
3519 .await?;
3520
3521 response.send(Ack {})?;
3522
3523 channel_buffer_updated(
3524 session.connection_id,
3525 left_buffer.connections,
3526 &proto::UpdateChannelBufferCollaborators {
3527 channel_id: channel_id.to_proto(),
3528 collaborators: left_buffer.collaborators,
3529 },
3530 &session.peer,
3531 );
3532
3533 Ok(())
3534}
3535
3536fn channel_buffer_updated<T: EnvelopedMessage>(
3537 sender_id: ConnectionId,
3538 collaborators: impl IntoIterator<Item = ConnectionId>,
3539 message: &T,
3540 peer: &Peer,
3541) {
3542 broadcast(Some(sender_id), collaborators, |peer_id| {
3543 peer.send(peer_id, message.clone())
3544 });
3545}
3546
3547fn send_notifications(
3548 connection_pool: &ConnectionPool,
3549 peer: &Peer,
3550 notifications: db::NotificationBatch,
3551) {
3552 for (user_id, notification) in notifications {
3553 for connection_id in connection_pool.user_connection_ids(user_id) {
3554 if let Err(error) = peer.send(
3555 connection_id,
3556 proto::AddNotification {
3557 notification: Some(notification.clone()),
3558 },
3559 ) {
3560 tracing::error!(
3561 "failed to send notification to {:?} {}",
3562 connection_id,
3563 error
3564 );
3565 }
3566 }
3567 }
3568}
3569
3570/// Send a message to the channel
3571async fn send_channel_message(
3572 request: proto::SendChannelMessage,
3573 response: Response<proto::SendChannelMessage>,
3574 session: Session,
3575) -> Result<()> {
3576 // Validate the message body.
3577 let body = request.body.trim().to_string();
3578 if body.len() > MAX_MESSAGE_LEN {
3579 return Err(anyhow!("message is too long"))?;
3580 }
3581 if body.is_empty() {
3582 return Err(anyhow!("message can't be blank"))?;
3583 }
3584
3585 // TODO: adjust mentions if body is trimmed
3586
3587 let timestamp = OffsetDateTime::now_utc();
3588 let nonce = request.nonce.context("nonce can't be blank")?;
3589
3590 let channel_id = ChannelId::from_proto(request.channel_id);
3591 let CreatedChannelMessage {
3592 message_id,
3593 participant_connection_ids,
3594 notifications,
3595 } = session
3596 .db()
3597 .await
3598 .create_channel_message(
3599 channel_id,
3600 session.user_id(),
3601 &body,
3602 &request.mentions,
3603 timestamp,
3604 nonce.clone().into(),
3605 request.reply_to_message_id.map(MessageId::from_proto),
3606 )
3607 .await?;
3608
3609 let message = proto::ChannelMessage {
3610 sender_id: session.user_id().to_proto(),
3611 id: message_id.to_proto(),
3612 body,
3613 mentions: request.mentions,
3614 timestamp: timestamp.unix_timestamp() as u64,
3615 nonce: Some(nonce),
3616 reply_to_message_id: request.reply_to_message_id,
3617 edited_at: None,
3618 };
3619 broadcast(
3620 Some(session.connection_id),
3621 participant_connection_ids.clone(),
3622 |connection| {
3623 session.peer.send(
3624 connection,
3625 proto::ChannelMessageSent {
3626 channel_id: channel_id.to_proto(),
3627 message: Some(message.clone()),
3628 },
3629 )
3630 },
3631 );
3632 response.send(proto::SendChannelMessageResponse {
3633 message: Some(message),
3634 })?;
3635
3636 let pool = &*session.connection_pool().await;
3637 let non_participants =
3638 pool.channel_connection_ids(channel_id)
3639 .filter_map(|(connection_id, _)| {
3640 if participant_connection_ids.contains(&connection_id) {
3641 None
3642 } else {
3643 Some(connection_id)
3644 }
3645 });
3646 broadcast(None, non_participants, |peer_id| {
3647 session.peer.send(
3648 peer_id,
3649 proto::UpdateChannels {
3650 latest_channel_message_ids: vec![proto::ChannelMessageId {
3651 channel_id: channel_id.to_proto(),
3652 message_id: message_id.to_proto(),
3653 }],
3654 ..Default::default()
3655 },
3656 )
3657 });
3658 send_notifications(pool, &session.peer, notifications);
3659
3660 Ok(())
3661}
3662
3663/// Delete a channel message
3664async fn remove_channel_message(
3665 request: proto::RemoveChannelMessage,
3666 response: Response<proto::RemoveChannelMessage>,
3667 session: Session,
3668) -> Result<()> {
3669 let channel_id = ChannelId::from_proto(request.channel_id);
3670 let message_id = MessageId::from_proto(request.message_id);
3671 let (connection_ids, existing_notification_ids) = session
3672 .db()
3673 .await
3674 .remove_channel_message(channel_id, message_id, session.user_id())
3675 .await?;
3676
3677 broadcast(
3678 Some(session.connection_id),
3679 connection_ids,
3680 move |connection| {
3681 session.peer.send(connection, request.clone())?;
3682
3683 for notification_id in &existing_notification_ids {
3684 session.peer.send(
3685 connection,
3686 proto::DeleteNotification {
3687 notification_id: (*notification_id).to_proto(),
3688 },
3689 )?;
3690 }
3691
3692 Ok(())
3693 },
3694 );
3695 response.send(proto::Ack {})?;
3696 Ok(())
3697}
3698
3699async fn update_channel_message(
3700 request: proto::UpdateChannelMessage,
3701 response: Response<proto::UpdateChannelMessage>,
3702 session: Session,
3703) -> Result<()> {
3704 let channel_id = ChannelId::from_proto(request.channel_id);
3705 let message_id = MessageId::from_proto(request.message_id);
3706 let updated_at = OffsetDateTime::now_utc();
3707 let UpdatedChannelMessage {
3708 message_id,
3709 participant_connection_ids,
3710 notifications,
3711 reply_to_message_id,
3712 timestamp,
3713 deleted_mention_notification_ids,
3714 updated_mention_notifications,
3715 } = session
3716 .db()
3717 .await
3718 .update_channel_message(
3719 channel_id,
3720 message_id,
3721 session.user_id(),
3722 request.body.as_str(),
3723 &request.mentions,
3724 updated_at,
3725 )
3726 .await?;
3727
3728 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3729
3730 let message = proto::ChannelMessage {
3731 sender_id: session.user_id().to_proto(),
3732 id: message_id.to_proto(),
3733 body: request.body.clone(),
3734 mentions: request.mentions.clone(),
3735 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3736 nonce: Some(nonce),
3737 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3738 edited_at: Some(updated_at.unix_timestamp() as u64),
3739 };
3740
3741 response.send(proto::Ack {})?;
3742
3743 let pool = &*session.connection_pool().await;
3744 broadcast(
3745 Some(session.connection_id),
3746 participant_connection_ids,
3747 |connection| {
3748 session.peer.send(
3749 connection,
3750 proto::ChannelMessageUpdate {
3751 channel_id: channel_id.to_proto(),
3752 message: Some(message.clone()),
3753 },
3754 )?;
3755
3756 for notification_id in &deleted_mention_notification_ids {
3757 session.peer.send(
3758 connection,
3759 proto::DeleteNotification {
3760 notification_id: (*notification_id).to_proto(),
3761 },
3762 )?;
3763 }
3764
3765 for notification in &updated_mention_notifications {
3766 session.peer.send(
3767 connection,
3768 proto::UpdateNotification {
3769 notification: Some(notification.clone()),
3770 },
3771 )?;
3772 }
3773
3774 Ok(())
3775 },
3776 );
3777
3778 send_notifications(pool, &session.peer, notifications);
3779
3780 Ok(())
3781}
3782
3783/// Mark a channel message as read
3784async fn acknowledge_channel_message(
3785 request: proto::AckChannelMessage,
3786 session: Session,
3787) -> Result<()> {
3788 let channel_id = ChannelId::from_proto(request.channel_id);
3789 let message_id = MessageId::from_proto(request.message_id);
3790 let notifications = session
3791 .db()
3792 .await
3793 .observe_channel_message(channel_id, session.user_id(), message_id)
3794 .await?;
3795 send_notifications(
3796 &*session.connection_pool().await,
3797 &session.peer,
3798 notifications,
3799 );
3800 Ok(())
3801}
3802
3803/// Mark a buffer version as synced
3804async fn acknowledge_buffer_version(
3805 request: proto::AckBufferOperation,
3806 session: Session,
3807) -> Result<()> {
3808 let buffer_id = BufferId::from_proto(request.buffer_id);
3809 session
3810 .db()
3811 .await
3812 .observe_buffer_version(
3813 buffer_id,
3814 session.user_id(),
3815 request.epoch as i32,
3816 &request.version,
3817 )
3818 .await?;
3819 Ok(())
3820}
3821
3822/// Get a Supermaven API key for the user
3823async fn get_supermaven_api_key(
3824 _request: proto::GetSupermavenApiKey,
3825 response: Response<proto::GetSupermavenApiKey>,
3826 session: Session,
3827) -> Result<()> {
3828 let user_id: String = session.user_id().to_string();
3829 if !session.is_staff() {
3830 return Err(anyhow!("supermaven not enabled for this account"))?;
3831 }
3832
3833 let email = session.email().context("user must have an email")?;
3834
3835 let supermaven_admin_api = session
3836 .supermaven_client
3837 .as_ref()
3838 .context("supermaven not configured")?;
3839
3840 let result = supermaven_admin_api
3841 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3842 .await?;
3843
3844 response.send(proto::GetSupermavenApiKeyResponse {
3845 api_key: result.api_key,
3846 })?;
3847
3848 Ok(())
3849}
3850
3851/// Start receiving chat updates for a channel
3852async fn join_channel_chat(
3853 request: proto::JoinChannelChat,
3854 response: Response<proto::JoinChannelChat>,
3855 session: Session,
3856) -> Result<()> {
3857 let channel_id = ChannelId::from_proto(request.channel_id);
3858
3859 let db = session.db().await;
3860 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3861 .await?;
3862 let messages = db
3863 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3864 .await?;
3865 response.send(proto::JoinChannelChatResponse {
3866 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3867 messages,
3868 })?;
3869 Ok(())
3870}
3871
3872/// Stop receiving chat updates for a channel
3873async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3874 let channel_id = ChannelId::from_proto(request.channel_id);
3875 session
3876 .db()
3877 .await
3878 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3879 .await?;
3880 Ok(())
3881}
3882
3883/// Retrieve the chat history for a channel
3884async fn get_channel_messages(
3885 request: proto::GetChannelMessages,
3886 response: Response<proto::GetChannelMessages>,
3887 session: Session,
3888) -> Result<()> {
3889 let channel_id = ChannelId::from_proto(request.channel_id);
3890 let messages = session
3891 .db()
3892 .await
3893 .get_channel_messages(
3894 channel_id,
3895 session.user_id(),
3896 MESSAGE_COUNT_PER_PAGE,
3897 Some(MessageId::from_proto(request.before_message_id)),
3898 )
3899 .await?;
3900 response.send(proto::GetChannelMessagesResponse {
3901 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3902 messages,
3903 })?;
3904 Ok(())
3905}
3906
3907/// Retrieve specific chat messages
3908async fn get_channel_messages_by_id(
3909 request: proto::GetChannelMessagesById,
3910 response: Response<proto::GetChannelMessagesById>,
3911 session: Session,
3912) -> Result<()> {
3913 let message_ids = request
3914 .message_ids
3915 .iter()
3916 .map(|id| MessageId::from_proto(*id))
3917 .collect::<Vec<_>>();
3918 let messages = session
3919 .db()
3920 .await
3921 .get_channel_messages_by_id(session.user_id(), &message_ids)
3922 .await?;
3923 response.send(proto::GetChannelMessagesResponse {
3924 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3925 messages,
3926 })?;
3927 Ok(())
3928}
3929
3930/// Retrieve the current users notifications
3931async fn get_notifications(
3932 request: proto::GetNotifications,
3933 response: Response<proto::GetNotifications>,
3934 session: Session,
3935) -> Result<()> {
3936 let notifications = session
3937 .db()
3938 .await
3939 .get_notifications(
3940 session.user_id(),
3941 NOTIFICATION_COUNT_PER_PAGE,
3942 request.before_id.map(db::NotificationId::from_proto),
3943 )
3944 .await?;
3945 response.send(proto::GetNotificationsResponse {
3946 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3947 notifications,
3948 })?;
3949 Ok(())
3950}
3951
3952/// Mark notifications as read
3953async fn mark_notification_as_read(
3954 request: proto::MarkNotificationRead,
3955 response: Response<proto::MarkNotificationRead>,
3956 session: Session,
3957) -> Result<()> {
3958 let database = &session.db().await;
3959 let notifications = database
3960 .mark_notification_as_read_by_id(
3961 session.user_id(),
3962 NotificationId::from_proto(request.notification_id),
3963 )
3964 .await?;
3965 send_notifications(
3966 &*session.connection_pool().await,
3967 &session.peer,
3968 notifications,
3969 );
3970 response.send(proto::Ack {})?;
3971 Ok(())
3972}
3973
3974/// Get the current users information
3975async fn get_private_user_info(
3976 _request: proto::GetPrivateUserInfo,
3977 response: Response<proto::GetPrivateUserInfo>,
3978 session: Session,
3979) -> Result<()> {
3980 let db = session.db().await;
3981
3982 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3983 let user = db
3984 .get_user_by_id(session.user_id())
3985 .await?
3986 .context("user not found")?;
3987 let flags = db.get_user_flags(session.user_id()).await?;
3988
3989 response.send(proto::GetPrivateUserInfoResponse {
3990 metrics_id,
3991 staff: user.admin,
3992 flags,
3993 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3994 })?;
3995 Ok(())
3996}
3997
3998/// Accept the terms of service (tos) on behalf of the current user
3999async fn accept_terms_of_service(
4000 _request: proto::AcceptTermsOfService,
4001 response: Response<proto::AcceptTermsOfService>,
4002 session: Session,
4003) -> Result<()> {
4004 let db = session.db().await;
4005
4006 let accepted_tos_at = Utc::now();
4007 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4008 .await?;
4009
4010 response.send(proto::AcceptTermsOfServiceResponse {
4011 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4012 })?;
4013 Ok(())
4014}
4015
4016/// The minimum account age an account must have in order to use the LLM service.
4017pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4018
4019async fn get_llm_api_token(
4020 _request: proto::GetLlmToken,
4021 response: Response<proto::GetLlmToken>,
4022 session: Session,
4023) -> Result<()> {
4024 let db = session.db().await;
4025
4026 let flags = db.get_user_flags(session.user_id()).await?;
4027
4028 let user_id = session.user_id();
4029 let user = db
4030 .get_user_by_id(user_id)
4031 .await?
4032 .with_context(|| format!("user {user_id} not found"))?;
4033
4034 if user.accepted_tos_at.is_none() {
4035 Err(anyhow!("terms of service not accepted"))?
4036 }
4037
4038 let stripe_client = session
4039 .app_state
4040 .stripe_client
4041 .as_ref()
4042 .context("failed to retrieve Stripe client")?;
4043
4044 let stripe_billing = session
4045 .app_state
4046 .stripe_billing
4047 .as_ref()
4048 .context("failed to retrieve Stripe billing object")?;
4049
4050 let billing_customer = if let Some(billing_customer) =
4051 db.get_billing_customer_by_user_id(user.id).await?
4052 {
4053 billing_customer
4054 } else {
4055 let customer_id = stripe_billing
4056 .find_or_create_customer_by_email(user.email_address.as_deref())
4057 .await?;
4058
4059 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4060 .await?
4061 .context("billing customer not found")?
4062 };
4063
4064 let billing_subscription =
4065 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4066 billing_subscription
4067 } else {
4068 let stripe_customer_id =
4069 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4070
4071 let stripe_subscription = stripe_billing
4072 .subscribe_to_zed_free(stripe_customer_id)
4073 .await?;
4074
4075 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4076 billing_customer_id: billing_customer.id,
4077 kind: Some(SubscriptionKind::ZedFree),
4078 stripe_subscription_id: stripe_subscription.id.to_string(),
4079 stripe_subscription_status: stripe_subscription.status.into(),
4080 stripe_cancellation_reason: None,
4081 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4082 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4083 })
4084 .await?
4085 };
4086
4087 let billing_preferences = db.get_billing_preferences(user.id).await?;
4088
4089 let token = LlmTokenClaims::create(
4090 &user,
4091 session.is_staff(),
4092 billing_customer,
4093 billing_preferences,
4094 &flags,
4095 billing_subscription,
4096 session.system_id.clone(),
4097 &session.app_state.config,
4098 )?;
4099 response.send(proto::GetLlmTokenResponse { token })?;
4100 Ok(())
4101}
4102
4103fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4104 let message = match message {
4105 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4106 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4107 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4108 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4109 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4110 code: frame.code.into(),
4111 reason: frame.reason.as_str().to_owned().into(),
4112 })),
4113 // We should never receive a frame while reading the message, according
4114 // to the `tungstenite` maintainers:
4115 //
4116 // > It cannot occur when you read messages from the WebSocket, but it
4117 // > can be used when you want to send the raw frames (e.g. you want to
4118 // > send the frames to the WebSocket without composing the full message first).
4119 // >
4120 // > — https://github.com/snapview/tungstenite-rs/issues/268
4121 TungsteniteMessage::Frame(_) => {
4122 bail!("received an unexpected frame while reading the message")
4123 }
4124 };
4125
4126 Ok(message)
4127}
4128
4129fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4130 match message {
4131 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4132 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4133 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4134 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4135 AxumMessage::Close(frame) => {
4136 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4137 code: frame.code.into(),
4138 reason: frame.reason.as_ref().into(),
4139 }))
4140 }
4141 }
4142}
4143
4144fn notify_membership_updated(
4145 connection_pool: &mut ConnectionPool,
4146 result: MembershipUpdated,
4147 user_id: UserId,
4148 peer: &Peer,
4149) {
4150 for membership in &result.new_channels.channel_memberships {
4151 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4152 }
4153 for channel_id in &result.removed_channels {
4154 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4155 }
4156
4157 let user_channels_update = proto::UpdateUserChannels {
4158 channel_memberships: result
4159 .new_channels
4160 .channel_memberships
4161 .iter()
4162 .map(|cm| proto::ChannelMembership {
4163 channel_id: cm.channel_id.to_proto(),
4164 role: cm.role.into(),
4165 })
4166 .collect(),
4167 ..Default::default()
4168 };
4169
4170 let mut update = build_channels_update(result.new_channels);
4171 update.delete_channels = result
4172 .removed_channels
4173 .into_iter()
4174 .map(|id| id.to_proto())
4175 .collect();
4176 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4177
4178 for connection_id in connection_pool.user_connection_ids(user_id) {
4179 peer.send(connection_id, user_channels_update.clone())
4180 .trace_err();
4181 peer.send(connection_id, update.clone()).trace_err();
4182 }
4183}
4184
4185fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4186 proto::UpdateUserChannels {
4187 channel_memberships: channels
4188 .channel_memberships
4189 .iter()
4190 .map(|m| proto::ChannelMembership {
4191 channel_id: m.channel_id.to_proto(),
4192 role: m.role.into(),
4193 })
4194 .collect(),
4195 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4196 observed_channel_message_id: channels.observed_channel_messages.clone(),
4197 }
4198}
4199
4200fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4201 let mut update = proto::UpdateChannels::default();
4202
4203 for channel in channels.channels {
4204 update.channels.push(channel.to_proto());
4205 }
4206
4207 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4208 update.latest_channel_message_ids = channels.latest_channel_messages;
4209
4210 for (channel_id, participants) in channels.channel_participants {
4211 update
4212 .channel_participants
4213 .push(proto::ChannelParticipants {
4214 channel_id: channel_id.to_proto(),
4215 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4216 });
4217 }
4218
4219 for channel in channels.invited_channels {
4220 update.channel_invitations.push(channel.to_proto());
4221 }
4222
4223 update
4224}
4225
4226fn build_initial_contacts_update(
4227 contacts: Vec<db::Contact>,
4228 pool: &ConnectionPool,
4229) -> proto::UpdateContacts {
4230 let mut update = proto::UpdateContacts::default();
4231
4232 for contact in contacts {
4233 match contact {
4234 db::Contact::Accepted { user_id, busy } => {
4235 update.contacts.push(contact_for_user(user_id, busy, pool));
4236 }
4237 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4238 db::Contact::Incoming { user_id } => {
4239 update
4240 .incoming_requests
4241 .push(proto::IncomingContactRequest {
4242 requester_id: user_id.to_proto(),
4243 })
4244 }
4245 }
4246 }
4247
4248 update
4249}
4250
4251fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4252 proto::Contact {
4253 user_id: user_id.to_proto(),
4254 online: pool.is_user_online(user_id),
4255 busy,
4256 }
4257}
4258
4259fn room_updated(room: &proto::Room, peer: &Peer) {
4260 broadcast(
4261 None,
4262 room.participants
4263 .iter()
4264 .filter_map(|participant| Some(participant.peer_id?.into())),
4265 |peer_id| {
4266 peer.send(
4267 peer_id,
4268 proto::RoomUpdated {
4269 room: Some(room.clone()),
4270 },
4271 )
4272 },
4273 );
4274}
4275
4276fn channel_updated(
4277 channel: &db::channel::Model,
4278 room: &proto::Room,
4279 peer: &Peer,
4280 pool: &ConnectionPool,
4281) {
4282 let participants = room
4283 .participants
4284 .iter()
4285 .map(|p| p.user_id)
4286 .collect::<Vec<_>>();
4287
4288 broadcast(
4289 None,
4290 pool.channel_connection_ids(channel.root_id())
4291 .filter_map(|(channel_id, role)| {
4292 role.can_see_channel(channel.visibility)
4293 .then_some(channel_id)
4294 }),
4295 |peer_id| {
4296 peer.send(
4297 peer_id,
4298 proto::UpdateChannels {
4299 channel_participants: vec![proto::ChannelParticipants {
4300 channel_id: channel.id.to_proto(),
4301 participant_user_ids: participants.clone(),
4302 }],
4303 ..Default::default()
4304 },
4305 )
4306 },
4307 );
4308}
4309
4310async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4311 let db = session.db().await;
4312
4313 let contacts = db.get_contacts(user_id).await?;
4314 let busy = db.is_user_busy(user_id).await?;
4315
4316 let pool = session.connection_pool().await;
4317 let updated_contact = contact_for_user(user_id, busy, &pool);
4318 for contact in contacts {
4319 if let db::Contact::Accepted {
4320 user_id: contact_user_id,
4321 ..
4322 } = contact
4323 {
4324 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4325 session
4326 .peer
4327 .send(
4328 contact_conn_id,
4329 proto::UpdateContacts {
4330 contacts: vec![updated_contact.clone()],
4331 remove_contacts: Default::default(),
4332 incoming_requests: Default::default(),
4333 remove_incoming_requests: Default::default(),
4334 outgoing_requests: Default::default(),
4335 remove_outgoing_requests: Default::default(),
4336 },
4337 )
4338 .trace_err();
4339 }
4340 }
4341 }
4342 Ok(())
4343}
4344
4345async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4346 let mut contacts_to_update = HashSet::default();
4347
4348 let room_id;
4349 let canceled_calls_to_user_ids;
4350 let livekit_room;
4351 let delete_livekit_room;
4352 let room;
4353 let channel;
4354
4355 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4356 contacts_to_update.insert(session.user_id());
4357
4358 for project in left_room.left_projects.values() {
4359 project_left(project, session);
4360 }
4361
4362 room_id = RoomId::from_proto(left_room.room.id);
4363 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4364 livekit_room = mem::take(&mut left_room.room.livekit_room);
4365 delete_livekit_room = left_room.deleted;
4366 room = mem::take(&mut left_room.room);
4367 channel = mem::take(&mut left_room.channel);
4368
4369 room_updated(&room, &session.peer);
4370 } else {
4371 return Ok(());
4372 }
4373
4374 if let Some(channel) = channel {
4375 channel_updated(
4376 &channel,
4377 &room,
4378 &session.peer,
4379 &*session.connection_pool().await,
4380 );
4381 }
4382
4383 {
4384 let pool = session.connection_pool().await;
4385 for canceled_user_id in canceled_calls_to_user_ids {
4386 for connection_id in pool.user_connection_ids(canceled_user_id) {
4387 session
4388 .peer
4389 .send(
4390 connection_id,
4391 proto::CallCanceled {
4392 room_id: room_id.to_proto(),
4393 },
4394 )
4395 .trace_err();
4396 }
4397 contacts_to_update.insert(canceled_user_id);
4398 }
4399 }
4400
4401 for contact_user_id in contacts_to_update {
4402 update_user_contacts(contact_user_id, session).await?;
4403 }
4404
4405 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4406 live_kit
4407 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4408 .await
4409 .trace_err();
4410
4411 if delete_livekit_room {
4412 live_kit.delete_room(livekit_room).await.trace_err();
4413 }
4414 }
4415
4416 Ok(())
4417}
4418
4419async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4420 let left_channel_buffers = session
4421 .db()
4422 .await
4423 .leave_channel_buffers(session.connection_id)
4424 .await?;
4425
4426 for left_buffer in left_channel_buffers {
4427 channel_buffer_updated(
4428 session.connection_id,
4429 left_buffer.connections,
4430 &proto::UpdateChannelBufferCollaborators {
4431 channel_id: left_buffer.channel_id.to_proto(),
4432 collaborators: left_buffer.collaborators,
4433 },
4434 &session.peer,
4435 );
4436 }
4437
4438 Ok(())
4439}
4440
4441fn project_left(project: &db::LeftProject, session: &Session) {
4442 for connection_id in &project.connection_ids {
4443 if project.should_unshare {
4444 session
4445 .peer
4446 .send(
4447 *connection_id,
4448 proto::UnshareProject {
4449 project_id: project.id.to_proto(),
4450 },
4451 )
4452 .trace_err();
4453 } else {
4454 session
4455 .peer
4456 .send(
4457 *connection_id,
4458 proto::RemoveProjectCollaborator {
4459 project_id: project.id.to_proto(),
4460 peer_id: Some(session.connection_id.into()),
4461 },
4462 )
4463 .trace_err();
4464 }
4465 }
4466}
4467
4468pub trait ResultExt {
4469 type Ok;
4470
4471 fn trace_err(self) -> Option<Self::Ok>;
4472}
4473
4474impl<T, E> ResultExt for Result<T, E>
4475where
4476 E: std::fmt::Debug,
4477{
4478 type Ok = T;
4479
4480 #[track_caller]
4481 fn trace_err(self) -> Option<T> {
4482 match self {
4483 Ok(value) => Some(value),
4484 Err(error) => {
4485 tracing::error!("{:?}", error);
4486 None
4487 }
4488 }
4489 }
4490}