1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
8use crate::{
9 AppState, Error, Result, auth,
10 db::{
11 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
12 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
13 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
14 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
15 },
16 executor::Executor,
17};
18use anyhow::{Context as _, anyhow, bail};
19use async_tungstenite::tungstenite::{
20 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
21};
22use axum::{
23 Extension, Router, TypedHeader,
24 body::Body,
25 extract::{
26 ConnectInfo, WebSocketUpgrade,
27 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
28 },
29 headers::{Header, HeaderName},
30 http::StatusCode,
31 middleware,
32 response::IntoResponse,
33 routing::get,
34};
35use chrono::Utc;
36use collections::{HashMap, HashSet};
37pub use connection_pool::{ConnectionPool, ZedVersion};
38use core::fmt::{self, Debug, Formatter};
39use reqwest_client::ReqwestClient;
40use rpc::proto::split_repository_update;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
45 stream::FuturesUnordered,
46};
47use prometheus::{IntGauge, register_int_gauge};
48use rpc::{
49 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
50 proto::{
51 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
52 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
53 },
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 Arc, OnceLock,
67 atomic::{AtomicBool, Ordering::SeqCst},
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{Semaphore, watch};
73use tower::ServiceBuilder;
74use tracing::{
75 Instrument,
76 field::{self},
77 info_span, instrument,
78};
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106#[derive(Clone, Debug)]
107pub enum Principal {
108 User(User),
109 Impersonated { user: User, admin: User },
110}
111
112impl Principal {
113 fn user(&self) -> &User {
114 match self {
115 Principal::User(user) => user,
116 Principal::Impersonated { user, .. } => user,
117 }
118 }
119
120 fn update_span(&self, span: &tracing::Span) {
121 match &self {
122 Principal::User(user) => {
123 span.record("user_id", user.id.0);
124 span.record("login", &user.github_login);
125 }
126 Principal::Impersonated { user, admin } => {
127 span.record("user_id", user.id.0);
128 span.record("login", &user.github_login);
129 span.record("impersonator", &admin.github_login);
130 }
131 }
132 }
133}
134
135#[derive(Clone)]
136struct Session {
137 principal: Principal,
138 connection_id: ConnectionId,
139 db: Arc<tokio::sync::Mutex<DbHandle>>,
140 peer: Arc<Peer>,
141 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
142 app_state: Arc<AppState>,
143 supermaven_client: Option<Arc<SupermavenAdminApi>>,
144 /// The GeoIP country code for the user.
145 #[allow(unused)]
146 geoip_country_code: Option<String>,
147 system_id: Option<String>,
148 _executor: Executor,
149}
150
151impl Session {
152 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
153 #[cfg(test)]
154 tokio::task::yield_now().await;
155 let guard = self.db.lock().await;
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 guard
159 }
160
161 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
162 #[cfg(test)]
163 tokio::task::yield_now().await;
164 let guard = self.connection_pool.lock();
165 ConnectionPoolGuard {
166 guard,
167 _not_send: PhantomData,
168 }
169 }
170
171 fn is_staff(&self) -> bool {
172 match &self.principal {
173 Principal::User(user) => user.admin,
174 Principal::Impersonated { .. } => true,
175 }
176 }
177
178 fn user_id(&self) -> UserId {
179 match &self.principal {
180 Principal::User(user) => user.id,
181 Principal::Impersonated { user, .. } => user.id,
182 }
183 }
184
185 pub fn email(&self) -> Option<String> {
186 match &self.principal {
187 Principal::User(user) => user.email_address.clone(),
188 Principal::Impersonated { user, .. } => user.email_address.clone(),
189 }
190 }
191}
192
193impl Debug for Session {
194 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
195 let mut result = f.debug_struct("Session");
196 match &self.principal {
197 Principal::User(user) => {
198 result.field("user", &user.github_login);
199 }
200 Principal::Impersonated { user, admin } => {
201 result.field("user", &user.github_login);
202 result.field("impersonator", &admin.github_login);
203 }
204 }
205 result.field("connection_id", &self.connection_id).finish()
206 }
207}
208
209struct DbHandle(Arc<Database>);
210
211impl Deref for DbHandle {
212 type Target = Database;
213
214 fn deref(&self) -> &Self::Target {
215 self.0.as_ref()
216 }
217}
218
219pub struct Server {
220 id: parking_lot::Mutex<ServerId>,
221 peer: Arc<Peer>,
222 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
223 app_state: Arc<AppState>,
224 handlers: HashMap<TypeId, MessageHandler>,
225 teardown: watch::Sender<bool>,
226}
227
228pub(crate) struct ConnectionPoolGuard<'a> {
229 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
230 _not_send: PhantomData<Rc<()>>,
231}
232
233#[derive(Serialize)]
234pub struct ServerSnapshot<'a> {
235 peer: &'a Peer,
236 #[serde(serialize_with = "serialize_deref")]
237 connection_pool: ConnectionPoolGuard<'a>,
238}
239
240pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
241where
242 S: Serializer,
243 T: Deref<Target = U>,
244 U: Serialize,
245{
246 Serialize::serialize(value.deref(), serializer)
247}
248
249impl Server {
250 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
251 let mut server = Self {
252 id: parking_lot::Mutex::new(id),
253 peer: Peer::new(id.0 as u32),
254 app_state: app_state.clone(),
255 connection_pool: Default::default(),
256 handlers: Default::default(),
257 teardown: watch::channel(false).0,
258 };
259
260 server
261 .add_request_handler(ping)
262 .add_request_handler(create_room)
263 .add_request_handler(join_room)
264 .add_request_handler(rejoin_room)
265 .add_request_handler(leave_room)
266 .add_request_handler(set_room_participant_role)
267 .add_request_handler(call)
268 .add_request_handler(cancel_call)
269 .add_message_handler(decline_call)
270 .add_request_handler(update_participant_location)
271 .add_request_handler(share_project)
272 .add_message_handler(unshare_project)
273 .add_request_handler(join_project)
274 .add_message_handler(leave_project)
275 .add_request_handler(update_project)
276 .add_request_handler(update_worktree)
277 .add_request_handler(update_repository)
278 .add_request_handler(remove_repository)
279 .add_message_handler(start_language_server)
280 .add_message_handler(update_language_server)
281 .add_message_handler(update_diagnostic_summary)
282 .add_message_handler(update_worktree_settings)
283 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
284 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
285 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
286 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
287 .add_request_handler(forward_find_search_candidates_request)
288 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
289 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
290 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
291 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
292 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
293 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
294 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
295 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
296 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
297 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
298 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
299 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
300 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
301 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
302 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
303 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
304 .add_request_handler(
305 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
306 )
307 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
308 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
309 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
310 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
311 .add_request_handler(
312 forward_read_only_project_request::<proto::LanguageServerIdForName>,
313 )
314 .add_request_handler(
315 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
316 )
317 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
318 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
323 .add_request_handler(
324 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
325 )
326 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
327 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
328 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
332 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
333 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
337 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
338 .add_request_handler(
339 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
340 )
341 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
342 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
343 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
344 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
345 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
346 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
347 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
348 .add_message_handler(create_buffer_for_peer)
349 .add_request_handler(update_buffer)
350 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
353 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
354 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
355 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
356 .add_request_handler(get_users)
357 .add_request_handler(fuzzy_search_users)
358 .add_request_handler(request_contact)
359 .add_request_handler(remove_contact)
360 .add_request_handler(respond_to_contact_request)
361 .add_message_handler(subscribe_to_channels)
362 .add_request_handler(create_channel)
363 .add_request_handler(delete_channel)
364 .add_request_handler(invite_channel_member)
365 .add_request_handler(remove_channel_member)
366 .add_request_handler(set_channel_member_role)
367 .add_request_handler(set_channel_visibility)
368 .add_request_handler(rename_channel)
369 .add_request_handler(join_channel_buffer)
370 .add_request_handler(leave_channel_buffer)
371 .add_message_handler(update_channel_buffer)
372 .add_request_handler(rejoin_channel_buffers)
373 .add_request_handler(get_channel_members)
374 .add_request_handler(respond_to_channel_invite)
375 .add_request_handler(join_channel)
376 .add_request_handler(join_channel_chat)
377 .add_message_handler(leave_channel_chat)
378 .add_request_handler(send_channel_message)
379 .add_request_handler(remove_channel_message)
380 .add_request_handler(update_channel_message)
381 .add_request_handler(get_channel_messages)
382 .add_request_handler(get_channel_messages_by_id)
383 .add_request_handler(get_notifications)
384 .add_request_handler(mark_notification_as_read)
385 .add_request_handler(move_channel)
386 .add_request_handler(follow)
387 .add_message_handler(unfollow)
388 .add_message_handler(update_followers)
389 .add_request_handler(get_private_user_info)
390 .add_request_handler(get_llm_api_token)
391 .add_request_handler(accept_terms_of_service)
392 .add_message_handler(acknowledge_channel_message)
393 .add_message_handler(acknowledge_buffer_version)
394 .add_request_handler(get_supermaven_api_key)
395 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
396 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
397 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
398 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
399 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
400 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
401 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
402 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
403 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
404 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
405 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
406 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
407 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
408 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
409 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
410 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
411 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
412 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
413 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
414 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
415 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
416 .add_message_handler(update_context);
417
418 Arc::new(server)
419 }
420
421 pub async fn start(&self) -> Result<()> {
422 let server_id = *self.id.lock();
423 let app_state = self.app_state.clone();
424 let peer = self.peer.clone();
425 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
426 let pool = self.connection_pool.clone();
427 let livekit_client = self.app_state.livekit_client.clone();
428
429 let span = info_span!("start server");
430 self.app_state.executor.spawn_detached(
431 async move {
432 tracing::info!("waiting for cleanup timeout");
433 timeout.await;
434 tracing::info!("cleanup timeout expired, retrieving stale rooms");
435 if let Some((room_ids, channel_ids)) = app_state
436 .db
437 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
438 .await
439 .trace_err()
440 {
441 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
442 tracing::info!(
443 stale_channel_buffer_count = channel_ids.len(),
444 "retrieved stale channel buffers"
445 );
446
447 for channel_id in channel_ids {
448 if let Some(refreshed_channel_buffer) = app_state
449 .db
450 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
451 .await
452 .trace_err()
453 {
454 for connection_id in refreshed_channel_buffer.connection_ids {
455 peer.send(
456 connection_id,
457 proto::UpdateChannelBufferCollaborators {
458 channel_id: channel_id.to_proto(),
459 collaborators: refreshed_channel_buffer
460 .collaborators
461 .clone(),
462 },
463 )
464 .trace_err();
465 }
466 }
467 }
468
469 for room_id in room_ids {
470 let mut contacts_to_update = HashSet::default();
471 let mut canceled_calls_to_user_ids = Vec::new();
472 let mut livekit_room = String::new();
473 let mut delete_livekit_room = false;
474
475 if let Some(mut refreshed_room) = app_state
476 .db
477 .clear_stale_room_participants(room_id, server_id)
478 .await
479 .trace_err()
480 {
481 tracing::info!(
482 room_id = room_id.0,
483 new_participant_count = refreshed_room.room.participants.len(),
484 "refreshed room"
485 );
486 room_updated(&refreshed_room.room, &peer);
487 if let Some(channel) = refreshed_room.channel.as_ref() {
488 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
489 }
490 contacts_to_update
491 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
492 contacts_to_update
493 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
494 canceled_calls_to_user_ids =
495 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
496 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
497 delete_livekit_room = refreshed_room.room.participants.is_empty();
498 }
499
500 {
501 let pool = pool.lock();
502 for canceled_user_id in canceled_calls_to_user_ids {
503 for connection_id in pool.user_connection_ids(canceled_user_id) {
504 peer.send(
505 connection_id,
506 proto::CallCanceled {
507 room_id: room_id.to_proto(),
508 },
509 )
510 .trace_err();
511 }
512 }
513 }
514
515 for user_id in contacts_to_update {
516 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
517 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
518 if let Some((busy, contacts)) = busy.zip(contacts) {
519 let pool = pool.lock();
520 let updated_contact = contact_for_user(user_id, busy, &pool);
521 for contact in contacts {
522 if let db::Contact::Accepted {
523 user_id: contact_user_id,
524 ..
525 } = contact
526 {
527 for contact_conn_id in
528 pool.user_connection_ids(contact_user_id)
529 {
530 peer.send(
531 contact_conn_id,
532 proto::UpdateContacts {
533 contacts: vec![updated_contact.clone()],
534 remove_contacts: Default::default(),
535 incoming_requests: Default::default(),
536 remove_incoming_requests: Default::default(),
537 outgoing_requests: Default::default(),
538 remove_outgoing_requests: Default::default(),
539 },
540 )
541 .trace_err();
542 }
543 }
544 }
545 }
546 }
547
548 if let Some(live_kit) = livekit_client.as_ref() {
549 if delete_livekit_room {
550 live_kit.delete_room(livekit_room).await.trace_err();
551 }
552 }
553 }
554 }
555
556 app_state
557 .db
558 .delete_stale_servers(&app_state.config.zed_environment, server_id)
559 .await
560 .trace_err();
561 }
562 .instrument(span),
563 );
564 Ok(())
565 }
566
567 pub fn teardown(&self) {
568 self.peer.teardown();
569 self.connection_pool.lock().reset();
570 let _ = self.teardown.send(true);
571 }
572
573 #[cfg(test)]
574 pub fn reset(&self, id: ServerId) {
575 self.teardown();
576 *self.id.lock() = id;
577 self.peer.reset(id.0 as u32);
578 let _ = self.teardown.send(false);
579 }
580
581 #[cfg(test)]
582 pub fn id(&self) -> ServerId {
583 *self.id.lock()
584 }
585
586 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
587 where
588 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
589 Fut: 'static + Send + Future<Output = Result<()>>,
590 M: EnvelopedMessage,
591 {
592 let prev_handler = self.handlers.insert(
593 TypeId::of::<M>(),
594 Box::new(move |envelope, session| {
595 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
596 let received_at = envelope.received_at;
597 tracing::info!("message received");
598 let start_time = Instant::now();
599 let future = (handler)(*envelope, session);
600 async move {
601 let result = future.await;
602 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
603 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
604 let queue_duration_ms = total_duration_ms - processing_duration_ms;
605 let payload_type = M::NAME;
606
607 match result {
608 Err(error) => {
609 tracing::error!(
610 ?error,
611 total_duration_ms,
612 processing_duration_ms,
613 queue_duration_ms,
614 payload_type,
615 "error handling message"
616 )
617 }
618 Ok(()) => tracing::info!(
619 total_duration_ms,
620 processing_duration_ms,
621 queue_duration_ms,
622 "finished handling message"
623 ),
624 }
625 }
626 .boxed()
627 }),
628 );
629 if prev_handler.is_some() {
630 panic!("registered a handler for the same message twice");
631 }
632 self
633 }
634
635 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
636 where
637 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
638 Fut: 'static + Send + Future<Output = Result<()>>,
639 M: EnvelopedMessage,
640 {
641 self.add_handler(move |envelope, session| handler(envelope.payload, session));
642 self
643 }
644
645 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
646 where
647 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
648 Fut: Send + Future<Output = Result<()>>,
649 M: RequestMessage,
650 {
651 let handler = Arc::new(handler);
652 self.add_handler(move |envelope, session| {
653 let receipt = envelope.receipt();
654 let handler = handler.clone();
655 async move {
656 let peer = session.peer.clone();
657 let responded = Arc::new(AtomicBool::default());
658 let response = Response {
659 peer: peer.clone(),
660 responded: responded.clone(),
661 receipt,
662 };
663 match (handler)(envelope.payload, response, session).await {
664 Ok(()) => {
665 if responded.load(std::sync::atomic::Ordering::SeqCst) {
666 Ok(())
667 } else {
668 Err(anyhow!("handler did not send a response"))?
669 }
670 }
671 Err(error) => {
672 let proto_err = match &error {
673 Error::Internal(err) => err.to_proto(),
674 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
675 };
676 peer.respond_with_error(receipt, proto_err)?;
677 Err(error)
678 }
679 }
680 }
681 })
682 }
683
684 pub fn handle_connection(
685 self: &Arc<Self>,
686 connection: Connection,
687 address: String,
688 principal: Principal,
689 zed_version: ZedVersion,
690 geoip_country_code: Option<String>,
691 system_id: Option<String>,
692 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
693 executor: Executor,
694 ) -> impl Future<Output = ()> + use<> {
695 let this = self.clone();
696 let span = info_span!("handle connection", %address,
697 connection_id=field::Empty,
698 user_id=field::Empty,
699 login=field::Empty,
700 impersonator=field::Empty,
701 geoip_country_code=field::Empty
702 );
703 principal.update_span(&span);
704 if let Some(country_code) = geoip_country_code.as_ref() {
705 span.record("geoip_country_code", country_code);
706 }
707
708 let mut teardown = self.teardown.subscribe();
709 async move {
710 if *teardown.borrow() {
711 tracing::error!("server is tearing down");
712 return
713 }
714 let (connection_id, handle_io, mut incoming_rx) = this
715 .peer
716 .add_connection(connection, {
717 let executor = executor.clone();
718 move |duration| executor.sleep(duration)
719 });
720 tracing::Span::current().record("connection_id", format!("{}", connection_id));
721
722 tracing::info!("connection opened");
723
724 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
725 let http_client = match ReqwestClient::user_agent(&user_agent) {
726 Ok(http_client) => Arc::new(http_client),
727 Err(error) => {
728 tracing::error!(?error, "failed to create HTTP client");
729 return;
730 }
731 };
732
733 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
734 supermaven_admin_api_key.to_string(),
735 http_client.clone(),
736 )));
737
738 let session = Session {
739 principal: principal.clone(),
740 connection_id,
741 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
742 peer: this.peer.clone(),
743 connection_pool: this.connection_pool.clone(),
744 app_state: this.app_state.clone(),
745 geoip_country_code,
746 system_id,
747 _executor: executor.clone(),
748 supermaven_client,
749 };
750
751 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
752 tracing::error!(?error, "failed to send initial client update");
753 return;
754 }
755
756 let handle_io = handle_io.fuse();
757 futures::pin_mut!(handle_io);
758
759 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
760 // This prevents deadlocks when e.g., client A performs a request to client B and
761 // client B performs a request to client A. If both clients stop processing further
762 // messages until their respective request completes, they won't have a chance to
763 // respond to the other client's request and cause a deadlock.
764 //
765 // This arrangement ensures we will attempt to process earlier messages first, but fall
766 // back to processing messages arrived later in the spirit of making progress.
767 let mut foreground_message_handlers = FuturesUnordered::new();
768 let concurrent_handlers = Arc::new(Semaphore::new(256));
769 loop {
770 let next_message = async {
771 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
772 let message = incoming_rx.next().await;
773 (permit, message)
774 }.fuse();
775 futures::pin_mut!(next_message);
776 futures::select_biased! {
777 _ = teardown.changed().fuse() => return,
778 result = handle_io => {
779 if let Err(error) = result {
780 tracing::error!(?error, "error handling I/O");
781 }
782 break;
783 }
784 _ = foreground_message_handlers.next() => {}
785 next_message = next_message => {
786 let (permit, message) = next_message;
787 if let Some(message) = message {
788 let type_name = message.payload_type_name();
789 // note: we copy all the fields from the parent span so we can query them in the logs.
790 // (https://github.com/tokio-rs/tracing/issues/2670).
791 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
792 user_id=field::Empty,
793 login=field::Empty,
794 impersonator=field::Empty,
795 );
796 principal.update_span(&span);
797 let span_enter = span.enter();
798 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
799 let is_background = message.is_background();
800 let handle_message = (handler)(message, session.clone());
801 drop(span_enter);
802
803 let handle_message = async move {
804 handle_message.await;
805 drop(permit);
806 }.instrument(span);
807 if is_background {
808 executor.spawn_detached(handle_message);
809 } else {
810 foreground_message_handlers.push(handle_message);
811 }
812 } else {
813 tracing::error!("no message handler");
814 }
815 } else {
816 tracing::info!("connection closed");
817 break;
818 }
819 }
820 }
821 }
822
823 drop(foreground_message_handlers);
824 tracing::info!("signing out");
825 if let Err(error) = connection_lost(session, teardown, executor).await {
826 tracing::error!(?error, "error signing out");
827 }
828
829 }.instrument(span)
830 }
831
832 async fn send_initial_client_update(
833 &self,
834 connection_id: ConnectionId,
835 zed_version: ZedVersion,
836 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
837 session: &Session,
838 ) -> Result<()> {
839 self.peer.send(
840 connection_id,
841 proto::Hello {
842 peer_id: Some(connection_id.into()),
843 },
844 )?;
845 tracing::info!("sent hello message");
846 if let Some(send_connection_id) = send_connection_id.take() {
847 let _ = send_connection_id.send(connection_id);
848 }
849
850 match &session.principal {
851 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
852 if !user.connected_once {
853 self.peer.send(connection_id, proto::ShowContacts {})?;
854 self.app_state
855 .db
856 .set_user_connected_once(user.id, true)
857 .await?;
858 }
859
860 update_user_plan(session).await?;
861
862 let contacts = self.app_state.db.get_contacts(user.id).await?;
863
864 {
865 let mut pool = self.connection_pool.lock();
866 pool.add_connection(connection_id, user.id, user.admin, zed_version);
867 self.peer.send(
868 connection_id,
869 build_initial_contacts_update(contacts, &pool),
870 )?;
871 }
872
873 if should_auto_subscribe_to_channels(zed_version) {
874 subscribe_user_to_channels(user.id, session).await?;
875 }
876
877 if let Some(incoming_call) =
878 self.app_state.db.incoming_call_for_user(user.id).await?
879 {
880 self.peer.send(connection_id, incoming_call)?;
881 }
882
883 update_user_contacts(user.id, session).await?;
884 }
885 }
886
887 Ok(())
888 }
889
890 pub async fn invite_code_redeemed(
891 self: &Arc<Self>,
892 inviter_id: UserId,
893 invitee_id: UserId,
894 ) -> Result<()> {
895 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
896 if let Some(code) = &user.invite_code {
897 let pool = self.connection_pool.lock();
898 let invitee_contact = contact_for_user(invitee_id, false, &pool);
899 for connection_id in pool.user_connection_ids(inviter_id) {
900 self.peer.send(
901 connection_id,
902 proto::UpdateContacts {
903 contacts: vec![invitee_contact.clone()],
904 ..Default::default()
905 },
906 )?;
907 self.peer.send(
908 connection_id,
909 proto::UpdateInviteInfo {
910 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
911 count: user.invite_count as u32,
912 },
913 )?;
914 }
915 }
916 }
917 Ok(())
918 }
919
920 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
921 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
922 if let Some(invite_code) = &user.invite_code {
923 let pool = self.connection_pool.lock();
924 for connection_id in pool.user_connection_ids(user_id) {
925 self.peer.send(
926 connection_id,
927 proto::UpdateInviteInfo {
928 url: format!(
929 "{}{}",
930 self.app_state.config.invite_link_prefix, invite_code
931 ),
932 count: user.invite_count as u32,
933 },
934 )?;
935 }
936 }
937 }
938 Ok(())
939 }
940
941 pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
942 let user = self
943 .app_state
944 .db
945 .get_user_by_id(user_id)
946 .await?
947 .context("user not found")?;
948
949 let update_user_plan = make_update_user_plan_message(
950 &user,
951 user.admin,
952 &self.app_state.db,
953 self.app_state.llm_db.clone(),
954 )
955 .await?;
956
957 let pool = self.connection_pool.lock();
958 for connection_id in pool.user_connection_ids(user_id) {
959 self.peer
960 .send(connection_id, update_user_plan.clone())
961 .trace_err();
962 }
963
964 Ok(())
965 }
966
967 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
968 let pool = self.connection_pool.lock();
969 for connection_id in pool.user_connection_ids(user_id) {
970 self.peer
971 .send(connection_id, proto::RefreshLlmToken {})
972 .trace_err();
973 }
974 }
975
976 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
977 ServerSnapshot {
978 connection_pool: ConnectionPoolGuard {
979 guard: self.connection_pool.lock(),
980 _not_send: PhantomData,
981 },
982 peer: &self.peer,
983 }
984 }
985}
986
987impl Deref for ConnectionPoolGuard<'_> {
988 type Target = ConnectionPool;
989
990 fn deref(&self) -> &Self::Target {
991 &self.guard
992 }
993}
994
995impl DerefMut for ConnectionPoolGuard<'_> {
996 fn deref_mut(&mut self) -> &mut Self::Target {
997 &mut self.guard
998 }
999}
1000
1001impl Drop for ConnectionPoolGuard<'_> {
1002 fn drop(&mut self) {
1003 #[cfg(test)]
1004 self.check_invariants();
1005 }
1006}
1007
1008fn broadcast<F>(
1009 sender_id: Option<ConnectionId>,
1010 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1011 mut f: F,
1012) where
1013 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1014{
1015 for receiver_id in receiver_ids {
1016 if Some(receiver_id) != sender_id {
1017 if let Err(error) = f(receiver_id) {
1018 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1019 }
1020 }
1021 }
1022}
1023
1024pub struct ProtocolVersion(u32);
1025
1026impl Header for ProtocolVersion {
1027 fn name() -> &'static HeaderName {
1028 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1029 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1030 }
1031
1032 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1033 where
1034 Self: Sized,
1035 I: Iterator<Item = &'i axum::http::HeaderValue>,
1036 {
1037 let version = values
1038 .next()
1039 .ok_or_else(axum::headers::Error::invalid)?
1040 .to_str()
1041 .map_err(|_| axum::headers::Error::invalid())?
1042 .parse()
1043 .map_err(|_| axum::headers::Error::invalid())?;
1044 Ok(Self(version))
1045 }
1046
1047 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1048 values.extend([self.0.to_string().parse().unwrap()]);
1049 }
1050}
1051
1052pub struct AppVersionHeader(SemanticVersion);
1053impl Header for AppVersionHeader {
1054 fn name() -> &'static HeaderName {
1055 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1056 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1057 }
1058
1059 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1060 where
1061 Self: Sized,
1062 I: Iterator<Item = &'i axum::http::HeaderValue>,
1063 {
1064 let version = values
1065 .next()
1066 .ok_or_else(axum::headers::Error::invalid)?
1067 .to_str()
1068 .map_err(|_| axum::headers::Error::invalid())?
1069 .parse()
1070 .map_err(|_| axum::headers::Error::invalid())?;
1071 Ok(Self(version))
1072 }
1073
1074 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1075 values.extend([self.0.to_string().parse().unwrap()]);
1076 }
1077}
1078
1079pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1080 Router::new()
1081 .route("/rpc", get(handle_websocket_request))
1082 .layer(
1083 ServiceBuilder::new()
1084 .layer(Extension(server.app_state.clone()))
1085 .layer(middleware::from_fn(auth::validate_header)),
1086 )
1087 .route("/metrics", get(handle_metrics))
1088 .layer(Extension(server))
1089}
1090
1091pub async fn handle_websocket_request(
1092 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1093 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1094 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1095 Extension(server): Extension<Arc<Server>>,
1096 Extension(principal): Extension<Principal>,
1097 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1098 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1099 ws: WebSocketUpgrade,
1100) -> axum::response::Response {
1101 if protocol_version != rpc::PROTOCOL_VERSION {
1102 return (
1103 StatusCode::UPGRADE_REQUIRED,
1104 "client must be upgraded".to_string(),
1105 )
1106 .into_response();
1107 }
1108
1109 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1110 return (
1111 StatusCode::UPGRADE_REQUIRED,
1112 "no version header found".to_string(),
1113 )
1114 .into_response();
1115 };
1116
1117 if !version.can_collaborate() {
1118 return (
1119 StatusCode::UPGRADE_REQUIRED,
1120 "client must be upgraded".to_string(),
1121 )
1122 .into_response();
1123 }
1124
1125 let socket_address = socket_address.to_string();
1126 ws.on_upgrade(move |socket| {
1127 let socket = socket
1128 .map_ok(to_tungstenite_message)
1129 .err_into()
1130 .with(|message| async move { to_axum_message(message) });
1131 let connection = Connection::new(Box::pin(socket));
1132 async move {
1133 server
1134 .handle_connection(
1135 connection,
1136 socket_address,
1137 principal,
1138 version,
1139 country_code_header.map(|header| header.to_string()),
1140 system_id_header.map(|header| header.to_string()),
1141 None,
1142 Executor::Production,
1143 )
1144 .await;
1145 }
1146 })
1147}
1148
1149pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1150 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1151 let connections_metric = CONNECTIONS_METRIC
1152 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1153
1154 let connections = server
1155 .connection_pool
1156 .lock()
1157 .connections()
1158 .filter(|connection| !connection.admin)
1159 .count();
1160 connections_metric.set(connections as _);
1161
1162 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1163 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1164 register_int_gauge!(
1165 "shared_projects",
1166 "number of open projects with one or more guests"
1167 )
1168 .unwrap()
1169 });
1170
1171 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1172 shared_projects_metric.set(shared_projects as _);
1173
1174 let encoder = prometheus::TextEncoder::new();
1175 let metric_families = prometheus::gather();
1176 let encoded_metrics = encoder
1177 .encode_to_string(&metric_families)
1178 .map_err(|err| anyhow!("{err}"))?;
1179 Ok(encoded_metrics)
1180}
1181
1182#[instrument(err, skip(executor))]
1183async fn connection_lost(
1184 session: Session,
1185 mut teardown: watch::Receiver<bool>,
1186 executor: Executor,
1187) -> Result<()> {
1188 session.peer.disconnect(session.connection_id);
1189 session
1190 .connection_pool()
1191 .await
1192 .remove_connection(session.connection_id)?;
1193
1194 session
1195 .db()
1196 .await
1197 .connection_lost(session.connection_id)
1198 .await
1199 .trace_err();
1200
1201 futures::select_biased! {
1202 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1203
1204 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1205 leave_room_for_session(&session, session.connection_id).await.trace_err();
1206 leave_channel_buffers_for_session(&session)
1207 .await
1208 .trace_err();
1209
1210 if !session
1211 .connection_pool()
1212 .await
1213 .is_user_online(session.user_id())
1214 {
1215 let db = session.db().await;
1216 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1217 room_updated(&room, &session.peer);
1218 }
1219 }
1220
1221 update_user_contacts(session.user_id(), &session).await?;
1222 },
1223 _ = teardown.changed().fuse() => {}
1224 }
1225
1226 Ok(())
1227}
1228
1229/// Acknowledges a ping from a client, used to keep the connection alive.
1230async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1231 response.send(proto::Ack {})?;
1232 Ok(())
1233}
1234
1235/// Creates a new room for calling (outside of channels)
1236async fn create_room(
1237 _request: proto::CreateRoom,
1238 response: Response<proto::CreateRoom>,
1239 session: Session,
1240) -> Result<()> {
1241 let livekit_room = nanoid::nanoid!(30);
1242
1243 let live_kit_connection_info = util::maybe!(async {
1244 let live_kit = session.app_state.livekit_client.as_ref();
1245 let live_kit = live_kit?;
1246 let user_id = session.user_id().to_string();
1247
1248 let token = live_kit
1249 .room_token(&livekit_room, &user_id.to_string())
1250 .trace_err()?;
1251
1252 Some(proto::LiveKitConnectionInfo {
1253 server_url: live_kit.url().into(),
1254 token,
1255 can_publish: true,
1256 })
1257 })
1258 .await;
1259
1260 let room = session
1261 .db()
1262 .await
1263 .create_room(session.user_id(), session.connection_id, &livekit_room)
1264 .await?;
1265
1266 response.send(proto::CreateRoomResponse {
1267 room: Some(room.clone()),
1268 live_kit_connection_info,
1269 })?;
1270
1271 update_user_contacts(session.user_id(), &session).await?;
1272 Ok(())
1273}
1274
1275/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1276async fn join_room(
1277 request: proto::JoinRoom,
1278 response: Response<proto::JoinRoom>,
1279 session: Session,
1280) -> Result<()> {
1281 let room_id = RoomId::from_proto(request.id);
1282
1283 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1284
1285 if let Some(channel_id) = channel_id {
1286 return join_channel_internal(channel_id, Box::new(response), session).await;
1287 }
1288
1289 let joined_room = {
1290 let room = session
1291 .db()
1292 .await
1293 .join_room(room_id, session.user_id(), session.connection_id)
1294 .await?;
1295 room_updated(&room.room, &session.peer);
1296 room.into_inner()
1297 };
1298
1299 for connection_id in session
1300 .connection_pool()
1301 .await
1302 .user_connection_ids(session.user_id())
1303 {
1304 session
1305 .peer
1306 .send(
1307 connection_id,
1308 proto::CallCanceled {
1309 room_id: room_id.to_proto(),
1310 },
1311 )
1312 .trace_err();
1313 }
1314
1315 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1316 {
1317 live_kit
1318 .room_token(
1319 &joined_room.room.livekit_room,
1320 &session.user_id().to_string(),
1321 )
1322 .trace_err()
1323 .map(|token| proto::LiveKitConnectionInfo {
1324 server_url: live_kit.url().into(),
1325 token,
1326 can_publish: true,
1327 })
1328 } else {
1329 None
1330 };
1331
1332 response.send(proto::JoinRoomResponse {
1333 room: Some(joined_room.room),
1334 channel_id: None,
1335 live_kit_connection_info,
1336 })?;
1337
1338 update_user_contacts(session.user_id(), &session).await?;
1339 Ok(())
1340}
1341
1342/// Rejoin room is used to reconnect to a room after connection errors.
1343async fn rejoin_room(
1344 request: proto::RejoinRoom,
1345 response: Response<proto::RejoinRoom>,
1346 session: Session,
1347) -> Result<()> {
1348 let room;
1349 let channel;
1350 {
1351 let mut rejoined_room = session
1352 .db()
1353 .await
1354 .rejoin_room(request, session.user_id(), session.connection_id)
1355 .await?;
1356
1357 response.send(proto::RejoinRoomResponse {
1358 room: Some(rejoined_room.room.clone()),
1359 reshared_projects: rejoined_room
1360 .reshared_projects
1361 .iter()
1362 .map(|project| proto::ResharedProject {
1363 id: project.id.to_proto(),
1364 collaborators: project
1365 .collaborators
1366 .iter()
1367 .map(|collaborator| collaborator.to_proto())
1368 .collect(),
1369 })
1370 .collect(),
1371 rejoined_projects: rejoined_room
1372 .rejoined_projects
1373 .iter()
1374 .map(|rejoined_project| rejoined_project.to_proto())
1375 .collect(),
1376 })?;
1377 room_updated(&rejoined_room.room, &session.peer);
1378
1379 for project in &rejoined_room.reshared_projects {
1380 for collaborator in &project.collaborators {
1381 session
1382 .peer
1383 .send(
1384 collaborator.connection_id,
1385 proto::UpdateProjectCollaborator {
1386 project_id: project.id.to_proto(),
1387 old_peer_id: Some(project.old_connection_id.into()),
1388 new_peer_id: Some(session.connection_id.into()),
1389 },
1390 )
1391 .trace_err();
1392 }
1393
1394 broadcast(
1395 Some(session.connection_id),
1396 project
1397 .collaborators
1398 .iter()
1399 .map(|collaborator| collaborator.connection_id),
1400 |connection_id| {
1401 session.peer.forward_send(
1402 session.connection_id,
1403 connection_id,
1404 proto::UpdateProject {
1405 project_id: project.id.to_proto(),
1406 worktrees: project.worktrees.clone(),
1407 },
1408 )
1409 },
1410 );
1411 }
1412
1413 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1414
1415 let rejoined_room = rejoined_room.into_inner();
1416
1417 room = rejoined_room.room;
1418 channel = rejoined_room.channel;
1419 }
1420
1421 if let Some(channel) = channel {
1422 channel_updated(
1423 &channel,
1424 &room,
1425 &session.peer,
1426 &*session.connection_pool().await,
1427 );
1428 }
1429
1430 update_user_contacts(session.user_id(), &session).await?;
1431 Ok(())
1432}
1433
1434fn notify_rejoined_projects(
1435 rejoined_projects: &mut Vec<RejoinedProject>,
1436 session: &Session,
1437) -> Result<()> {
1438 for project in rejoined_projects.iter() {
1439 for collaborator in &project.collaborators {
1440 session
1441 .peer
1442 .send(
1443 collaborator.connection_id,
1444 proto::UpdateProjectCollaborator {
1445 project_id: project.id.to_proto(),
1446 old_peer_id: Some(project.old_connection_id.into()),
1447 new_peer_id: Some(session.connection_id.into()),
1448 },
1449 )
1450 .trace_err();
1451 }
1452 }
1453
1454 for project in rejoined_projects {
1455 for worktree in mem::take(&mut project.worktrees) {
1456 // Stream this worktree's entries.
1457 let message = proto::UpdateWorktree {
1458 project_id: project.id.to_proto(),
1459 worktree_id: worktree.id,
1460 abs_path: worktree.abs_path.clone(),
1461 root_name: worktree.root_name,
1462 updated_entries: worktree.updated_entries,
1463 removed_entries: worktree.removed_entries,
1464 scan_id: worktree.scan_id,
1465 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1466 updated_repositories: worktree.updated_repositories,
1467 removed_repositories: worktree.removed_repositories,
1468 };
1469 for update in proto::split_worktree_update(message) {
1470 session.peer.send(session.connection_id, update)?;
1471 }
1472
1473 // Stream this worktree's diagnostics.
1474 for summary in worktree.diagnostic_summaries {
1475 session.peer.send(
1476 session.connection_id,
1477 proto::UpdateDiagnosticSummary {
1478 project_id: project.id.to_proto(),
1479 worktree_id: worktree.id,
1480 summary: Some(summary),
1481 },
1482 )?;
1483 }
1484
1485 for settings_file in worktree.settings_files {
1486 session.peer.send(
1487 session.connection_id,
1488 proto::UpdateWorktreeSettings {
1489 project_id: project.id.to_proto(),
1490 worktree_id: worktree.id,
1491 path: settings_file.path,
1492 content: Some(settings_file.content),
1493 kind: Some(settings_file.kind.to_proto().into()),
1494 },
1495 )?;
1496 }
1497 }
1498
1499 for repository in mem::take(&mut project.updated_repositories) {
1500 for update in split_repository_update(repository) {
1501 session.peer.send(session.connection_id, update)?;
1502 }
1503 }
1504
1505 for id in mem::take(&mut project.removed_repositories) {
1506 session.peer.send(
1507 session.connection_id,
1508 proto::RemoveRepository {
1509 project_id: project.id.to_proto(),
1510 id,
1511 },
1512 )?;
1513 }
1514 }
1515
1516 Ok(())
1517}
1518
1519/// leave room disconnects from the room.
1520async fn leave_room(
1521 _: proto::LeaveRoom,
1522 response: Response<proto::LeaveRoom>,
1523 session: Session,
1524) -> Result<()> {
1525 leave_room_for_session(&session, session.connection_id).await?;
1526 response.send(proto::Ack {})?;
1527 Ok(())
1528}
1529
1530/// Updates the permissions of someone else in the room.
1531async fn set_room_participant_role(
1532 request: proto::SetRoomParticipantRole,
1533 response: Response<proto::SetRoomParticipantRole>,
1534 session: Session,
1535) -> Result<()> {
1536 let user_id = UserId::from_proto(request.user_id);
1537 let role = ChannelRole::from(request.role());
1538
1539 let (livekit_room, can_publish) = {
1540 let room = session
1541 .db()
1542 .await
1543 .set_room_participant_role(
1544 session.user_id(),
1545 RoomId::from_proto(request.room_id),
1546 user_id,
1547 role,
1548 )
1549 .await?;
1550
1551 let livekit_room = room.livekit_room.clone();
1552 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1553 room_updated(&room, &session.peer);
1554 (livekit_room, can_publish)
1555 };
1556
1557 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1558 live_kit
1559 .update_participant(
1560 livekit_room.clone(),
1561 request.user_id.to_string(),
1562 livekit_api::proto::ParticipantPermission {
1563 can_subscribe: true,
1564 can_publish,
1565 can_publish_data: can_publish,
1566 hidden: false,
1567 recorder: false,
1568 },
1569 )
1570 .await
1571 .trace_err();
1572 }
1573
1574 response.send(proto::Ack {})?;
1575 Ok(())
1576}
1577
1578/// Call someone else into the current room
1579async fn call(
1580 request: proto::Call,
1581 response: Response<proto::Call>,
1582 session: Session,
1583) -> Result<()> {
1584 let room_id = RoomId::from_proto(request.room_id);
1585 let calling_user_id = session.user_id();
1586 let calling_connection_id = session.connection_id;
1587 let called_user_id = UserId::from_proto(request.called_user_id);
1588 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1589 if !session
1590 .db()
1591 .await
1592 .has_contact(calling_user_id, called_user_id)
1593 .await?
1594 {
1595 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1596 }
1597
1598 let incoming_call = {
1599 let (room, incoming_call) = &mut *session
1600 .db()
1601 .await
1602 .call(
1603 room_id,
1604 calling_user_id,
1605 calling_connection_id,
1606 called_user_id,
1607 initial_project_id,
1608 )
1609 .await?;
1610 room_updated(room, &session.peer);
1611 mem::take(incoming_call)
1612 };
1613 update_user_contacts(called_user_id, &session).await?;
1614
1615 let mut calls = session
1616 .connection_pool()
1617 .await
1618 .user_connection_ids(called_user_id)
1619 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1620 .collect::<FuturesUnordered<_>>();
1621
1622 while let Some(call_response) = calls.next().await {
1623 match call_response.as_ref() {
1624 Ok(_) => {
1625 response.send(proto::Ack {})?;
1626 return Ok(());
1627 }
1628 Err(_) => {
1629 call_response.trace_err();
1630 }
1631 }
1632 }
1633
1634 {
1635 let room = session
1636 .db()
1637 .await
1638 .call_failed(room_id, called_user_id)
1639 .await?;
1640 room_updated(&room, &session.peer);
1641 }
1642 update_user_contacts(called_user_id, &session).await?;
1643
1644 Err(anyhow!("failed to ring user"))?
1645}
1646
1647/// Cancel an outgoing call.
1648async fn cancel_call(
1649 request: proto::CancelCall,
1650 response: Response<proto::CancelCall>,
1651 session: Session,
1652) -> Result<()> {
1653 let called_user_id = UserId::from_proto(request.called_user_id);
1654 let room_id = RoomId::from_proto(request.room_id);
1655 {
1656 let room = session
1657 .db()
1658 .await
1659 .cancel_call(room_id, session.connection_id, called_user_id)
1660 .await?;
1661 room_updated(&room, &session.peer);
1662 }
1663
1664 for connection_id in session
1665 .connection_pool()
1666 .await
1667 .user_connection_ids(called_user_id)
1668 {
1669 session
1670 .peer
1671 .send(
1672 connection_id,
1673 proto::CallCanceled {
1674 room_id: room_id.to_proto(),
1675 },
1676 )
1677 .trace_err();
1678 }
1679 response.send(proto::Ack {})?;
1680
1681 update_user_contacts(called_user_id, &session).await?;
1682 Ok(())
1683}
1684
1685/// Decline an incoming call.
1686async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1687 let room_id = RoomId::from_proto(message.room_id);
1688 {
1689 let room = session
1690 .db()
1691 .await
1692 .decline_call(Some(room_id), session.user_id())
1693 .await?
1694 .context("declining call")?;
1695 room_updated(&room, &session.peer);
1696 }
1697
1698 for connection_id in session
1699 .connection_pool()
1700 .await
1701 .user_connection_ids(session.user_id())
1702 {
1703 session
1704 .peer
1705 .send(
1706 connection_id,
1707 proto::CallCanceled {
1708 room_id: room_id.to_proto(),
1709 },
1710 )
1711 .trace_err();
1712 }
1713 update_user_contacts(session.user_id(), &session).await?;
1714 Ok(())
1715}
1716
1717/// Updates other participants in the room with your current location.
1718async fn update_participant_location(
1719 request: proto::UpdateParticipantLocation,
1720 response: Response<proto::UpdateParticipantLocation>,
1721 session: Session,
1722) -> Result<()> {
1723 let room_id = RoomId::from_proto(request.room_id);
1724 let location = request.location.context("invalid location")?;
1725
1726 let db = session.db().await;
1727 let room = db
1728 .update_room_participant_location(room_id, session.connection_id, location)
1729 .await?;
1730
1731 room_updated(&room, &session.peer);
1732 response.send(proto::Ack {})?;
1733 Ok(())
1734}
1735
1736/// Share a project into the room.
1737async fn share_project(
1738 request: proto::ShareProject,
1739 response: Response<proto::ShareProject>,
1740 session: Session,
1741) -> Result<()> {
1742 let (project_id, room) = &*session
1743 .db()
1744 .await
1745 .share_project(
1746 RoomId::from_proto(request.room_id),
1747 session.connection_id,
1748 &request.worktrees,
1749 request.is_ssh_project,
1750 )
1751 .await?;
1752 response.send(proto::ShareProjectResponse {
1753 project_id: project_id.to_proto(),
1754 })?;
1755 room_updated(room, &session.peer);
1756
1757 Ok(())
1758}
1759
1760/// Unshare a project from the room.
1761async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1762 let project_id = ProjectId::from_proto(message.project_id);
1763 unshare_project_internal(project_id, session.connection_id, &session).await
1764}
1765
1766async fn unshare_project_internal(
1767 project_id: ProjectId,
1768 connection_id: ConnectionId,
1769 session: &Session,
1770) -> Result<()> {
1771 let delete = {
1772 let room_guard = session
1773 .db()
1774 .await
1775 .unshare_project(project_id, connection_id)
1776 .await?;
1777
1778 let (delete, room, guest_connection_ids) = &*room_guard;
1779
1780 let message = proto::UnshareProject {
1781 project_id: project_id.to_proto(),
1782 };
1783
1784 broadcast(
1785 Some(connection_id),
1786 guest_connection_ids.iter().copied(),
1787 |conn_id| session.peer.send(conn_id, message.clone()),
1788 );
1789 if let Some(room) = room {
1790 room_updated(room, &session.peer);
1791 }
1792
1793 *delete
1794 };
1795
1796 if delete {
1797 let db = session.db().await;
1798 db.delete_project(project_id).await?;
1799 }
1800
1801 Ok(())
1802}
1803
1804/// Join someone elses shared project.
1805async fn join_project(
1806 request: proto::JoinProject,
1807 response: Response<proto::JoinProject>,
1808 session: Session,
1809) -> Result<()> {
1810 let project_id = ProjectId::from_proto(request.project_id);
1811
1812 tracing::info!(%project_id, "join project");
1813
1814 let db = session.db().await;
1815 let (project, replica_id) = &mut *db
1816 .join_project(project_id, session.connection_id, session.user_id())
1817 .await?;
1818 drop(db);
1819 tracing::info!(%project_id, "join remote project");
1820 join_project_internal(response, session, project, replica_id)
1821}
1822
1823trait JoinProjectInternalResponse {
1824 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1825}
1826impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1827 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1828 Response::<proto::JoinProject>::send(self, result)
1829 }
1830}
1831
1832fn join_project_internal(
1833 response: impl JoinProjectInternalResponse,
1834 session: Session,
1835 project: &mut Project,
1836 replica_id: &ReplicaId,
1837) -> Result<()> {
1838 let collaborators = project
1839 .collaborators
1840 .iter()
1841 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1842 .map(|collaborator| collaborator.to_proto())
1843 .collect::<Vec<_>>();
1844 let project_id = project.id;
1845 let guest_user_id = session.user_id();
1846
1847 let worktrees = project
1848 .worktrees
1849 .iter()
1850 .map(|(id, worktree)| proto::WorktreeMetadata {
1851 id: *id,
1852 root_name: worktree.root_name.clone(),
1853 visible: worktree.visible,
1854 abs_path: worktree.abs_path.clone(),
1855 })
1856 .collect::<Vec<_>>();
1857
1858 let add_project_collaborator = proto::AddProjectCollaborator {
1859 project_id: project_id.to_proto(),
1860 collaborator: Some(proto::Collaborator {
1861 peer_id: Some(session.connection_id.into()),
1862 replica_id: replica_id.0 as u32,
1863 user_id: guest_user_id.to_proto(),
1864 is_host: false,
1865 }),
1866 };
1867
1868 for collaborator in &collaborators {
1869 session
1870 .peer
1871 .send(
1872 collaborator.peer_id.unwrap().into(),
1873 add_project_collaborator.clone(),
1874 )
1875 .trace_err();
1876 }
1877
1878 // First, we send the metadata associated with each worktree.
1879 response.send(proto::JoinProjectResponse {
1880 project_id: project.id.0 as u64,
1881 worktrees: worktrees.clone(),
1882 replica_id: replica_id.0 as u32,
1883 collaborators: collaborators.clone(),
1884 language_servers: project.language_servers.clone(),
1885 role: project.role.into(),
1886 })?;
1887
1888 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1889 // Stream this worktree's entries.
1890 let message = proto::UpdateWorktree {
1891 project_id: project_id.to_proto(),
1892 worktree_id,
1893 abs_path: worktree.abs_path.clone(),
1894 root_name: worktree.root_name,
1895 updated_entries: worktree.entries,
1896 removed_entries: Default::default(),
1897 scan_id: worktree.scan_id,
1898 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1899 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1900 removed_repositories: Default::default(),
1901 };
1902 for update in proto::split_worktree_update(message) {
1903 session.peer.send(session.connection_id, update.clone())?;
1904 }
1905
1906 // Stream this worktree's diagnostics.
1907 for summary in worktree.diagnostic_summaries {
1908 session.peer.send(
1909 session.connection_id,
1910 proto::UpdateDiagnosticSummary {
1911 project_id: project_id.to_proto(),
1912 worktree_id: worktree.id,
1913 summary: Some(summary),
1914 },
1915 )?;
1916 }
1917
1918 for settings_file in worktree.settings_files {
1919 session.peer.send(
1920 session.connection_id,
1921 proto::UpdateWorktreeSettings {
1922 project_id: project_id.to_proto(),
1923 worktree_id: worktree.id,
1924 path: settings_file.path,
1925 content: Some(settings_file.content),
1926 kind: Some(settings_file.kind.to_proto() as i32),
1927 },
1928 )?;
1929 }
1930 }
1931
1932 for repository in mem::take(&mut project.repositories) {
1933 for update in split_repository_update(repository) {
1934 session.peer.send(session.connection_id, update)?;
1935 }
1936 }
1937
1938 for language_server in &project.language_servers {
1939 session.peer.send(
1940 session.connection_id,
1941 proto::UpdateLanguageServer {
1942 project_id: project_id.to_proto(),
1943 language_server_id: language_server.id,
1944 variant: Some(
1945 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1946 proto::LspDiskBasedDiagnosticsUpdated {},
1947 ),
1948 ),
1949 },
1950 )?;
1951 }
1952
1953 Ok(())
1954}
1955
1956/// Leave someone elses shared project.
1957async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1958 let sender_id = session.connection_id;
1959 let project_id = ProjectId::from_proto(request.project_id);
1960 let db = session.db().await;
1961
1962 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1963 tracing::info!(
1964 %project_id,
1965 "leave project"
1966 );
1967
1968 project_left(project, &session);
1969 if let Some(room) = room {
1970 room_updated(room, &session.peer);
1971 }
1972
1973 Ok(())
1974}
1975
1976/// Updates other participants with changes to the project
1977async fn update_project(
1978 request: proto::UpdateProject,
1979 response: Response<proto::UpdateProject>,
1980 session: Session,
1981) -> Result<()> {
1982 let project_id = ProjectId::from_proto(request.project_id);
1983 let (room, guest_connection_ids) = &*session
1984 .db()
1985 .await
1986 .update_project(project_id, session.connection_id, &request.worktrees)
1987 .await?;
1988 broadcast(
1989 Some(session.connection_id),
1990 guest_connection_ids.iter().copied(),
1991 |connection_id| {
1992 session
1993 .peer
1994 .forward_send(session.connection_id, connection_id, request.clone())
1995 },
1996 );
1997 if let Some(room) = room {
1998 room_updated(room, &session.peer);
1999 }
2000 response.send(proto::Ack {})?;
2001
2002 Ok(())
2003}
2004
2005/// Updates other participants with changes to the worktree
2006async fn update_worktree(
2007 request: proto::UpdateWorktree,
2008 response: Response<proto::UpdateWorktree>,
2009 session: Session,
2010) -> Result<()> {
2011 let guest_connection_ids = session
2012 .db()
2013 .await
2014 .update_worktree(&request, session.connection_id)
2015 .await?;
2016
2017 broadcast(
2018 Some(session.connection_id),
2019 guest_connection_ids.iter().copied(),
2020 |connection_id| {
2021 session
2022 .peer
2023 .forward_send(session.connection_id, connection_id, request.clone())
2024 },
2025 );
2026 response.send(proto::Ack {})?;
2027 Ok(())
2028}
2029
2030async fn update_repository(
2031 request: proto::UpdateRepository,
2032 response: Response<proto::UpdateRepository>,
2033 session: Session,
2034) -> Result<()> {
2035 let guest_connection_ids = session
2036 .db()
2037 .await
2038 .update_repository(&request, session.connection_id)
2039 .await?;
2040
2041 broadcast(
2042 Some(session.connection_id),
2043 guest_connection_ids.iter().copied(),
2044 |connection_id| {
2045 session
2046 .peer
2047 .forward_send(session.connection_id, connection_id, request.clone())
2048 },
2049 );
2050 response.send(proto::Ack {})?;
2051 Ok(())
2052}
2053
2054async fn remove_repository(
2055 request: proto::RemoveRepository,
2056 response: Response<proto::RemoveRepository>,
2057 session: Session,
2058) -> Result<()> {
2059 let guest_connection_ids = session
2060 .db()
2061 .await
2062 .remove_repository(&request, session.connection_id)
2063 .await?;
2064
2065 broadcast(
2066 Some(session.connection_id),
2067 guest_connection_ids.iter().copied(),
2068 |connection_id| {
2069 session
2070 .peer
2071 .forward_send(session.connection_id, connection_id, request.clone())
2072 },
2073 );
2074 response.send(proto::Ack {})?;
2075 Ok(())
2076}
2077
2078/// Updates other participants with changes to the diagnostics
2079async fn update_diagnostic_summary(
2080 message: proto::UpdateDiagnosticSummary,
2081 session: Session,
2082) -> Result<()> {
2083 let guest_connection_ids = session
2084 .db()
2085 .await
2086 .update_diagnostic_summary(&message, session.connection_id)
2087 .await?;
2088
2089 broadcast(
2090 Some(session.connection_id),
2091 guest_connection_ids.iter().copied(),
2092 |connection_id| {
2093 session
2094 .peer
2095 .forward_send(session.connection_id, connection_id, message.clone())
2096 },
2097 );
2098
2099 Ok(())
2100}
2101
2102/// Updates other participants with changes to the worktree settings
2103async fn update_worktree_settings(
2104 message: proto::UpdateWorktreeSettings,
2105 session: Session,
2106) -> Result<()> {
2107 let guest_connection_ids = session
2108 .db()
2109 .await
2110 .update_worktree_settings(&message, session.connection_id)
2111 .await?;
2112
2113 broadcast(
2114 Some(session.connection_id),
2115 guest_connection_ids.iter().copied(),
2116 |connection_id| {
2117 session
2118 .peer
2119 .forward_send(session.connection_id, connection_id, message.clone())
2120 },
2121 );
2122
2123 Ok(())
2124}
2125
2126/// Notify other participants that a language server has started.
2127async fn start_language_server(
2128 request: proto::StartLanguageServer,
2129 session: Session,
2130) -> Result<()> {
2131 let guest_connection_ids = session
2132 .db()
2133 .await
2134 .start_language_server(&request, session.connection_id)
2135 .await?;
2136
2137 broadcast(
2138 Some(session.connection_id),
2139 guest_connection_ids.iter().copied(),
2140 |connection_id| {
2141 session
2142 .peer
2143 .forward_send(session.connection_id, connection_id, request.clone())
2144 },
2145 );
2146 Ok(())
2147}
2148
2149/// Notify other participants that a language server has changed.
2150async fn update_language_server(
2151 request: proto::UpdateLanguageServer,
2152 session: Session,
2153) -> Result<()> {
2154 let project_id = ProjectId::from_proto(request.project_id);
2155 let project_connection_ids = session
2156 .db()
2157 .await
2158 .project_connection_ids(project_id, session.connection_id, true)
2159 .await?;
2160 broadcast(
2161 Some(session.connection_id),
2162 project_connection_ids.iter().copied(),
2163 |connection_id| {
2164 session
2165 .peer
2166 .forward_send(session.connection_id, connection_id, request.clone())
2167 },
2168 );
2169 Ok(())
2170}
2171
2172/// forward a project request to the host. These requests should be read only
2173/// as guests are allowed to send them.
2174async fn forward_read_only_project_request<T>(
2175 request: T,
2176 response: Response<T>,
2177 session: Session,
2178) -> Result<()>
2179where
2180 T: EntityMessage + RequestMessage,
2181{
2182 let project_id = ProjectId::from_proto(request.remote_entity_id());
2183 let host_connection_id = session
2184 .db()
2185 .await
2186 .host_for_read_only_project_request(project_id, session.connection_id)
2187 .await?;
2188 let payload = session
2189 .peer
2190 .forward_request(session.connection_id, host_connection_id, request)
2191 .await?;
2192 response.send(payload)?;
2193 Ok(())
2194}
2195
2196async fn forward_find_search_candidates_request(
2197 request: proto::FindSearchCandidates,
2198 response: Response<proto::FindSearchCandidates>,
2199 session: Session,
2200) -> Result<()> {
2201 let project_id = ProjectId::from_proto(request.remote_entity_id());
2202 let host_connection_id = session
2203 .db()
2204 .await
2205 .host_for_read_only_project_request(project_id, session.connection_id)
2206 .await?;
2207 let payload = session
2208 .peer
2209 .forward_request(session.connection_id, host_connection_id, request)
2210 .await?;
2211 response.send(payload)?;
2212 Ok(())
2213}
2214
2215/// forward a project request to the host. These requests are disallowed
2216/// for guests.
2217async fn forward_mutating_project_request<T>(
2218 request: T,
2219 response: Response<T>,
2220 session: Session,
2221) -> Result<()>
2222where
2223 T: EntityMessage + RequestMessage,
2224{
2225 let project_id = ProjectId::from_proto(request.remote_entity_id());
2226
2227 let host_connection_id = session
2228 .db()
2229 .await
2230 .host_for_mutating_project_request(project_id, session.connection_id)
2231 .await?;
2232 let payload = session
2233 .peer
2234 .forward_request(session.connection_id, host_connection_id, request)
2235 .await?;
2236 response.send(payload)?;
2237 Ok(())
2238}
2239
2240/// Notify other participants that a new buffer has been created
2241async fn create_buffer_for_peer(
2242 request: proto::CreateBufferForPeer,
2243 session: Session,
2244) -> Result<()> {
2245 session
2246 .db()
2247 .await
2248 .check_user_is_project_host(
2249 ProjectId::from_proto(request.project_id),
2250 session.connection_id,
2251 )
2252 .await?;
2253 let peer_id = request.peer_id.context("invalid peer id")?;
2254 session
2255 .peer
2256 .forward_send(session.connection_id, peer_id.into(), request)?;
2257 Ok(())
2258}
2259
2260/// Notify other participants that a buffer has been updated. This is
2261/// allowed for guests as long as the update is limited to selections.
2262async fn update_buffer(
2263 request: proto::UpdateBuffer,
2264 response: Response<proto::UpdateBuffer>,
2265 session: Session,
2266) -> Result<()> {
2267 let project_id = ProjectId::from_proto(request.project_id);
2268 let mut capability = Capability::ReadOnly;
2269
2270 for op in request.operations.iter() {
2271 match op.variant {
2272 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2273 Some(_) => capability = Capability::ReadWrite,
2274 }
2275 }
2276
2277 let host = {
2278 let guard = session
2279 .db()
2280 .await
2281 .connections_for_buffer_update(project_id, session.connection_id, capability)
2282 .await?;
2283
2284 let (host, guests) = &*guard;
2285
2286 broadcast(
2287 Some(session.connection_id),
2288 guests.clone(),
2289 |connection_id| {
2290 session
2291 .peer
2292 .forward_send(session.connection_id, connection_id, request.clone())
2293 },
2294 );
2295
2296 *host
2297 };
2298
2299 if host != session.connection_id {
2300 session
2301 .peer
2302 .forward_request(session.connection_id, host, request.clone())
2303 .await?;
2304 }
2305
2306 response.send(proto::Ack {})?;
2307 Ok(())
2308}
2309
2310async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2311 let project_id = ProjectId::from_proto(message.project_id);
2312
2313 let operation = message.operation.as_ref().context("invalid operation")?;
2314 let capability = match operation.variant.as_ref() {
2315 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2316 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2317 match buffer_op.variant {
2318 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2319 Capability::ReadOnly
2320 }
2321 _ => Capability::ReadWrite,
2322 }
2323 } else {
2324 Capability::ReadWrite
2325 }
2326 }
2327 Some(_) => Capability::ReadWrite,
2328 None => Capability::ReadOnly,
2329 };
2330
2331 let guard = session
2332 .db()
2333 .await
2334 .connections_for_buffer_update(project_id, session.connection_id, capability)
2335 .await?;
2336
2337 let (host, guests) = &*guard;
2338
2339 broadcast(
2340 Some(session.connection_id),
2341 guests.iter().chain([host]).copied(),
2342 |connection_id| {
2343 session
2344 .peer
2345 .forward_send(session.connection_id, connection_id, message.clone())
2346 },
2347 );
2348
2349 Ok(())
2350}
2351
2352/// Notify other participants that a project has been updated.
2353async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2354 request: T,
2355 session: Session,
2356) -> Result<()> {
2357 let project_id = ProjectId::from_proto(request.remote_entity_id());
2358 let project_connection_ids = session
2359 .db()
2360 .await
2361 .project_connection_ids(project_id, session.connection_id, false)
2362 .await?;
2363
2364 broadcast(
2365 Some(session.connection_id),
2366 project_connection_ids.iter().copied(),
2367 |connection_id| {
2368 session
2369 .peer
2370 .forward_send(session.connection_id, connection_id, request.clone())
2371 },
2372 );
2373 Ok(())
2374}
2375
2376/// Start following another user in a call.
2377async fn follow(
2378 request: proto::Follow,
2379 response: Response<proto::Follow>,
2380 session: Session,
2381) -> Result<()> {
2382 let room_id = RoomId::from_proto(request.room_id);
2383 let project_id = request.project_id.map(ProjectId::from_proto);
2384 let leader_id = request.leader_id.context("invalid leader id")?.into();
2385 let follower_id = session.connection_id;
2386
2387 session
2388 .db()
2389 .await
2390 .check_room_participants(room_id, leader_id, session.connection_id)
2391 .await?;
2392
2393 let response_payload = session
2394 .peer
2395 .forward_request(session.connection_id, leader_id, request)
2396 .await?;
2397 response.send(response_payload)?;
2398
2399 if let Some(project_id) = project_id {
2400 let room = session
2401 .db()
2402 .await
2403 .follow(room_id, project_id, leader_id, follower_id)
2404 .await?;
2405 room_updated(&room, &session.peer);
2406 }
2407
2408 Ok(())
2409}
2410
2411/// Stop following another user in a call.
2412async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2413 let room_id = RoomId::from_proto(request.room_id);
2414 let project_id = request.project_id.map(ProjectId::from_proto);
2415 let leader_id = request.leader_id.context("invalid leader id")?.into();
2416 let follower_id = session.connection_id;
2417
2418 session
2419 .db()
2420 .await
2421 .check_room_participants(room_id, leader_id, session.connection_id)
2422 .await?;
2423
2424 session
2425 .peer
2426 .forward_send(session.connection_id, leader_id, request)?;
2427
2428 if let Some(project_id) = project_id {
2429 let room = session
2430 .db()
2431 .await
2432 .unfollow(room_id, project_id, leader_id, follower_id)
2433 .await?;
2434 room_updated(&room, &session.peer);
2435 }
2436
2437 Ok(())
2438}
2439
2440/// Notify everyone following you of your current location.
2441async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2442 let room_id = RoomId::from_proto(request.room_id);
2443 let database = session.db.lock().await;
2444
2445 let connection_ids = if let Some(project_id) = request.project_id {
2446 let project_id = ProjectId::from_proto(project_id);
2447 database
2448 .project_connection_ids(project_id, session.connection_id, true)
2449 .await?
2450 } else {
2451 database
2452 .room_connection_ids(room_id, session.connection_id)
2453 .await?
2454 };
2455
2456 // For now, don't send view update messages back to that view's current leader.
2457 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2458 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2459 _ => None,
2460 });
2461
2462 for connection_id in connection_ids.iter().cloned() {
2463 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2464 session
2465 .peer
2466 .forward_send(session.connection_id, connection_id, request.clone())?;
2467 }
2468 }
2469 Ok(())
2470}
2471
2472/// Get public data about users.
2473async fn get_users(
2474 request: proto::GetUsers,
2475 response: Response<proto::GetUsers>,
2476 session: Session,
2477) -> Result<()> {
2478 let user_ids = request
2479 .user_ids
2480 .into_iter()
2481 .map(UserId::from_proto)
2482 .collect();
2483 let users = session
2484 .db()
2485 .await
2486 .get_users_by_ids(user_ids)
2487 .await?
2488 .into_iter()
2489 .map(|user| proto::User {
2490 id: user.id.to_proto(),
2491 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2492 github_login: user.github_login,
2493 email: user.email_address,
2494 name: user.name,
2495 })
2496 .collect();
2497 response.send(proto::UsersResponse { users })?;
2498 Ok(())
2499}
2500
2501/// Search for users (to invite) buy Github login
2502async fn fuzzy_search_users(
2503 request: proto::FuzzySearchUsers,
2504 response: Response<proto::FuzzySearchUsers>,
2505 session: Session,
2506) -> Result<()> {
2507 let query = request.query;
2508 let users = match query.len() {
2509 0 => vec![],
2510 1 | 2 => session
2511 .db()
2512 .await
2513 .get_user_by_github_login(&query)
2514 .await?
2515 .into_iter()
2516 .collect(),
2517 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2518 };
2519 let users = users
2520 .into_iter()
2521 .filter(|user| user.id != session.user_id())
2522 .map(|user| proto::User {
2523 id: user.id.to_proto(),
2524 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2525 github_login: user.github_login,
2526 name: user.name,
2527 email: user.email_address,
2528 })
2529 .collect();
2530 response.send(proto::UsersResponse { users })?;
2531 Ok(())
2532}
2533
2534/// Send a contact request to another user.
2535async fn request_contact(
2536 request: proto::RequestContact,
2537 response: Response<proto::RequestContact>,
2538 session: Session,
2539) -> Result<()> {
2540 let requester_id = session.user_id();
2541 let responder_id = UserId::from_proto(request.responder_id);
2542 if requester_id == responder_id {
2543 return Err(anyhow!("cannot add yourself as a contact"))?;
2544 }
2545
2546 let notifications = session
2547 .db()
2548 .await
2549 .send_contact_request(requester_id, responder_id)
2550 .await?;
2551
2552 // Update outgoing contact requests of requester
2553 let mut update = proto::UpdateContacts::default();
2554 update.outgoing_requests.push(responder_id.to_proto());
2555 for connection_id in session
2556 .connection_pool()
2557 .await
2558 .user_connection_ids(requester_id)
2559 {
2560 session.peer.send(connection_id, update.clone())?;
2561 }
2562
2563 // Update incoming contact requests of responder
2564 let mut update = proto::UpdateContacts::default();
2565 update
2566 .incoming_requests
2567 .push(proto::IncomingContactRequest {
2568 requester_id: requester_id.to_proto(),
2569 });
2570 let connection_pool = session.connection_pool().await;
2571 for connection_id in connection_pool.user_connection_ids(responder_id) {
2572 session.peer.send(connection_id, update.clone())?;
2573 }
2574
2575 send_notifications(&connection_pool, &session.peer, notifications);
2576
2577 response.send(proto::Ack {})?;
2578 Ok(())
2579}
2580
2581/// Accept or decline a contact request
2582async fn respond_to_contact_request(
2583 request: proto::RespondToContactRequest,
2584 response: Response<proto::RespondToContactRequest>,
2585 session: Session,
2586) -> Result<()> {
2587 let responder_id = session.user_id();
2588 let requester_id = UserId::from_proto(request.requester_id);
2589 let db = session.db().await;
2590 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2591 db.dismiss_contact_notification(responder_id, requester_id)
2592 .await?;
2593 } else {
2594 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2595
2596 let notifications = db
2597 .respond_to_contact_request(responder_id, requester_id, accept)
2598 .await?;
2599 let requester_busy = db.is_user_busy(requester_id).await?;
2600 let responder_busy = db.is_user_busy(responder_id).await?;
2601
2602 let pool = session.connection_pool().await;
2603 // Update responder with new contact
2604 let mut update = proto::UpdateContacts::default();
2605 if accept {
2606 update
2607 .contacts
2608 .push(contact_for_user(requester_id, requester_busy, &pool));
2609 }
2610 update
2611 .remove_incoming_requests
2612 .push(requester_id.to_proto());
2613 for connection_id in pool.user_connection_ids(responder_id) {
2614 session.peer.send(connection_id, update.clone())?;
2615 }
2616
2617 // Update requester with new contact
2618 let mut update = proto::UpdateContacts::default();
2619 if accept {
2620 update
2621 .contacts
2622 .push(contact_for_user(responder_id, responder_busy, &pool));
2623 }
2624 update
2625 .remove_outgoing_requests
2626 .push(responder_id.to_proto());
2627
2628 for connection_id in pool.user_connection_ids(requester_id) {
2629 session.peer.send(connection_id, update.clone())?;
2630 }
2631
2632 send_notifications(&pool, &session.peer, notifications);
2633 }
2634
2635 response.send(proto::Ack {})?;
2636 Ok(())
2637}
2638
2639/// Remove a contact.
2640async fn remove_contact(
2641 request: proto::RemoveContact,
2642 response: Response<proto::RemoveContact>,
2643 session: Session,
2644) -> Result<()> {
2645 let requester_id = session.user_id();
2646 let responder_id = UserId::from_proto(request.user_id);
2647 let db = session.db().await;
2648 let (contact_accepted, deleted_notification_id) =
2649 db.remove_contact(requester_id, responder_id).await?;
2650
2651 let pool = session.connection_pool().await;
2652 // Update outgoing contact requests of requester
2653 let mut update = proto::UpdateContacts::default();
2654 if contact_accepted {
2655 update.remove_contacts.push(responder_id.to_proto());
2656 } else {
2657 update
2658 .remove_outgoing_requests
2659 .push(responder_id.to_proto());
2660 }
2661 for connection_id in pool.user_connection_ids(requester_id) {
2662 session.peer.send(connection_id, update.clone())?;
2663 }
2664
2665 // Update incoming contact requests of responder
2666 let mut update = proto::UpdateContacts::default();
2667 if contact_accepted {
2668 update.remove_contacts.push(requester_id.to_proto());
2669 } else {
2670 update
2671 .remove_incoming_requests
2672 .push(requester_id.to_proto());
2673 }
2674 for connection_id in pool.user_connection_ids(responder_id) {
2675 session.peer.send(connection_id, update.clone())?;
2676 if let Some(notification_id) = deleted_notification_id {
2677 session.peer.send(
2678 connection_id,
2679 proto::DeleteNotification {
2680 notification_id: notification_id.to_proto(),
2681 },
2682 )?;
2683 }
2684 }
2685
2686 response.send(proto::Ack {})?;
2687 Ok(())
2688}
2689
2690fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2691 version.0.minor() < 139
2692}
2693
2694async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2695 if is_staff {
2696 return Ok(proto::Plan::ZedPro);
2697 }
2698
2699 let subscription = db.get_active_billing_subscription(user_id).await?;
2700 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2701
2702 let plan = if let Some(subscription_kind) = subscription_kind {
2703 match subscription_kind {
2704 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2705 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2706 SubscriptionKind::ZedFree => proto::Plan::Free,
2707 }
2708 } else {
2709 proto::Plan::Free
2710 };
2711
2712 Ok(plan)
2713}
2714
2715async fn make_update_user_plan_message(
2716 user: &User,
2717 is_staff: bool,
2718 db: &Arc<Database>,
2719 llm_db: Option<Arc<LlmDatabase>>,
2720) -> Result<proto::UpdateUserPlan> {
2721 let feature_flags = db.get_user_flags(user.id).await?;
2722 let plan = current_plan(db, user.id, is_staff).await?;
2723 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2724 let billing_preferences = db.get_billing_preferences(user.id).await?;
2725
2726 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2727 let subscription = db.get_active_billing_subscription(user.id).await?;
2728
2729 let subscription_period =
2730 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2731
2732 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2733 llm_db
2734 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2735 .await?
2736 } else {
2737 None
2738 };
2739
2740 (subscription_period, usage)
2741 } else {
2742 (None, None)
2743 };
2744
2745 let account_too_young =
2746 !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2747
2748 Ok(proto::UpdateUserPlan {
2749 plan: plan.into(),
2750 trial_started_at: billing_customer
2751 .as_ref()
2752 .and_then(|billing_customer| billing_customer.trial_started_at)
2753 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2754 is_usage_based_billing_enabled: if is_staff {
2755 Some(true)
2756 } else {
2757 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2758 },
2759 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2760 proto::SubscriptionPeriod {
2761 started_at: started_at.timestamp() as u64,
2762 ended_at: ended_at.timestamp() as u64,
2763 }
2764 }),
2765 account_too_young: Some(account_too_young),
2766 has_overdue_invoices: billing_customer
2767 .map(|billing_customer| billing_customer.has_overdue_invoices),
2768 usage: usage.map(|usage| {
2769 let plan = match plan {
2770 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2771 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2772 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2773 };
2774
2775 let model_requests_limit = match plan.model_requests_limit() {
2776 zed_llm_client::UsageLimit::Limited(limit) => {
2777 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2778 && feature_flags
2779 .iter()
2780 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2781 {
2782 1_000
2783 } else {
2784 limit
2785 };
2786
2787 zed_llm_client::UsageLimit::Limited(limit)
2788 }
2789 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2790 };
2791
2792 proto::SubscriptionUsage {
2793 model_requests_usage_amount: usage.model_requests as u32,
2794 model_requests_usage_limit: Some(proto::UsageLimit {
2795 variant: Some(match model_requests_limit {
2796 zed_llm_client::UsageLimit::Limited(limit) => {
2797 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2798 limit: limit as u32,
2799 })
2800 }
2801 zed_llm_client::UsageLimit::Unlimited => {
2802 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2803 }
2804 }),
2805 }),
2806 edit_predictions_usage_amount: usage.edit_predictions as u32,
2807 edit_predictions_usage_limit: Some(proto::UsageLimit {
2808 variant: Some(match plan.edit_predictions_limit() {
2809 zed_llm_client::UsageLimit::Limited(limit) => {
2810 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2811 limit: limit as u32,
2812 })
2813 }
2814 zed_llm_client::UsageLimit::Unlimited => {
2815 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2816 }
2817 }),
2818 }),
2819 }
2820 }),
2821 })
2822}
2823
2824async fn update_user_plan(session: &Session) -> Result<()> {
2825 let db = session.db().await;
2826
2827 let update_user_plan = make_update_user_plan_message(
2828 session.principal.user(),
2829 session.is_staff(),
2830 &db.0,
2831 session.app_state.llm_db.clone(),
2832 )
2833 .await?;
2834
2835 session
2836 .peer
2837 .send(session.connection_id, update_user_plan)
2838 .trace_err();
2839
2840 Ok(())
2841}
2842
2843async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2844 subscribe_user_to_channels(session.user_id(), &session).await?;
2845 Ok(())
2846}
2847
2848async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2849 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2850 let mut pool = session.connection_pool().await;
2851 for membership in &channels_for_user.channel_memberships {
2852 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2853 }
2854 session.peer.send(
2855 session.connection_id,
2856 build_update_user_channels(&channels_for_user),
2857 )?;
2858 session.peer.send(
2859 session.connection_id,
2860 build_channels_update(channels_for_user),
2861 )?;
2862 Ok(())
2863}
2864
2865/// Creates a new channel.
2866async fn create_channel(
2867 request: proto::CreateChannel,
2868 response: Response<proto::CreateChannel>,
2869 session: Session,
2870) -> Result<()> {
2871 let db = session.db().await;
2872
2873 let parent_id = request.parent_id.map(ChannelId::from_proto);
2874 let (channel, membership) = db
2875 .create_channel(&request.name, parent_id, session.user_id())
2876 .await?;
2877
2878 let root_id = channel.root_id();
2879 let channel = Channel::from_model(channel);
2880
2881 response.send(proto::CreateChannelResponse {
2882 channel: Some(channel.to_proto()),
2883 parent_id: request.parent_id,
2884 })?;
2885
2886 let mut connection_pool = session.connection_pool().await;
2887 if let Some(membership) = membership {
2888 connection_pool.subscribe_to_channel(
2889 membership.user_id,
2890 membership.channel_id,
2891 membership.role,
2892 );
2893 let update = proto::UpdateUserChannels {
2894 channel_memberships: vec![proto::ChannelMembership {
2895 channel_id: membership.channel_id.to_proto(),
2896 role: membership.role.into(),
2897 }],
2898 ..Default::default()
2899 };
2900 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2901 session.peer.send(connection_id, update.clone())?;
2902 }
2903 }
2904
2905 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2906 if !role.can_see_channel(channel.visibility) {
2907 continue;
2908 }
2909
2910 let update = proto::UpdateChannels {
2911 channels: vec![channel.to_proto()],
2912 ..Default::default()
2913 };
2914 session.peer.send(connection_id, update.clone())?;
2915 }
2916
2917 Ok(())
2918}
2919
2920/// Delete a channel
2921async fn delete_channel(
2922 request: proto::DeleteChannel,
2923 response: Response<proto::DeleteChannel>,
2924 session: Session,
2925) -> Result<()> {
2926 let db = session.db().await;
2927
2928 let channel_id = request.channel_id;
2929 let (root_channel, removed_channels) = db
2930 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2931 .await?;
2932 response.send(proto::Ack {})?;
2933
2934 // Notify members of removed channels
2935 let mut update = proto::UpdateChannels::default();
2936 update
2937 .delete_channels
2938 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2939
2940 let connection_pool = session.connection_pool().await;
2941 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2942 session.peer.send(connection_id, update.clone())?;
2943 }
2944
2945 Ok(())
2946}
2947
2948/// Invite someone to join a channel.
2949async fn invite_channel_member(
2950 request: proto::InviteChannelMember,
2951 response: Response<proto::InviteChannelMember>,
2952 session: Session,
2953) -> Result<()> {
2954 let db = session.db().await;
2955 let channel_id = ChannelId::from_proto(request.channel_id);
2956 let invitee_id = UserId::from_proto(request.user_id);
2957 let InviteMemberResult {
2958 channel,
2959 notifications,
2960 } = db
2961 .invite_channel_member(
2962 channel_id,
2963 invitee_id,
2964 session.user_id(),
2965 request.role().into(),
2966 )
2967 .await?;
2968
2969 let update = proto::UpdateChannels {
2970 channel_invitations: vec![channel.to_proto()],
2971 ..Default::default()
2972 };
2973
2974 let connection_pool = session.connection_pool().await;
2975 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2976 session.peer.send(connection_id, update.clone())?;
2977 }
2978
2979 send_notifications(&connection_pool, &session.peer, notifications);
2980
2981 response.send(proto::Ack {})?;
2982 Ok(())
2983}
2984
2985/// remove someone from a channel
2986async fn remove_channel_member(
2987 request: proto::RemoveChannelMember,
2988 response: Response<proto::RemoveChannelMember>,
2989 session: Session,
2990) -> Result<()> {
2991 let db = session.db().await;
2992 let channel_id = ChannelId::from_proto(request.channel_id);
2993 let member_id = UserId::from_proto(request.user_id);
2994
2995 let RemoveChannelMemberResult {
2996 membership_update,
2997 notification_id,
2998 } = db
2999 .remove_channel_member(channel_id, member_id, session.user_id())
3000 .await?;
3001
3002 let mut connection_pool = session.connection_pool().await;
3003 notify_membership_updated(
3004 &mut connection_pool,
3005 membership_update,
3006 member_id,
3007 &session.peer,
3008 );
3009 for connection_id in connection_pool.user_connection_ids(member_id) {
3010 if let Some(notification_id) = notification_id {
3011 session
3012 .peer
3013 .send(
3014 connection_id,
3015 proto::DeleteNotification {
3016 notification_id: notification_id.to_proto(),
3017 },
3018 )
3019 .trace_err();
3020 }
3021 }
3022
3023 response.send(proto::Ack {})?;
3024 Ok(())
3025}
3026
3027/// Toggle the channel between public and private.
3028/// Care is taken to maintain the invariant that public channels only descend from public channels,
3029/// (though members-only channels can appear at any point in the hierarchy).
3030async fn set_channel_visibility(
3031 request: proto::SetChannelVisibility,
3032 response: Response<proto::SetChannelVisibility>,
3033 session: Session,
3034) -> Result<()> {
3035 let db = session.db().await;
3036 let channel_id = ChannelId::from_proto(request.channel_id);
3037 let visibility = request.visibility().into();
3038
3039 let channel_model = db
3040 .set_channel_visibility(channel_id, visibility, session.user_id())
3041 .await?;
3042 let root_id = channel_model.root_id();
3043 let channel = Channel::from_model(channel_model);
3044
3045 let mut connection_pool = session.connection_pool().await;
3046 for (user_id, role) in connection_pool
3047 .channel_user_ids(root_id)
3048 .collect::<Vec<_>>()
3049 .into_iter()
3050 {
3051 let update = if role.can_see_channel(channel.visibility) {
3052 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3053 proto::UpdateChannels {
3054 channels: vec![channel.to_proto()],
3055 ..Default::default()
3056 }
3057 } else {
3058 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3059 proto::UpdateChannels {
3060 delete_channels: vec![channel.id.to_proto()],
3061 ..Default::default()
3062 }
3063 };
3064
3065 for connection_id in connection_pool.user_connection_ids(user_id) {
3066 session.peer.send(connection_id, update.clone())?;
3067 }
3068 }
3069
3070 response.send(proto::Ack {})?;
3071 Ok(())
3072}
3073
3074/// Alter the role for a user in the channel.
3075async fn set_channel_member_role(
3076 request: proto::SetChannelMemberRole,
3077 response: Response<proto::SetChannelMemberRole>,
3078 session: Session,
3079) -> Result<()> {
3080 let db = session.db().await;
3081 let channel_id = ChannelId::from_proto(request.channel_id);
3082 let member_id = UserId::from_proto(request.user_id);
3083 let result = db
3084 .set_channel_member_role(
3085 channel_id,
3086 session.user_id(),
3087 member_id,
3088 request.role().into(),
3089 )
3090 .await?;
3091
3092 match result {
3093 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3094 let mut connection_pool = session.connection_pool().await;
3095 notify_membership_updated(
3096 &mut connection_pool,
3097 membership_update,
3098 member_id,
3099 &session.peer,
3100 )
3101 }
3102 db::SetMemberRoleResult::InviteUpdated(channel) => {
3103 let update = proto::UpdateChannels {
3104 channel_invitations: vec![channel.to_proto()],
3105 ..Default::default()
3106 };
3107
3108 for connection_id in session
3109 .connection_pool()
3110 .await
3111 .user_connection_ids(member_id)
3112 {
3113 session.peer.send(connection_id, update.clone())?;
3114 }
3115 }
3116 }
3117
3118 response.send(proto::Ack {})?;
3119 Ok(())
3120}
3121
3122/// Change the name of a channel
3123async fn rename_channel(
3124 request: proto::RenameChannel,
3125 response: Response<proto::RenameChannel>,
3126 session: Session,
3127) -> Result<()> {
3128 let db = session.db().await;
3129 let channel_id = ChannelId::from_proto(request.channel_id);
3130 let channel_model = db
3131 .rename_channel(channel_id, session.user_id(), &request.name)
3132 .await?;
3133 let root_id = channel_model.root_id();
3134 let channel = Channel::from_model(channel_model);
3135
3136 response.send(proto::RenameChannelResponse {
3137 channel: Some(channel.to_proto()),
3138 })?;
3139
3140 let connection_pool = session.connection_pool().await;
3141 let update = proto::UpdateChannels {
3142 channels: vec![channel.to_proto()],
3143 ..Default::default()
3144 };
3145 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3146 if role.can_see_channel(channel.visibility) {
3147 session.peer.send(connection_id, update.clone())?;
3148 }
3149 }
3150
3151 Ok(())
3152}
3153
3154/// Move a channel to a new parent.
3155async fn move_channel(
3156 request: proto::MoveChannel,
3157 response: Response<proto::MoveChannel>,
3158 session: Session,
3159) -> Result<()> {
3160 let channel_id = ChannelId::from_proto(request.channel_id);
3161 let to = ChannelId::from_proto(request.to);
3162
3163 let (root_id, channels) = session
3164 .db()
3165 .await
3166 .move_channel(channel_id, to, session.user_id())
3167 .await?;
3168
3169 let connection_pool = session.connection_pool().await;
3170 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3171 let channels = channels
3172 .iter()
3173 .filter_map(|channel| {
3174 if role.can_see_channel(channel.visibility) {
3175 Some(channel.to_proto())
3176 } else {
3177 None
3178 }
3179 })
3180 .collect::<Vec<_>>();
3181 if channels.is_empty() {
3182 continue;
3183 }
3184
3185 let update = proto::UpdateChannels {
3186 channels,
3187 ..Default::default()
3188 };
3189
3190 session.peer.send(connection_id, update.clone())?;
3191 }
3192
3193 response.send(Ack {})?;
3194 Ok(())
3195}
3196
3197/// Get the list of channel members
3198async fn get_channel_members(
3199 request: proto::GetChannelMembers,
3200 response: Response<proto::GetChannelMembers>,
3201 session: Session,
3202) -> Result<()> {
3203 let db = session.db().await;
3204 let channel_id = ChannelId::from_proto(request.channel_id);
3205 let limit = if request.limit == 0 {
3206 u16::MAX as u64
3207 } else {
3208 request.limit
3209 };
3210 let (members, users) = db
3211 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3212 .await?;
3213 response.send(proto::GetChannelMembersResponse { members, users })?;
3214 Ok(())
3215}
3216
3217/// Accept or decline a channel invitation.
3218async fn respond_to_channel_invite(
3219 request: proto::RespondToChannelInvite,
3220 response: Response<proto::RespondToChannelInvite>,
3221 session: Session,
3222) -> Result<()> {
3223 let db = session.db().await;
3224 let channel_id = ChannelId::from_proto(request.channel_id);
3225 let RespondToChannelInvite {
3226 membership_update,
3227 notifications,
3228 } = db
3229 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3230 .await?;
3231
3232 let mut connection_pool = session.connection_pool().await;
3233 if let Some(membership_update) = membership_update {
3234 notify_membership_updated(
3235 &mut connection_pool,
3236 membership_update,
3237 session.user_id(),
3238 &session.peer,
3239 );
3240 } else {
3241 let update = proto::UpdateChannels {
3242 remove_channel_invitations: vec![channel_id.to_proto()],
3243 ..Default::default()
3244 };
3245
3246 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3247 session.peer.send(connection_id, update.clone())?;
3248 }
3249 };
3250
3251 send_notifications(&connection_pool, &session.peer, notifications);
3252
3253 response.send(proto::Ack {})?;
3254
3255 Ok(())
3256}
3257
3258/// Join the channels' room
3259async fn join_channel(
3260 request: proto::JoinChannel,
3261 response: Response<proto::JoinChannel>,
3262 session: Session,
3263) -> Result<()> {
3264 let channel_id = ChannelId::from_proto(request.channel_id);
3265 join_channel_internal(channel_id, Box::new(response), session).await
3266}
3267
3268trait JoinChannelInternalResponse {
3269 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3270}
3271impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3272 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3273 Response::<proto::JoinChannel>::send(self, result)
3274 }
3275}
3276impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3277 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3278 Response::<proto::JoinRoom>::send(self, result)
3279 }
3280}
3281
3282async fn join_channel_internal(
3283 channel_id: ChannelId,
3284 response: Box<impl JoinChannelInternalResponse>,
3285 session: Session,
3286) -> Result<()> {
3287 let joined_room = {
3288 let mut db = session.db().await;
3289 // If zed quits without leaving the room, and the user re-opens zed before the
3290 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3291 // room they were in.
3292 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3293 tracing::info!(
3294 stale_connection_id = %connection,
3295 "cleaning up stale connection",
3296 );
3297 drop(db);
3298 leave_room_for_session(&session, connection).await?;
3299 db = session.db().await;
3300 }
3301
3302 let (joined_room, membership_updated, role) = db
3303 .join_channel(channel_id, session.user_id(), session.connection_id)
3304 .await?;
3305
3306 let live_kit_connection_info =
3307 session
3308 .app_state
3309 .livekit_client
3310 .as_ref()
3311 .and_then(|live_kit| {
3312 let (can_publish, token) = if role == ChannelRole::Guest {
3313 (
3314 false,
3315 live_kit
3316 .guest_token(
3317 &joined_room.room.livekit_room,
3318 &session.user_id().to_string(),
3319 )
3320 .trace_err()?,
3321 )
3322 } else {
3323 (
3324 true,
3325 live_kit
3326 .room_token(
3327 &joined_room.room.livekit_room,
3328 &session.user_id().to_string(),
3329 )
3330 .trace_err()?,
3331 )
3332 };
3333
3334 Some(LiveKitConnectionInfo {
3335 server_url: live_kit.url().into(),
3336 token,
3337 can_publish,
3338 })
3339 });
3340
3341 response.send(proto::JoinRoomResponse {
3342 room: Some(joined_room.room.clone()),
3343 channel_id: joined_room
3344 .channel
3345 .as_ref()
3346 .map(|channel| channel.id.to_proto()),
3347 live_kit_connection_info,
3348 })?;
3349
3350 let mut connection_pool = session.connection_pool().await;
3351 if let Some(membership_updated) = membership_updated {
3352 notify_membership_updated(
3353 &mut connection_pool,
3354 membership_updated,
3355 session.user_id(),
3356 &session.peer,
3357 );
3358 }
3359
3360 room_updated(&joined_room.room, &session.peer);
3361
3362 joined_room
3363 };
3364
3365 channel_updated(
3366 &joined_room.channel.context("channel not returned")?,
3367 &joined_room.room,
3368 &session.peer,
3369 &*session.connection_pool().await,
3370 );
3371
3372 update_user_contacts(session.user_id(), &session).await?;
3373 Ok(())
3374}
3375
3376/// Start editing the channel notes
3377async fn join_channel_buffer(
3378 request: proto::JoinChannelBuffer,
3379 response: Response<proto::JoinChannelBuffer>,
3380 session: Session,
3381) -> Result<()> {
3382 let db = session.db().await;
3383 let channel_id = ChannelId::from_proto(request.channel_id);
3384
3385 let open_response = db
3386 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3387 .await?;
3388
3389 let collaborators = open_response.collaborators.clone();
3390 response.send(open_response)?;
3391
3392 let update = UpdateChannelBufferCollaborators {
3393 channel_id: channel_id.to_proto(),
3394 collaborators: collaborators.clone(),
3395 };
3396 channel_buffer_updated(
3397 session.connection_id,
3398 collaborators
3399 .iter()
3400 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3401 &update,
3402 &session.peer,
3403 );
3404
3405 Ok(())
3406}
3407
3408/// Edit the channel notes
3409async fn update_channel_buffer(
3410 request: proto::UpdateChannelBuffer,
3411 session: Session,
3412) -> Result<()> {
3413 let db = session.db().await;
3414 let channel_id = ChannelId::from_proto(request.channel_id);
3415
3416 let (collaborators, epoch, version) = db
3417 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3418 .await?;
3419
3420 channel_buffer_updated(
3421 session.connection_id,
3422 collaborators.clone(),
3423 &proto::UpdateChannelBuffer {
3424 channel_id: channel_id.to_proto(),
3425 operations: request.operations,
3426 },
3427 &session.peer,
3428 );
3429
3430 let pool = &*session.connection_pool().await;
3431
3432 let non_collaborators =
3433 pool.channel_connection_ids(channel_id)
3434 .filter_map(|(connection_id, _)| {
3435 if collaborators.contains(&connection_id) {
3436 None
3437 } else {
3438 Some(connection_id)
3439 }
3440 });
3441
3442 broadcast(None, non_collaborators, |peer_id| {
3443 session.peer.send(
3444 peer_id,
3445 proto::UpdateChannels {
3446 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3447 channel_id: channel_id.to_proto(),
3448 epoch: epoch as u64,
3449 version: version.clone(),
3450 }],
3451 ..Default::default()
3452 },
3453 )
3454 });
3455
3456 Ok(())
3457}
3458
3459/// Rejoin the channel notes after a connection blip
3460async fn rejoin_channel_buffers(
3461 request: proto::RejoinChannelBuffers,
3462 response: Response<proto::RejoinChannelBuffers>,
3463 session: Session,
3464) -> Result<()> {
3465 let db = session.db().await;
3466 let buffers = db
3467 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3468 .await?;
3469
3470 for rejoined_buffer in &buffers {
3471 let collaborators_to_notify = rejoined_buffer
3472 .buffer
3473 .collaborators
3474 .iter()
3475 .filter_map(|c| Some(c.peer_id?.into()));
3476 channel_buffer_updated(
3477 session.connection_id,
3478 collaborators_to_notify,
3479 &proto::UpdateChannelBufferCollaborators {
3480 channel_id: rejoined_buffer.buffer.channel_id,
3481 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3482 },
3483 &session.peer,
3484 );
3485 }
3486
3487 response.send(proto::RejoinChannelBuffersResponse {
3488 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3489 })?;
3490
3491 Ok(())
3492}
3493
3494/// Stop editing the channel notes
3495async fn leave_channel_buffer(
3496 request: proto::LeaveChannelBuffer,
3497 response: Response<proto::LeaveChannelBuffer>,
3498 session: Session,
3499) -> Result<()> {
3500 let db = session.db().await;
3501 let channel_id = ChannelId::from_proto(request.channel_id);
3502
3503 let left_buffer = db
3504 .leave_channel_buffer(channel_id, session.connection_id)
3505 .await?;
3506
3507 response.send(Ack {})?;
3508
3509 channel_buffer_updated(
3510 session.connection_id,
3511 left_buffer.connections,
3512 &proto::UpdateChannelBufferCollaborators {
3513 channel_id: channel_id.to_proto(),
3514 collaborators: left_buffer.collaborators,
3515 },
3516 &session.peer,
3517 );
3518
3519 Ok(())
3520}
3521
3522fn channel_buffer_updated<T: EnvelopedMessage>(
3523 sender_id: ConnectionId,
3524 collaborators: impl IntoIterator<Item = ConnectionId>,
3525 message: &T,
3526 peer: &Peer,
3527) {
3528 broadcast(Some(sender_id), collaborators, |peer_id| {
3529 peer.send(peer_id, message.clone())
3530 });
3531}
3532
3533fn send_notifications(
3534 connection_pool: &ConnectionPool,
3535 peer: &Peer,
3536 notifications: db::NotificationBatch,
3537) {
3538 for (user_id, notification) in notifications {
3539 for connection_id in connection_pool.user_connection_ids(user_id) {
3540 if let Err(error) = peer.send(
3541 connection_id,
3542 proto::AddNotification {
3543 notification: Some(notification.clone()),
3544 },
3545 ) {
3546 tracing::error!(
3547 "failed to send notification to {:?} {}",
3548 connection_id,
3549 error
3550 );
3551 }
3552 }
3553 }
3554}
3555
3556/// Send a message to the channel
3557async fn send_channel_message(
3558 request: proto::SendChannelMessage,
3559 response: Response<proto::SendChannelMessage>,
3560 session: Session,
3561) -> Result<()> {
3562 // Validate the message body.
3563 let body = request.body.trim().to_string();
3564 if body.len() > MAX_MESSAGE_LEN {
3565 return Err(anyhow!("message is too long"))?;
3566 }
3567 if body.is_empty() {
3568 return Err(anyhow!("message can't be blank"))?;
3569 }
3570
3571 // TODO: adjust mentions if body is trimmed
3572
3573 let timestamp = OffsetDateTime::now_utc();
3574 let nonce = request.nonce.context("nonce can't be blank")?;
3575
3576 let channel_id = ChannelId::from_proto(request.channel_id);
3577 let CreatedChannelMessage {
3578 message_id,
3579 participant_connection_ids,
3580 notifications,
3581 } = session
3582 .db()
3583 .await
3584 .create_channel_message(
3585 channel_id,
3586 session.user_id(),
3587 &body,
3588 &request.mentions,
3589 timestamp,
3590 nonce.clone().into(),
3591 request.reply_to_message_id.map(MessageId::from_proto),
3592 )
3593 .await?;
3594
3595 let message = proto::ChannelMessage {
3596 sender_id: session.user_id().to_proto(),
3597 id: message_id.to_proto(),
3598 body,
3599 mentions: request.mentions,
3600 timestamp: timestamp.unix_timestamp() as u64,
3601 nonce: Some(nonce),
3602 reply_to_message_id: request.reply_to_message_id,
3603 edited_at: None,
3604 };
3605 broadcast(
3606 Some(session.connection_id),
3607 participant_connection_ids.clone(),
3608 |connection| {
3609 session.peer.send(
3610 connection,
3611 proto::ChannelMessageSent {
3612 channel_id: channel_id.to_proto(),
3613 message: Some(message.clone()),
3614 },
3615 )
3616 },
3617 );
3618 response.send(proto::SendChannelMessageResponse {
3619 message: Some(message),
3620 })?;
3621
3622 let pool = &*session.connection_pool().await;
3623 let non_participants =
3624 pool.channel_connection_ids(channel_id)
3625 .filter_map(|(connection_id, _)| {
3626 if participant_connection_ids.contains(&connection_id) {
3627 None
3628 } else {
3629 Some(connection_id)
3630 }
3631 });
3632 broadcast(None, non_participants, |peer_id| {
3633 session.peer.send(
3634 peer_id,
3635 proto::UpdateChannels {
3636 latest_channel_message_ids: vec![proto::ChannelMessageId {
3637 channel_id: channel_id.to_proto(),
3638 message_id: message_id.to_proto(),
3639 }],
3640 ..Default::default()
3641 },
3642 )
3643 });
3644 send_notifications(pool, &session.peer, notifications);
3645
3646 Ok(())
3647}
3648
3649/// Delete a channel message
3650async fn remove_channel_message(
3651 request: proto::RemoveChannelMessage,
3652 response: Response<proto::RemoveChannelMessage>,
3653 session: Session,
3654) -> Result<()> {
3655 let channel_id = ChannelId::from_proto(request.channel_id);
3656 let message_id = MessageId::from_proto(request.message_id);
3657 let (connection_ids, existing_notification_ids) = session
3658 .db()
3659 .await
3660 .remove_channel_message(channel_id, message_id, session.user_id())
3661 .await?;
3662
3663 broadcast(
3664 Some(session.connection_id),
3665 connection_ids,
3666 move |connection| {
3667 session.peer.send(connection, request.clone())?;
3668
3669 for notification_id in &existing_notification_ids {
3670 session.peer.send(
3671 connection,
3672 proto::DeleteNotification {
3673 notification_id: (*notification_id).to_proto(),
3674 },
3675 )?;
3676 }
3677
3678 Ok(())
3679 },
3680 );
3681 response.send(proto::Ack {})?;
3682 Ok(())
3683}
3684
3685async fn update_channel_message(
3686 request: proto::UpdateChannelMessage,
3687 response: Response<proto::UpdateChannelMessage>,
3688 session: Session,
3689) -> Result<()> {
3690 let channel_id = ChannelId::from_proto(request.channel_id);
3691 let message_id = MessageId::from_proto(request.message_id);
3692 let updated_at = OffsetDateTime::now_utc();
3693 let UpdatedChannelMessage {
3694 message_id,
3695 participant_connection_ids,
3696 notifications,
3697 reply_to_message_id,
3698 timestamp,
3699 deleted_mention_notification_ids,
3700 updated_mention_notifications,
3701 } = session
3702 .db()
3703 .await
3704 .update_channel_message(
3705 channel_id,
3706 message_id,
3707 session.user_id(),
3708 request.body.as_str(),
3709 &request.mentions,
3710 updated_at,
3711 )
3712 .await?;
3713
3714 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3715
3716 let message = proto::ChannelMessage {
3717 sender_id: session.user_id().to_proto(),
3718 id: message_id.to_proto(),
3719 body: request.body.clone(),
3720 mentions: request.mentions.clone(),
3721 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3722 nonce: Some(nonce),
3723 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3724 edited_at: Some(updated_at.unix_timestamp() as u64),
3725 };
3726
3727 response.send(proto::Ack {})?;
3728
3729 let pool = &*session.connection_pool().await;
3730 broadcast(
3731 Some(session.connection_id),
3732 participant_connection_ids,
3733 |connection| {
3734 session.peer.send(
3735 connection,
3736 proto::ChannelMessageUpdate {
3737 channel_id: channel_id.to_proto(),
3738 message: Some(message.clone()),
3739 },
3740 )?;
3741
3742 for notification_id in &deleted_mention_notification_ids {
3743 session.peer.send(
3744 connection,
3745 proto::DeleteNotification {
3746 notification_id: (*notification_id).to_proto(),
3747 },
3748 )?;
3749 }
3750
3751 for notification in &updated_mention_notifications {
3752 session.peer.send(
3753 connection,
3754 proto::UpdateNotification {
3755 notification: Some(notification.clone()),
3756 },
3757 )?;
3758 }
3759
3760 Ok(())
3761 },
3762 );
3763
3764 send_notifications(pool, &session.peer, notifications);
3765
3766 Ok(())
3767}
3768
3769/// Mark a channel message as read
3770async fn acknowledge_channel_message(
3771 request: proto::AckChannelMessage,
3772 session: Session,
3773) -> Result<()> {
3774 let channel_id = ChannelId::from_proto(request.channel_id);
3775 let message_id = MessageId::from_proto(request.message_id);
3776 let notifications = session
3777 .db()
3778 .await
3779 .observe_channel_message(channel_id, session.user_id(), message_id)
3780 .await?;
3781 send_notifications(
3782 &*session.connection_pool().await,
3783 &session.peer,
3784 notifications,
3785 );
3786 Ok(())
3787}
3788
3789/// Mark a buffer version as synced
3790async fn acknowledge_buffer_version(
3791 request: proto::AckBufferOperation,
3792 session: Session,
3793) -> Result<()> {
3794 let buffer_id = BufferId::from_proto(request.buffer_id);
3795 session
3796 .db()
3797 .await
3798 .observe_buffer_version(
3799 buffer_id,
3800 session.user_id(),
3801 request.epoch as i32,
3802 &request.version,
3803 )
3804 .await?;
3805 Ok(())
3806}
3807
3808/// Get a Supermaven API key for the user
3809async fn get_supermaven_api_key(
3810 _request: proto::GetSupermavenApiKey,
3811 response: Response<proto::GetSupermavenApiKey>,
3812 session: Session,
3813) -> Result<()> {
3814 let user_id: String = session.user_id().to_string();
3815 if !session.is_staff() {
3816 return Err(anyhow!("supermaven not enabled for this account"))?;
3817 }
3818
3819 let email = session.email().context("user must have an email")?;
3820
3821 let supermaven_admin_api = session
3822 .supermaven_client
3823 .as_ref()
3824 .context("supermaven not configured")?;
3825
3826 let result = supermaven_admin_api
3827 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3828 .await?;
3829
3830 response.send(proto::GetSupermavenApiKeyResponse {
3831 api_key: result.api_key,
3832 })?;
3833
3834 Ok(())
3835}
3836
3837/// Start receiving chat updates for a channel
3838async fn join_channel_chat(
3839 request: proto::JoinChannelChat,
3840 response: Response<proto::JoinChannelChat>,
3841 session: Session,
3842) -> Result<()> {
3843 let channel_id = ChannelId::from_proto(request.channel_id);
3844
3845 let db = session.db().await;
3846 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3847 .await?;
3848 let messages = db
3849 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3850 .await?;
3851 response.send(proto::JoinChannelChatResponse {
3852 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3853 messages,
3854 })?;
3855 Ok(())
3856}
3857
3858/// Stop receiving chat updates for a channel
3859async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3860 let channel_id = ChannelId::from_proto(request.channel_id);
3861 session
3862 .db()
3863 .await
3864 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3865 .await?;
3866 Ok(())
3867}
3868
3869/// Retrieve the chat history for a channel
3870async fn get_channel_messages(
3871 request: proto::GetChannelMessages,
3872 response: Response<proto::GetChannelMessages>,
3873 session: Session,
3874) -> Result<()> {
3875 let channel_id = ChannelId::from_proto(request.channel_id);
3876 let messages = session
3877 .db()
3878 .await
3879 .get_channel_messages(
3880 channel_id,
3881 session.user_id(),
3882 MESSAGE_COUNT_PER_PAGE,
3883 Some(MessageId::from_proto(request.before_message_id)),
3884 )
3885 .await?;
3886 response.send(proto::GetChannelMessagesResponse {
3887 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3888 messages,
3889 })?;
3890 Ok(())
3891}
3892
3893/// Retrieve specific chat messages
3894async fn get_channel_messages_by_id(
3895 request: proto::GetChannelMessagesById,
3896 response: Response<proto::GetChannelMessagesById>,
3897 session: Session,
3898) -> Result<()> {
3899 let message_ids = request
3900 .message_ids
3901 .iter()
3902 .map(|id| MessageId::from_proto(*id))
3903 .collect::<Vec<_>>();
3904 let messages = session
3905 .db()
3906 .await
3907 .get_channel_messages_by_id(session.user_id(), &message_ids)
3908 .await?;
3909 response.send(proto::GetChannelMessagesResponse {
3910 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3911 messages,
3912 })?;
3913 Ok(())
3914}
3915
3916/// Retrieve the current users notifications
3917async fn get_notifications(
3918 request: proto::GetNotifications,
3919 response: Response<proto::GetNotifications>,
3920 session: Session,
3921) -> Result<()> {
3922 let notifications = session
3923 .db()
3924 .await
3925 .get_notifications(
3926 session.user_id(),
3927 NOTIFICATION_COUNT_PER_PAGE,
3928 request.before_id.map(db::NotificationId::from_proto),
3929 )
3930 .await?;
3931 response.send(proto::GetNotificationsResponse {
3932 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3933 notifications,
3934 })?;
3935 Ok(())
3936}
3937
3938/// Mark notifications as read
3939async fn mark_notification_as_read(
3940 request: proto::MarkNotificationRead,
3941 response: Response<proto::MarkNotificationRead>,
3942 session: Session,
3943) -> Result<()> {
3944 let database = &session.db().await;
3945 let notifications = database
3946 .mark_notification_as_read_by_id(
3947 session.user_id(),
3948 NotificationId::from_proto(request.notification_id),
3949 )
3950 .await?;
3951 send_notifications(
3952 &*session.connection_pool().await,
3953 &session.peer,
3954 notifications,
3955 );
3956 response.send(proto::Ack {})?;
3957 Ok(())
3958}
3959
3960/// Get the current users information
3961async fn get_private_user_info(
3962 _request: proto::GetPrivateUserInfo,
3963 response: Response<proto::GetPrivateUserInfo>,
3964 session: Session,
3965) -> Result<()> {
3966 let db = session.db().await;
3967
3968 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3969 let user = db
3970 .get_user_by_id(session.user_id())
3971 .await?
3972 .context("user not found")?;
3973 let flags = db.get_user_flags(session.user_id()).await?;
3974
3975 response.send(proto::GetPrivateUserInfoResponse {
3976 metrics_id,
3977 staff: user.admin,
3978 flags,
3979 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
3980 })?;
3981 Ok(())
3982}
3983
3984/// Accept the terms of service (tos) on behalf of the current user
3985async fn accept_terms_of_service(
3986 _request: proto::AcceptTermsOfService,
3987 response: Response<proto::AcceptTermsOfService>,
3988 session: Session,
3989) -> Result<()> {
3990 let db = session.db().await;
3991
3992 let accepted_tos_at = Utc::now();
3993 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3994 .await?;
3995
3996 response.send(proto::AcceptTermsOfServiceResponse {
3997 accepted_tos_at: accepted_tos_at.timestamp() as u64,
3998 })?;
3999 Ok(())
4000}
4001
4002/// The minimum account age an account must have in order to use the LLM service.
4003pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4004
4005async fn get_llm_api_token(
4006 _request: proto::GetLlmToken,
4007 response: Response<proto::GetLlmToken>,
4008 session: Session,
4009) -> Result<()> {
4010 let db = session.db().await;
4011
4012 let flags = db.get_user_flags(session.user_id()).await?;
4013
4014 let user_id = session.user_id();
4015 let user = db
4016 .get_user_by_id(user_id)
4017 .await?
4018 .with_context(|| format!("user {user_id} not found"))?;
4019
4020 if user.accepted_tos_at.is_none() {
4021 Err(anyhow!("terms of service not accepted"))?
4022 }
4023
4024 let stripe_client = session
4025 .app_state
4026 .stripe_client
4027 .as_ref()
4028 .context("failed to retrieve Stripe client")?;
4029
4030 let stripe_billing = session
4031 .app_state
4032 .stripe_billing
4033 .as_ref()
4034 .context("failed to retrieve Stripe billing object")?;
4035
4036 let billing_customer =
4037 if let Some(billing_customer) = db.get_billing_customer_by_user_id(user.id).await? {
4038 billing_customer
4039 } else {
4040 let customer_id = stripe_billing
4041 .find_or_create_customer_by_email(user.email_address.as_deref())
4042 .await?;
4043
4044 find_or_create_billing_customer(
4045 &session.app_state,
4046 &stripe_client,
4047 stripe::Expandable::Id(customer_id),
4048 )
4049 .await?
4050 .context("billing customer not found")?
4051 };
4052
4053 let billing_subscription =
4054 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4055 billing_subscription
4056 } else {
4057 let stripe_customer_id = billing_customer
4058 .stripe_customer_id
4059 .parse::<stripe::CustomerId>()
4060 .context("failed to parse Stripe customer ID from database")?;
4061
4062 let stripe_subscription = stripe_billing
4063 .subscribe_to_zed_free(stripe_customer_id)
4064 .await?;
4065
4066 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4067 billing_customer_id: billing_customer.id,
4068 kind: Some(SubscriptionKind::ZedFree),
4069 stripe_subscription_id: stripe_subscription.id.to_string(),
4070 stripe_subscription_status: stripe_subscription.status.into(),
4071 stripe_cancellation_reason: None,
4072 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4073 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4074 })
4075 .await?
4076 };
4077
4078 let billing_preferences = db.get_billing_preferences(user.id).await?;
4079
4080 let token = LlmTokenClaims::create(
4081 &user,
4082 session.is_staff(),
4083 billing_customer,
4084 billing_preferences,
4085 &flags,
4086 billing_subscription,
4087 session.system_id.clone(),
4088 &session.app_state.config,
4089 )?;
4090 response.send(proto::GetLlmTokenResponse { token })?;
4091 Ok(())
4092}
4093
4094fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4095 let message = match message {
4096 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4097 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4098 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4099 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4100 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4101 code: frame.code.into(),
4102 reason: frame.reason.as_str().to_owned().into(),
4103 })),
4104 // We should never receive a frame while reading the message, according
4105 // to the `tungstenite` maintainers:
4106 //
4107 // > It cannot occur when you read messages from the WebSocket, but it
4108 // > can be used when you want to send the raw frames (e.g. you want to
4109 // > send the frames to the WebSocket without composing the full message first).
4110 // >
4111 // > — https://github.com/snapview/tungstenite-rs/issues/268
4112 TungsteniteMessage::Frame(_) => {
4113 bail!("received an unexpected frame while reading the message")
4114 }
4115 };
4116
4117 Ok(message)
4118}
4119
4120fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4121 match message {
4122 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4123 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4124 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4125 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4126 AxumMessage::Close(frame) => {
4127 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4128 code: frame.code.into(),
4129 reason: frame.reason.as_ref().into(),
4130 }))
4131 }
4132 }
4133}
4134
4135fn notify_membership_updated(
4136 connection_pool: &mut ConnectionPool,
4137 result: MembershipUpdated,
4138 user_id: UserId,
4139 peer: &Peer,
4140) {
4141 for membership in &result.new_channels.channel_memberships {
4142 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4143 }
4144 for channel_id in &result.removed_channels {
4145 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4146 }
4147
4148 let user_channels_update = proto::UpdateUserChannels {
4149 channel_memberships: result
4150 .new_channels
4151 .channel_memberships
4152 .iter()
4153 .map(|cm| proto::ChannelMembership {
4154 channel_id: cm.channel_id.to_proto(),
4155 role: cm.role.into(),
4156 })
4157 .collect(),
4158 ..Default::default()
4159 };
4160
4161 let mut update = build_channels_update(result.new_channels);
4162 update.delete_channels = result
4163 .removed_channels
4164 .into_iter()
4165 .map(|id| id.to_proto())
4166 .collect();
4167 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4168
4169 for connection_id in connection_pool.user_connection_ids(user_id) {
4170 peer.send(connection_id, user_channels_update.clone())
4171 .trace_err();
4172 peer.send(connection_id, update.clone()).trace_err();
4173 }
4174}
4175
4176fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4177 proto::UpdateUserChannels {
4178 channel_memberships: channels
4179 .channel_memberships
4180 .iter()
4181 .map(|m| proto::ChannelMembership {
4182 channel_id: m.channel_id.to_proto(),
4183 role: m.role.into(),
4184 })
4185 .collect(),
4186 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4187 observed_channel_message_id: channels.observed_channel_messages.clone(),
4188 }
4189}
4190
4191fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4192 let mut update = proto::UpdateChannels::default();
4193
4194 for channel in channels.channels {
4195 update.channels.push(channel.to_proto());
4196 }
4197
4198 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4199 update.latest_channel_message_ids = channels.latest_channel_messages;
4200
4201 for (channel_id, participants) in channels.channel_participants {
4202 update
4203 .channel_participants
4204 .push(proto::ChannelParticipants {
4205 channel_id: channel_id.to_proto(),
4206 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4207 });
4208 }
4209
4210 for channel in channels.invited_channels {
4211 update.channel_invitations.push(channel.to_proto());
4212 }
4213
4214 update
4215}
4216
4217fn build_initial_contacts_update(
4218 contacts: Vec<db::Contact>,
4219 pool: &ConnectionPool,
4220) -> proto::UpdateContacts {
4221 let mut update = proto::UpdateContacts::default();
4222
4223 for contact in contacts {
4224 match contact {
4225 db::Contact::Accepted { user_id, busy } => {
4226 update.contacts.push(contact_for_user(user_id, busy, pool));
4227 }
4228 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4229 db::Contact::Incoming { user_id } => {
4230 update
4231 .incoming_requests
4232 .push(proto::IncomingContactRequest {
4233 requester_id: user_id.to_proto(),
4234 })
4235 }
4236 }
4237 }
4238
4239 update
4240}
4241
4242fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4243 proto::Contact {
4244 user_id: user_id.to_proto(),
4245 online: pool.is_user_online(user_id),
4246 busy,
4247 }
4248}
4249
4250fn room_updated(room: &proto::Room, peer: &Peer) {
4251 broadcast(
4252 None,
4253 room.participants
4254 .iter()
4255 .filter_map(|participant| Some(participant.peer_id?.into())),
4256 |peer_id| {
4257 peer.send(
4258 peer_id,
4259 proto::RoomUpdated {
4260 room: Some(room.clone()),
4261 },
4262 )
4263 },
4264 );
4265}
4266
4267fn channel_updated(
4268 channel: &db::channel::Model,
4269 room: &proto::Room,
4270 peer: &Peer,
4271 pool: &ConnectionPool,
4272) {
4273 let participants = room
4274 .participants
4275 .iter()
4276 .map(|p| p.user_id)
4277 .collect::<Vec<_>>();
4278
4279 broadcast(
4280 None,
4281 pool.channel_connection_ids(channel.root_id())
4282 .filter_map(|(channel_id, role)| {
4283 role.can_see_channel(channel.visibility)
4284 .then_some(channel_id)
4285 }),
4286 |peer_id| {
4287 peer.send(
4288 peer_id,
4289 proto::UpdateChannels {
4290 channel_participants: vec![proto::ChannelParticipants {
4291 channel_id: channel.id.to_proto(),
4292 participant_user_ids: participants.clone(),
4293 }],
4294 ..Default::default()
4295 },
4296 )
4297 },
4298 );
4299}
4300
4301async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4302 let db = session.db().await;
4303
4304 let contacts = db.get_contacts(user_id).await?;
4305 let busy = db.is_user_busy(user_id).await?;
4306
4307 let pool = session.connection_pool().await;
4308 let updated_contact = contact_for_user(user_id, busy, &pool);
4309 for contact in contacts {
4310 if let db::Contact::Accepted {
4311 user_id: contact_user_id,
4312 ..
4313 } = contact
4314 {
4315 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4316 session
4317 .peer
4318 .send(
4319 contact_conn_id,
4320 proto::UpdateContacts {
4321 contacts: vec![updated_contact.clone()],
4322 remove_contacts: Default::default(),
4323 incoming_requests: Default::default(),
4324 remove_incoming_requests: Default::default(),
4325 outgoing_requests: Default::default(),
4326 remove_outgoing_requests: Default::default(),
4327 },
4328 )
4329 .trace_err();
4330 }
4331 }
4332 }
4333 Ok(())
4334}
4335
4336async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4337 let mut contacts_to_update = HashSet::default();
4338
4339 let room_id;
4340 let canceled_calls_to_user_ids;
4341 let livekit_room;
4342 let delete_livekit_room;
4343 let room;
4344 let channel;
4345
4346 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4347 contacts_to_update.insert(session.user_id());
4348
4349 for project in left_room.left_projects.values() {
4350 project_left(project, session);
4351 }
4352
4353 room_id = RoomId::from_proto(left_room.room.id);
4354 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4355 livekit_room = mem::take(&mut left_room.room.livekit_room);
4356 delete_livekit_room = left_room.deleted;
4357 room = mem::take(&mut left_room.room);
4358 channel = mem::take(&mut left_room.channel);
4359
4360 room_updated(&room, &session.peer);
4361 } else {
4362 return Ok(());
4363 }
4364
4365 if let Some(channel) = channel {
4366 channel_updated(
4367 &channel,
4368 &room,
4369 &session.peer,
4370 &*session.connection_pool().await,
4371 );
4372 }
4373
4374 {
4375 let pool = session.connection_pool().await;
4376 for canceled_user_id in canceled_calls_to_user_ids {
4377 for connection_id in pool.user_connection_ids(canceled_user_id) {
4378 session
4379 .peer
4380 .send(
4381 connection_id,
4382 proto::CallCanceled {
4383 room_id: room_id.to_proto(),
4384 },
4385 )
4386 .trace_err();
4387 }
4388 contacts_to_update.insert(canceled_user_id);
4389 }
4390 }
4391
4392 for contact_user_id in contacts_to_update {
4393 update_user_contacts(contact_user_id, session).await?;
4394 }
4395
4396 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4397 live_kit
4398 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4399 .await
4400 .trace_err();
4401
4402 if delete_livekit_room {
4403 live_kit.delete_room(livekit_room).await.trace_err();
4404 }
4405 }
4406
4407 Ok(())
4408}
4409
4410async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4411 let left_channel_buffers = session
4412 .db()
4413 .await
4414 .leave_channel_buffers(session.connection_id)
4415 .await?;
4416
4417 for left_buffer in left_channel_buffers {
4418 channel_buffer_updated(
4419 session.connection_id,
4420 left_buffer.connections,
4421 &proto::UpdateChannelBufferCollaborators {
4422 channel_id: left_buffer.channel_id.to_proto(),
4423 collaborators: left_buffer.collaborators,
4424 },
4425 &session.peer,
4426 );
4427 }
4428
4429 Ok(())
4430}
4431
4432fn project_left(project: &db::LeftProject, session: &Session) {
4433 for connection_id in &project.connection_ids {
4434 if project.should_unshare {
4435 session
4436 .peer
4437 .send(
4438 *connection_id,
4439 proto::UnshareProject {
4440 project_id: project.id.to_proto(),
4441 },
4442 )
4443 .trace_err();
4444 } else {
4445 session
4446 .peer
4447 .send(
4448 *connection_id,
4449 proto::RemoveProjectCollaborator {
4450 project_id: project.id.to_proto(),
4451 peer_id: Some(session.connection_id.into()),
4452 },
4453 )
4454 .trace_err();
4455 }
4456 }
4457}
4458
4459pub trait ResultExt {
4460 type Ok;
4461
4462 fn trace_err(self) -> Option<Self::Ok>;
4463}
4464
4465impl<T, E> ResultExt for Result<T, E>
4466where
4467 E: std::fmt::Debug,
4468{
4469 type Ok = T;
4470
4471 #[track_caller]
4472 fn trace_err(self) -> Option<T> {
4473 match self {
4474 Ok(value) => Some(value),
4475 Err(error) => {
4476 tracing::error!("{:?}", error);
4477 None
4478 }
4479 }
4480 }
4481}