1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
8use crate::stripe_client::StripeCustomerId;
9use crate::{
10 AppState, Error, Result, auth,
11 db::{
12 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
13 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
14 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
15 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
16 },
17 executor::Executor,
18};
19use anyhow::{Context as _, anyhow, bail};
20use async_tungstenite::tungstenite::{
21 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
22};
23use axum::{
24 Extension, Router, TypedHeader,
25 body::Body,
26 extract::{
27 ConnectInfo, WebSocketUpgrade,
28 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
29 },
30 headers::{Header, HeaderName},
31 http::StatusCode,
32 middleware,
33 response::IntoResponse,
34 routing::get,
35};
36use chrono::Utc;
37use collections::{HashMap, HashSet};
38pub use connection_pool::{ConnectionPool, ZedVersion};
39use core::fmt::{self, Debug, Formatter};
40use reqwest_client::ReqwestClient;
41use rpc::proto::split_repository_update;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
46 stream::FuturesUnordered,
47};
48use prometheus::{IntGauge, register_int_gauge};
49use rpc::{
50 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
51 proto::{
52 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
53 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
54 },
55};
56use semantic_version::SemanticVersion;
57use serde::{Serialize, Serializer};
58use std::{
59 any::TypeId,
60 future::Future,
61 marker::PhantomData,
62 mem,
63 net::SocketAddr,
64 ops::{Deref, DerefMut},
65 rc::Rc,
66 sync::{
67 Arc, OnceLock,
68 atomic::{AtomicBool, Ordering::SeqCst},
69 },
70 time::{Duration, Instant},
71};
72use time::OffsetDateTime;
73use tokio::sync::{Semaphore, watch};
74use tower::ServiceBuilder;
75use tracing::{
76 Instrument,
77 field::{self},
78 info_span, instrument,
79};
80
81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
82
83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
85
86const MESSAGE_COUNT_PER_PAGE: usize = 100;
87const MAX_MESSAGE_LEN: usize = 1024;
88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
92
93struct Response<R> {
94 peer: Arc<Peer>,
95 receipt: Receipt<R>,
96 responded: Arc<AtomicBool>,
97}
98
99impl<R: RequestMessage> Response<R> {
100 fn send(self, payload: R::Response) -> Result<()> {
101 self.responded.store(true, SeqCst);
102 self.peer.respond(self.receipt, payload)?;
103 Ok(())
104 }
105}
106
107#[derive(Clone, Debug)]
108pub enum Principal {
109 User(User),
110 Impersonated { user: User, admin: User },
111}
112
113impl Principal {
114 fn user(&self) -> &User {
115 match self {
116 Principal::User(user) => user,
117 Principal::Impersonated { user, .. } => user,
118 }
119 }
120
121 fn update_span(&self, span: &tracing::Span) {
122 match &self {
123 Principal::User(user) => {
124 span.record("user_id", user.id.0);
125 span.record("login", &user.github_login);
126 }
127 Principal::Impersonated { user, admin } => {
128 span.record("user_id", user.id.0);
129 span.record("login", &user.github_login);
130 span.record("impersonator", &admin.github_login);
131 }
132 }
133 }
134}
135
136#[derive(Clone)]
137struct Session {
138 principal: Principal,
139 connection_id: ConnectionId,
140 db: Arc<tokio::sync::Mutex<DbHandle>>,
141 peer: Arc<Peer>,
142 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
143 app_state: Arc<AppState>,
144 supermaven_client: Option<Arc<SupermavenAdminApi>>,
145 /// The GeoIP country code for the user.
146 #[allow(unused)]
147 geoip_country_code: Option<String>,
148 system_id: Option<String>,
149 _executor: Executor,
150}
151
152impl Session {
153 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 let guard = self.db.lock().await;
157 #[cfg(test)]
158 tokio::task::yield_now().await;
159 guard
160 }
161
162 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 let guard = self.connection_pool.lock();
166 ConnectionPoolGuard {
167 guard,
168 _not_send: PhantomData,
169 }
170 }
171
172 fn is_staff(&self) -> bool {
173 match &self.principal {
174 Principal::User(user) => user.admin,
175 Principal::Impersonated { .. } => true,
176 }
177 }
178
179 fn user_id(&self) -> UserId {
180 match &self.principal {
181 Principal::User(user) => user.id,
182 Principal::Impersonated { user, .. } => user.id,
183 }
184 }
185
186 pub fn email(&self) -> Option<String> {
187 match &self.principal {
188 Principal::User(user) => user.email_address.clone(),
189 Principal::Impersonated { user, .. } => user.email_address.clone(),
190 }
191 }
192}
193
194impl Debug for Session {
195 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
196 let mut result = f.debug_struct("Session");
197 match &self.principal {
198 Principal::User(user) => {
199 result.field("user", &user.github_login);
200 }
201 Principal::Impersonated { user, admin } => {
202 result.field("user", &user.github_login);
203 result.field("impersonator", &admin.github_login);
204 }
205 }
206 result.field("connection_id", &self.connection_id).finish()
207 }
208}
209
210struct DbHandle(Arc<Database>);
211
212impl Deref for DbHandle {
213 type Target = Database;
214
215 fn deref(&self) -> &Self::Target {
216 self.0.as_ref()
217 }
218}
219
220pub struct Server {
221 id: parking_lot::Mutex<ServerId>,
222 peer: Arc<Peer>,
223 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
224 app_state: Arc<AppState>,
225 handlers: HashMap<TypeId, MessageHandler>,
226 teardown: watch::Sender<bool>,
227}
228
229pub(crate) struct ConnectionPoolGuard<'a> {
230 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
231 _not_send: PhantomData<Rc<()>>,
232}
233
234#[derive(Serialize)]
235pub struct ServerSnapshot<'a> {
236 peer: &'a Peer,
237 #[serde(serialize_with = "serialize_deref")]
238 connection_pool: ConnectionPoolGuard<'a>,
239}
240
241pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
242where
243 S: Serializer,
244 T: Deref<Target = U>,
245 U: Serialize,
246{
247 Serialize::serialize(value.deref(), serializer)
248}
249
250impl Server {
251 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
252 let mut server = Self {
253 id: parking_lot::Mutex::new(id),
254 peer: Peer::new(id.0 as u32),
255 app_state: app_state.clone(),
256 connection_pool: Default::default(),
257 handlers: Default::default(),
258 teardown: watch::channel(false).0,
259 };
260
261 server
262 .add_request_handler(ping)
263 .add_request_handler(create_room)
264 .add_request_handler(join_room)
265 .add_request_handler(rejoin_room)
266 .add_request_handler(leave_room)
267 .add_request_handler(set_room_participant_role)
268 .add_request_handler(call)
269 .add_request_handler(cancel_call)
270 .add_message_handler(decline_call)
271 .add_request_handler(update_participant_location)
272 .add_request_handler(share_project)
273 .add_message_handler(unshare_project)
274 .add_request_handler(join_project)
275 .add_message_handler(leave_project)
276 .add_request_handler(update_project)
277 .add_request_handler(update_worktree)
278 .add_request_handler(update_repository)
279 .add_request_handler(remove_repository)
280 .add_message_handler(start_language_server)
281 .add_message_handler(update_language_server)
282 .add_message_handler(update_diagnostic_summary)
283 .add_message_handler(update_worktree_settings)
284 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
285 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
286 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
287 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
288 .add_request_handler(forward_find_search_candidates_request)
289 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
290 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
291 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
292 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
293 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
294 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
295 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
296 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
297 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
298 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
299 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
300 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
301 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
302 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
303 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
304 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
305 .add_request_handler(
306 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
307 )
308 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
309 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
310 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
311 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
312 .add_request_handler(
313 forward_read_only_project_request::<proto::LanguageServerIdForName>,
314 )
315 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
316 .add_request_handler(
317 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
318 )
319 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
320 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
321 .add_request_handler(
322 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
323 )
324 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
325 .add_request_handler(
326 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
327 )
328 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
329 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
330 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
331 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
332 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
333 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
334 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
335 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
337 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
338 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
339 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
340 .add_request_handler(
341 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
342 )
343 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
344 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
345 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
346 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
347 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
348 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
349 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
350 .add_message_handler(create_buffer_for_peer)
351 .add_request_handler(update_buffer)
352 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
353 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
354 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
355 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
356 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
357 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
358 .add_message_handler(
359 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
360 )
361 .add_request_handler(get_users)
362 .add_request_handler(fuzzy_search_users)
363 .add_request_handler(request_contact)
364 .add_request_handler(remove_contact)
365 .add_request_handler(respond_to_contact_request)
366 .add_message_handler(subscribe_to_channels)
367 .add_request_handler(create_channel)
368 .add_request_handler(delete_channel)
369 .add_request_handler(invite_channel_member)
370 .add_request_handler(remove_channel_member)
371 .add_request_handler(set_channel_member_role)
372 .add_request_handler(set_channel_visibility)
373 .add_request_handler(rename_channel)
374 .add_request_handler(join_channel_buffer)
375 .add_request_handler(leave_channel_buffer)
376 .add_message_handler(update_channel_buffer)
377 .add_request_handler(rejoin_channel_buffers)
378 .add_request_handler(get_channel_members)
379 .add_request_handler(respond_to_channel_invite)
380 .add_request_handler(join_channel)
381 .add_request_handler(join_channel_chat)
382 .add_message_handler(leave_channel_chat)
383 .add_request_handler(send_channel_message)
384 .add_request_handler(remove_channel_message)
385 .add_request_handler(update_channel_message)
386 .add_request_handler(get_channel_messages)
387 .add_request_handler(get_channel_messages_by_id)
388 .add_request_handler(get_notifications)
389 .add_request_handler(mark_notification_as_read)
390 .add_request_handler(move_channel)
391 .add_request_handler(reorder_channel)
392 .add_request_handler(follow)
393 .add_message_handler(unfollow)
394 .add_message_handler(update_followers)
395 .add_request_handler(get_private_user_info)
396 .add_request_handler(get_llm_api_token)
397 .add_request_handler(accept_terms_of_service)
398 .add_message_handler(acknowledge_channel_message)
399 .add_message_handler(acknowledge_buffer_version)
400 .add_request_handler(get_supermaven_api_key)
401 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
402 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
403 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
404 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
405 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
406 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
407 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
408 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
409 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
410 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
411 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
412 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
413 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
414 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
415 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
416 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
417 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
418 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
419 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
420 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
421 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
422 .add_message_handler(update_context);
423
424 Arc::new(server)
425 }
426
427 pub async fn start(&self) -> Result<()> {
428 let server_id = *self.id.lock();
429 let app_state = self.app_state.clone();
430 let peer = self.peer.clone();
431 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
432 let pool = self.connection_pool.clone();
433 let livekit_client = self.app_state.livekit_client.clone();
434
435 let span = info_span!("start server");
436 self.app_state.executor.spawn_detached(
437 async move {
438 tracing::info!("waiting for cleanup timeout");
439 timeout.await;
440 tracing::info!("cleanup timeout expired, retrieving stale rooms");
441
442 app_state
443 .db
444 .delete_stale_channel_chat_participants(
445 &app_state.config.zed_environment,
446 server_id,
447 )
448 .await
449 .trace_err();
450
451 if let Some((room_ids, channel_ids)) = app_state
452 .db
453 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
454 .await
455 .trace_err()
456 {
457 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
458 tracing::info!(
459 stale_channel_buffer_count = channel_ids.len(),
460 "retrieved stale channel buffers"
461 );
462
463 for channel_id in channel_ids {
464 if let Some(refreshed_channel_buffer) = app_state
465 .db
466 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
467 .await
468 .trace_err()
469 {
470 for connection_id in refreshed_channel_buffer.connection_ids {
471 peer.send(
472 connection_id,
473 proto::UpdateChannelBufferCollaborators {
474 channel_id: channel_id.to_proto(),
475 collaborators: refreshed_channel_buffer
476 .collaborators
477 .clone(),
478 },
479 )
480 .trace_err();
481 }
482 }
483 }
484
485 for room_id in room_ids {
486 let mut contacts_to_update = HashSet::default();
487 let mut canceled_calls_to_user_ids = Vec::new();
488 let mut livekit_room = String::new();
489 let mut delete_livekit_room = false;
490
491 if let Some(mut refreshed_room) = app_state
492 .db
493 .clear_stale_room_participants(room_id, server_id)
494 .await
495 .trace_err()
496 {
497 tracing::info!(
498 room_id = room_id.0,
499 new_participant_count = refreshed_room.room.participants.len(),
500 "refreshed room"
501 );
502 room_updated(&refreshed_room.room, &peer);
503 if let Some(channel) = refreshed_room.channel.as_ref() {
504 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
505 }
506 contacts_to_update
507 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
508 contacts_to_update
509 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
510 canceled_calls_to_user_ids =
511 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
512 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
513 delete_livekit_room = refreshed_room.room.participants.is_empty();
514 }
515
516 {
517 let pool = pool.lock();
518 for canceled_user_id in canceled_calls_to_user_ids {
519 for connection_id in pool.user_connection_ids(canceled_user_id) {
520 peer.send(
521 connection_id,
522 proto::CallCanceled {
523 room_id: room_id.to_proto(),
524 },
525 )
526 .trace_err();
527 }
528 }
529 }
530
531 for user_id in contacts_to_update {
532 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
533 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
534 if let Some((busy, contacts)) = busy.zip(contacts) {
535 let pool = pool.lock();
536 let updated_contact = contact_for_user(user_id, busy, &pool);
537 for contact in contacts {
538 if let db::Contact::Accepted {
539 user_id: contact_user_id,
540 ..
541 } = contact
542 {
543 for contact_conn_id in
544 pool.user_connection_ids(contact_user_id)
545 {
546 peer.send(
547 contact_conn_id,
548 proto::UpdateContacts {
549 contacts: vec![updated_contact.clone()],
550 remove_contacts: Default::default(),
551 incoming_requests: Default::default(),
552 remove_incoming_requests: Default::default(),
553 outgoing_requests: Default::default(),
554 remove_outgoing_requests: Default::default(),
555 },
556 )
557 .trace_err();
558 }
559 }
560 }
561 }
562 }
563
564 if let Some(live_kit) = livekit_client.as_ref() {
565 if delete_livekit_room {
566 live_kit.delete_room(livekit_room).await.trace_err();
567 }
568 }
569 }
570 }
571
572 app_state
573 .db
574 .delete_stale_channel_chat_participants(
575 &app_state.config.zed_environment,
576 server_id,
577 )
578 .await
579 .trace_err();
580
581 app_state
582 .db
583 .clear_old_worktree_entries(server_id)
584 .await
585 .trace_err();
586
587 app_state
588 .db
589 .delete_stale_servers(&app_state.config.zed_environment, server_id)
590 .await
591 .trace_err();
592 }
593 .instrument(span),
594 );
595 Ok(())
596 }
597
598 pub fn teardown(&self) {
599 self.peer.teardown();
600 self.connection_pool.lock().reset();
601 let _ = self.teardown.send(true);
602 }
603
604 #[cfg(test)]
605 pub fn reset(&self, id: ServerId) {
606 self.teardown();
607 *self.id.lock() = id;
608 self.peer.reset(id.0 as u32);
609 let _ = self.teardown.send(false);
610 }
611
612 #[cfg(test)]
613 pub fn id(&self) -> ServerId {
614 *self.id.lock()
615 }
616
617 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
618 where
619 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
620 Fut: 'static + Send + Future<Output = Result<()>>,
621 M: EnvelopedMessage,
622 {
623 let prev_handler = self.handlers.insert(
624 TypeId::of::<M>(),
625 Box::new(move |envelope, session| {
626 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
627 let received_at = envelope.received_at;
628 tracing::info!("message received");
629 let start_time = Instant::now();
630 let future = (handler)(*envelope, session);
631 async move {
632 let result = future.await;
633 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
634 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
635 let queue_duration_ms = total_duration_ms - processing_duration_ms;
636 let payload_type = M::NAME;
637
638 match result {
639 Err(error) => {
640 tracing::error!(
641 ?error,
642 total_duration_ms,
643 processing_duration_ms,
644 queue_duration_ms,
645 payload_type,
646 "error handling message"
647 )
648 }
649 Ok(()) => tracing::info!(
650 total_duration_ms,
651 processing_duration_ms,
652 queue_duration_ms,
653 "finished handling message"
654 ),
655 }
656 }
657 .boxed()
658 }),
659 );
660 if prev_handler.is_some() {
661 panic!("registered a handler for the same message twice");
662 }
663 self
664 }
665
666 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
667 where
668 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
669 Fut: 'static + Send + Future<Output = Result<()>>,
670 M: EnvelopedMessage,
671 {
672 self.add_handler(move |envelope, session| handler(envelope.payload, session));
673 self
674 }
675
676 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
677 where
678 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
679 Fut: Send + Future<Output = Result<()>>,
680 M: RequestMessage,
681 {
682 let handler = Arc::new(handler);
683 self.add_handler(move |envelope, session| {
684 let receipt = envelope.receipt();
685 let handler = handler.clone();
686 async move {
687 let peer = session.peer.clone();
688 let responded = Arc::new(AtomicBool::default());
689 let response = Response {
690 peer: peer.clone(),
691 responded: responded.clone(),
692 receipt,
693 };
694 match (handler)(envelope.payload, response, session).await {
695 Ok(()) => {
696 if responded.load(std::sync::atomic::Ordering::SeqCst) {
697 Ok(())
698 } else {
699 Err(anyhow!("handler did not send a response"))?
700 }
701 }
702 Err(error) => {
703 let proto_err = match &error {
704 Error::Internal(err) => err.to_proto(),
705 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
706 };
707 peer.respond_with_error(receipt, proto_err)?;
708 Err(error)
709 }
710 }
711 }
712 })
713 }
714
715 pub fn handle_connection(
716 self: &Arc<Self>,
717 connection: Connection,
718 address: String,
719 principal: Principal,
720 zed_version: ZedVersion,
721 geoip_country_code: Option<String>,
722 system_id: Option<String>,
723 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
724 executor: Executor,
725 ) -> impl Future<Output = ()> + use<> {
726 let this = self.clone();
727 let span = info_span!("handle connection", %address,
728 connection_id=field::Empty,
729 user_id=field::Empty,
730 login=field::Empty,
731 impersonator=field::Empty,
732 geoip_country_code=field::Empty
733 );
734 principal.update_span(&span);
735 if let Some(country_code) = geoip_country_code.as_ref() {
736 span.record("geoip_country_code", country_code);
737 }
738
739 let mut teardown = self.teardown.subscribe();
740 async move {
741 if *teardown.borrow() {
742 tracing::error!("server is tearing down");
743 return
744 }
745 let (connection_id, handle_io, mut incoming_rx) = this
746 .peer
747 .add_connection(connection, {
748 let executor = executor.clone();
749 move |duration| executor.sleep(duration)
750 });
751 tracing::Span::current().record("connection_id", format!("{}", connection_id));
752
753 tracing::info!("connection opened");
754
755 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
756 let http_client = match ReqwestClient::user_agent(&user_agent) {
757 Ok(http_client) => Arc::new(http_client),
758 Err(error) => {
759 tracing::error!(?error, "failed to create HTTP client");
760 return;
761 }
762 };
763
764 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
765 supermaven_admin_api_key.to_string(),
766 http_client.clone(),
767 )));
768
769 let session = Session {
770 principal: principal.clone(),
771 connection_id,
772 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
773 peer: this.peer.clone(),
774 connection_pool: this.connection_pool.clone(),
775 app_state: this.app_state.clone(),
776 geoip_country_code,
777 system_id,
778 _executor: executor.clone(),
779 supermaven_client,
780 };
781
782 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
783 tracing::error!(?error, "failed to send initial client update");
784 return;
785 }
786
787 let handle_io = handle_io.fuse();
788 futures::pin_mut!(handle_io);
789
790 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
791 // This prevents deadlocks when e.g., client A performs a request to client B and
792 // client B performs a request to client A. If both clients stop processing further
793 // messages until their respective request completes, they won't have a chance to
794 // respond to the other client's request and cause a deadlock.
795 //
796 // This arrangement ensures we will attempt to process earlier messages first, but fall
797 // back to processing messages arrived later in the spirit of making progress.
798 let mut foreground_message_handlers = FuturesUnordered::new();
799 let concurrent_handlers = Arc::new(Semaphore::new(256));
800 loop {
801 let next_message = async {
802 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
803 let message = incoming_rx.next().await;
804 (permit, message)
805 }.fuse();
806 futures::pin_mut!(next_message);
807 futures::select_biased! {
808 _ = teardown.changed().fuse() => return,
809 result = handle_io => {
810 if let Err(error) = result {
811 tracing::error!(?error, "error handling I/O");
812 }
813 break;
814 }
815 _ = foreground_message_handlers.next() => {}
816 next_message = next_message => {
817 let (permit, message) = next_message;
818 if let Some(message) = message {
819 let type_name = message.payload_type_name();
820 // note: we copy all the fields from the parent span so we can query them in the logs.
821 // (https://github.com/tokio-rs/tracing/issues/2670).
822 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
823 user_id=field::Empty,
824 login=field::Empty,
825 impersonator=field::Empty,
826 );
827 principal.update_span(&span);
828 let span_enter = span.enter();
829 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
830 let is_background = message.is_background();
831 let handle_message = (handler)(message, session.clone());
832 drop(span_enter);
833
834 let handle_message = async move {
835 handle_message.await;
836 drop(permit);
837 }.instrument(span);
838 if is_background {
839 executor.spawn_detached(handle_message);
840 } else {
841 foreground_message_handlers.push(handle_message);
842 }
843 } else {
844 tracing::error!("no message handler");
845 }
846 } else {
847 tracing::info!("connection closed");
848 break;
849 }
850 }
851 }
852 }
853
854 drop(foreground_message_handlers);
855 tracing::info!("signing out");
856 if let Err(error) = connection_lost(session, teardown, executor).await {
857 tracing::error!(?error, "error signing out");
858 }
859
860 }.instrument(span)
861 }
862
863 async fn send_initial_client_update(
864 &self,
865 connection_id: ConnectionId,
866 zed_version: ZedVersion,
867 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
868 session: &Session,
869 ) -> Result<()> {
870 self.peer.send(
871 connection_id,
872 proto::Hello {
873 peer_id: Some(connection_id.into()),
874 },
875 )?;
876 tracing::info!("sent hello message");
877 if let Some(send_connection_id) = send_connection_id.take() {
878 let _ = send_connection_id.send(connection_id);
879 }
880
881 match &session.principal {
882 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
883 if !user.connected_once {
884 self.peer.send(connection_id, proto::ShowContacts {})?;
885 self.app_state
886 .db
887 .set_user_connected_once(user.id, true)
888 .await?;
889 }
890
891 update_user_plan(session).await?;
892
893 let contacts = self.app_state.db.get_contacts(user.id).await?;
894
895 {
896 let mut pool = self.connection_pool.lock();
897 pool.add_connection(connection_id, user.id, user.admin, zed_version);
898 self.peer.send(
899 connection_id,
900 build_initial_contacts_update(contacts, &pool),
901 )?;
902 }
903
904 if should_auto_subscribe_to_channels(zed_version) {
905 subscribe_user_to_channels(user.id, session).await?;
906 }
907
908 if let Some(incoming_call) =
909 self.app_state.db.incoming_call_for_user(user.id).await?
910 {
911 self.peer.send(connection_id, incoming_call)?;
912 }
913
914 update_user_contacts(user.id, session).await?;
915 }
916 }
917
918 Ok(())
919 }
920
921 pub async fn invite_code_redeemed(
922 self: &Arc<Self>,
923 inviter_id: UserId,
924 invitee_id: UserId,
925 ) -> Result<()> {
926 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
927 if let Some(code) = &user.invite_code {
928 let pool = self.connection_pool.lock();
929 let invitee_contact = contact_for_user(invitee_id, false, &pool);
930 for connection_id in pool.user_connection_ids(inviter_id) {
931 self.peer.send(
932 connection_id,
933 proto::UpdateContacts {
934 contacts: vec![invitee_contact.clone()],
935 ..Default::default()
936 },
937 )?;
938 self.peer.send(
939 connection_id,
940 proto::UpdateInviteInfo {
941 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
942 count: user.invite_count as u32,
943 },
944 )?;
945 }
946 }
947 }
948 Ok(())
949 }
950
951 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
952 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
953 if let Some(invite_code) = &user.invite_code {
954 let pool = self.connection_pool.lock();
955 for connection_id in pool.user_connection_ids(user_id) {
956 self.peer.send(
957 connection_id,
958 proto::UpdateInviteInfo {
959 url: format!(
960 "{}{}",
961 self.app_state.config.invite_link_prefix, invite_code
962 ),
963 count: user.invite_count as u32,
964 },
965 )?;
966 }
967 }
968 }
969 Ok(())
970 }
971
972 pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
973 let user = self
974 .app_state
975 .db
976 .get_user_by_id(user_id)
977 .await?
978 .context("user not found")?;
979
980 let update_user_plan = make_update_user_plan_message(
981 &user,
982 user.admin,
983 &self.app_state.db,
984 self.app_state.llm_db.clone(),
985 )
986 .await?;
987
988 let pool = self.connection_pool.lock();
989 for connection_id in pool.user_connection_ids(user_id) {
990 self.peer
991 .send(connection_id, update_user_plan.clone())
992 .trace_err();
993 }
994
995 Ok(())
996 }
997
998 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
999 let pool = self.connection_pool.lock();
1000 for connection_id in pool.user_connection_ids(user_id) {
1001 self.peer
1002 .send(connection_id, proto::RefreshLlmToken {})
1003 .trace_err();
1004 }
1005 }
1006
1007 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1008 ServerSnapshot {
1009 connection_pool: ConnectionPoolGuard {
1010 guard: self.connection_pool.lock(),
1011 _not_send: PhantomData,
1012 },
1013 peer: &self.peer,
1014 }
1015 }
1016}
1017
1018impl Deref for ConnectionPoolGuard<'_> {
1019 type Target = ConnectionPool;
1020
1021 fn deref(&self) -> &Self::Target {
1022 &self.guard
1023 }
1024}
1025
1026impl DerefMut for ConnectionPoolGuard<'_> {
1027 fn deref_mut(&mut self) -> &mut Self::Target {
1028 &mut self.guard
1029 }
1030}
1031
1032impl Drop for ConnectionPoolGuard<'_> {
1033 fn drop(&mut self) {
1034 #[cfg(test)]
1035 self.check_invariants();
1036 }
1037}
1038
1039fn broadcast<F>(
1040 sender_id: Option<ConnectionId>,
1041 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1042 mut f: F,
1043) where
1044 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1045{
1046 for receiver_id in receiver_ids {
1047 if Some(receiver_id) != sender_id {
1048 if let Err(error) = f(receiver_id) {
1049 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1050 }
1051 }
1052 }
1053}
1054
1055pub struct ProtocolVersion(u32);
1056
1057impl Header for ProtocolVersion {
1058 fn name() -> &'static HeaderName {
1059 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1060 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1061 }
1062
1063 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1064 where
1065 Self: Sized,
1066 I: Iterator<Item = &'i axum::http::HeaderValue>,
1067 {
1068 let version = values
1069 .next()
1070 .ok_or_else(axum::headers::Error::invalid)?
1071 .to_str()
1072 .map_err(|_| axum::headers::Error::invalid())?
1073 .parse()
1074 .map_err(|_| axum::headers::Error::invalid())?;
1075 Ok(Self(version))
1076 }
1077
1078 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1079 values.extend([self.0.to_string().parse().unwrap()]);
1080 }
1081}
1082
1083pub struct AppVersionHeader(SemanticVersion);
1084impl Header for AppVersionHeader {
1085 fn name() -> &'static HeaderName {
1086 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1087 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1088 }
1089
1090 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1091 where
1092 Self: Sized,
1093 I: Iterator<Item = &'i axum::http::HeaderValue>,
1094 {
1095 let version = values
1096 .next()
1097 .ok_or_else(axum::headers::Error::invalid)?
1098 .to_str()
1099 .map_err(|_| axum::headers::Error::invalid())?
1100 .parse()
1101 .map_err(|_| axum::headers::Error::invalid())?;
1102 Ok(Self(version))
1103 }
1104
1105 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1106 values.extend([self.0.to_string().parse().unwrap()]);
1107 }
1108}
1109
1110pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1111 Router::new()
1112 .route("/rpc", get(handle_websocket_request))
1113 .layer(
1114 ServiceBuilder::new()
1115 .layer(Extension(server.app_state.clone()))
1116 .layer(middleware::from_fn(auth::validate_header)),
1117 )
1118 .route("/metrics", get(handle_metrics))
1119 .layer(Extension(server))
1120}
1121
1122pub async fn handle_websocket_request(
1123 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1124 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1125 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1126 Extension(server): Extension<Arc<Server>>,
1127 Extension(principal): Extension<Principal>,
1128 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1129 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1130 ws: WebSocketUpgrade,
1131) -> axum::response::Response {
1132 if protocol_version != rpc::PROTOCOL_VERSION {
1133 return (
1134 StatusCode::UPGRADE_REQUIRED,
1135 "client must be upgraded".to_string(),
1136 )
1137 .into_response();
1138 }
1139
1140 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1141 return (
1142 StatusCode::UPGRADE_REQUIRED,
1143 "no version header found".to_string(),
1144 )
1145 .into_response();
1146 };
1147
1148 if !version.can_collaborate() {
1149 return (
1150 StatusCode::UPGRADE_REQUIRED,
1151 "client must be upgraded".to_string(),
1152 )
1153 .into_response();
1154 }
1155
1156 let socket_address = socket_address.to_string();
1157 ws.on_upgrade(move |socket| {
1158 let socket = socket
1159 .map_ok(to_tungstenite_message)
1160 .err_into()
1161 .with(|message| async move { to_axum_message(message) });
1162 let connection = Connection::new(Box::pin(socket));
1163 async move {
1164 server
1165 .handle_connection(
1166 connection,
1167 socket_address,
1168 principal,
1169 version,
1170 country_code_header.map(|header| header.to_string()),
1171 system_id_header.map(|header| header.to_string()),
1172 None,
1173 Executor::Production,
1174 )
1175 .await;
1176 }
1177 })
1178}
1179
1180pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1181 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1182 let connections_metric = CONNECTIONS_METRIC
1183 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1184
1185 let connections = server
1186 .connection_pool
1187 .lock()
1188 .connections()
1189 .filter(|connection| !connection.admin)
1190 .count();
1191 connections_metric.set(connections as _);
1192
1193 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1194 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1195 register_int_gauge!(
1196 "shared_projects",
1197 "number of open projects with one or more guests"
1198 )
1199 .unwrap()
1200 });
1201
1202 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1203 shared_projects_metric.set(shared_projects as _);
1204
1205 let encoder = prometheus::TextEncoder::new();
1206 let metric_families = prometheus::gather();
1207 let encoded_metrics = encoder
1208 .encode_to_string(&metric_families)
1209 .map_err(|err| anyhow!("{err}"))?;
1210 Ok(encoded_metrics)
1211}
1212
1213#[instrument(err, skip(executor))]
1214async fn connection_lost(
1215 session: Session,
1216 mut teardown: watch::Receiver<bool>,
1217 executor: Executor,
1218) -> Result<()> {
1219 session.peer.disconnect(session.connection_id);
1220 session
1221 .connection_pool()
1222 .await
1223 .remove_connection(session.connection_id)?;
1224
1225 session
1226 .db()
1227 .await
1228 .connection_lost(session.connection_id)
1229 .await
1230 .trace_err();
1231
1232 futures::select_biased! {
1233 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1234
1235 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1236 leave_room_for_session(&session, session.connection_id).await.trace_err();
1237 leave_channel_buffers_for_session(&session)
1238 .await
1239 .trace_err();
1240
1241 if !session
1242 .connection_pool()
1243 .await
1244 .is_user_online(session.user_id())
1245 {
1246 let db = session.db().await;
1247 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1248 room_updated(&room, &session.peer);
1249 }
1250 }
1251
1252 update_user_contacts(session.user_id(), &session).await?;
1253 },
1254 _ = teardown.changed().fuse() => {}
1255 }
1256
1257 Ok(())
1258}
1259
1260/// Acknowledges a ping from a client, used to keep the connection alive.
1261async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1262 response.send(proto::Ack {})?;
1263 Ok(())
1264}
1265
1266/// Creates a new room for calling (outside of channels)
1267async fn create_room(
1268 _request: proto::CreateRoom,
1269 response: Response<proto::CreateRoom>,
1270 session: Session,
1271) -> Result<()> {
1272 let livekit_room = nanoid::nanoid!(30);
1273
1274 let live_kit_connection_info = util::maybe!(async {
1275 let live_kit = session.app_state.livekit_client.as_ref();
1276 let live_kit = live_kit?;
1277 let user_id = session.user_id().to_string();
1278
1279 let token = live_kit
1280 .room_token(&livekit_room, &user_id.to_string())
1281 .trace_err()?;
1282
1283 Some(proto::LiveKitConnectionInfo {
1284 server_url: live_kit.url().into(),
1285 token,
1286 can_publish: true,
1287 })
1288 })
1289 .await;
1290
1291 let room = session
1292 .db()
1293 .await
1294 .create_room(session.user_id(), session.connection_id, &livekit_room)
1295 .await?;
1296
1297 response.send(proto::CreateRoomResponse {
1298 room: Some(room.clone()),
1299 live_kit_connection_info,
1300 })?;
1301
1302 update_user_contacts(session.user_id(), &session).await?;
1303 Ok(())
1304}
1305
1306/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1307async fn join_room(
1308 request: proto::JoinRoom,
1309 response: Response<proto::JoinRoom>,
1310 session: Session,
1311) -> Result<()> {
1312 let room_id = RoomId::from_proto(request.id);
1313
1314 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1315
1316 if let Some(channel_id) = channel_id {
1317 return join_channel_internal(channel_id, Box::new(response), session).await;
1318 }
1319
1320 let joined_room = {
1321 let room = session
1322 .db()
1323 .await
1324 .join_room(room_id, session.user_id(), session.connection_id)
1325 .await?;
1326 room_updated(&room.room, &session.peer);
1327 room.into_inner()
1328 };
1329
1330 for connection_id in session
1331 .connection_pool()
1332 .await
1333 .user_connection_ids(session.user_id())
1334 {
1335 session
1336 .peer
1337 .send(
1338 connection_id,
1339 proto::CallCanceled {
1340 room_id: room_id.to_proto(),
1341 },
1342 )
1343 .trace_err();
1344 }
1345
1346 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1347 {
1348 live_kit
1349 .room_token(
1350 &joined_room.room.livekit_room,
1351 &session.user_id().to_string(),
1352 )
1353 .trace_err()
1354 .map(|token| proto::LiveKitConnectionInfo {
1355 server_url: live_kit.url().into(),
1356 token,
1357 can_publish: true,
1358 })
1359 } else {
1360 None
1361 };
1362
1363 response.send(proto::JoinRoomResponse {
1364 room: Some(joined_room.room),
1365 channel_id: None,
1366 live_kit_connection_info,
1367 })?;
1368
1369 update_user_contacts(session.user_id(), &session).await?;
1370 Ok(())
1371}
1372
1373/// Rejoin room is used to reconnect to a room after connection errors.
1374async fn rejoin_room(
1375 request: proto::RejoinRoom,
1376 response: Response<proto::RejoinRoom>,
1377 session: Session,
1378) -> Result<()> {
1379 let room;
1380 let channel;
1381 {
1382 let mut rejoined_room = session
1383 .db()
1384 .await
1385 .rejoin_room(request, session.user_id(), session.connection_id)
1386 .await?;
1387
1388 response.send(proto::RejoinRoomResponse {
1389 room: Some(rejoined_room.room.clone()),
1390 reshared_projects: rejoined_room
1391 .reshared_projects
1392 .iter()
1393 .map(|project| proto::ResharedProject {
1394 id: project.id.to_proto(),
1395 collaborators: project
1396 .collaborators
1397 .iter()
1398 .map(|collaborator| collaborator.to_proto())
1399 .collect(),
1400 })
1401 .collect(),
1402 rejoined_projects: rejoined_room
1403 .rejoined_projects
1404 .iter()
1405 .map(|rejoined_project| rejoined_project.to_proto())
1406 .collect(),
1407 })?;
1408 room_updated(&rejoined_room.room, &session.peer);
1409
1410 for project in &rejoined_room.reshared_projects {
1411 for collaborator in &project.collaborators {
1412 session
1413 .peer
1414 .send(
1415 collaborator.connection_id,
1416 proto::UpdateProjectCollaborator {
1417 project_id: project.id.to_proto(),
1418 old_peer_id: Some(project.old_connection_id.into()),
1419 new_peer_id: Some(session.connection_id.into()),
1420 },
1421 )
1422 .trace_err();
1423 }
1424
1425 broadcast(
1426 Some(session.connection_id),
1427 project
1428 .collaborators
1429 .iter()
1430 .map(|collaborator| collaborator.connection_id),
1431 |connection_id| {
1432 session.peer.forward_send(
1433 session.connection_id,
1434 connection_id,
1435 proto::UpdateProject {
1436 project_id: project.id.to_proto(),
1437 worktrees: project.worktrees.clone(),
1438 },
1439 )
1440 },
1441 );
1442 }
1443
1444 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1445
1446 let rejoined_room = rejoined_room.into_inner();
1447
1448 room = rejoined_room.room;
1449 channel = rejoined_room.channel;
1450 }
1451
1452 if let Some(channel) = channel {
1453 channel_updated(
1454 &channel,
1455 &room,
1456 &session.peer,
1457 &*session.connection_pool().await,
1458 );
1459 }
1460
1461 update_user_contacts(session.user_id(), &session).await?;
1462 Ok(())
1463}
1464
1465fn notify_rejoined_projects(
1466 rejoined_projects: &mut Vec<RejoinedProject>,
1467 session: &Session,
1468) -> Result<()> {
1469 for project in rejoined_projects.iter() {
1470 for collaborator in &project.collaborators {
1471 session
1472 .peer
1473 .send(
1474 collaborator.connection_id,
1475 proto::UpdateProjectCollaborator {
1476 project_id: project.id.to_proto(),
1477 old_peer_id: Some(project.old_connection_id.into()),
1478 new_peer_id: Some(session.connection_id.into()),
1479 },
1480 )
1481 .trace_err();
1482 }
1483 }
1484
1485 for project in rejoined_projects {
1486 for worktree in mem::take(&mut project.worktrees) {
1487 // Stream this worktree's entries.
1488 let message = proto::UpdateWorktree {
1489 project_id: project.id.to_proto(),
1490 worktree_id: worktree.id,
1491 abs_path: worktree.abs_path.clone(),
1492 root_name: worktree.root_name,
1493 updated_entries: worktree.updated_entries,
1494 removed_entries: worktree.removed_entries,
1495 scan_id: worktree.scan_id,
1496 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1497 updated_repositories: worktree.updated_repositories,
1498 removed_repositories: worktree.removed_repositories,
1499 };
1500 for update in proto::split_worktree_update(message) {
1501 session.peer.send(session.connection_id, update)?;
1502 }
1503
1504 // Stream this worktree's diagnostics.
1505 for summary in worktree.diagnostic_summaries {
1506 session.peer.send(
1507 session.connection_id,
1508 proto::UpdateDiagnosticSummary {
1509 project_id: project.id.to_proto(),
1510 worktree_id: worktree.id,
1511 summary: Some(summary),
1512 },
1513 )?;
1514 }
1515
1516 for settings_file in worktree.settings_files {
1517 session.peer.send(
1518 session.connection_id,
1519 proto::UpdateWorktreeSettings {
1520 project_id: project.id.to_proto(),
1521 worktree_id: worktree.id,
1522 path: settings_file.path,
1523 content: Some(settings_file.content),
1524 kind: Some(settings_file.kind.to_proto().into()),
1525 },
1526 )?;
1527 }
1528 }
1529
1530 for repository in mem::take(&mut project.updated_repositories) {
1531 for update in split_repository_update(repository) {
1532 session.peer.send(session.connection_id, update)?;
1533 }
1534 }
1535
1536 for id in mem::take(&mut project.removed_repositories) {
1537 session.peer.send(
1538 session.connection_id,
1539 proto::RemoveRepository {
1540 project_id: project.id.to_proto(),
1541 id,
1542 },
1543 )?;
1544 }
1545 }
1546
1547 Ok(())
1548}
1549
1550/// leave room disconnects from the room.
1551async fn leave_room(
1552 _: proto::LeaveRoom,
1553 response: Response<proto::LeaveRoom>,
1554 session: Session,
1555) -> Result<()> {
1556 leave_room_for_session(&session, session.connection_id).await?;
1557 response.send(proto::Ack {})?;
1558 Ok(())
1559}
1560
1561/// Updates the permissions of someone else in the room.
1562async fn set_room_participant_role(
1563 request: proto::SetRoomParticipantRole,
1564 response: Response<proto::SetRoomParticipantRole>,
1565 session: Session,
1566) -> Result<()> {
1567 let user_id = UserId::from_proto(request.user_id);
1568 let role = ChannelRole::from(request.role());
1569
1570 let (livekit_room, can_publish) = {
1571 let room = session
1572 .db()
1573 .await
1574 .set_room_participant_role(
1575 session.user_id(),
1576 RoomId::from_proto(request.room_id),
1577 user_id,
1578 role,
1579 )
1580 .await?;
1581
1582 let livekit_room = room.livekit_room.clone();
1583 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1584 room_updated(&room, &session.peer);
1585 (livekit_room, can_publish)
1586 };
1587
1588 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1589 live_kit
1590 .update_participant(
1591 livekit_room.clone(),
1592 request.user_id.to_string(),
1593 livekit_api::proto::ParticipantPermission {
1594 can_subscribe: true,
1595 can_publish,
1596 can_publish_data: can_publish,
1597 hidden: false,
1598 recorder: false,
1599 },
1600 )
1601 .await
1602 .trace_err();
1603 }
1604
1605 response.send(proto::Ack {})?;
1606 Ok(())
1607}
1608
1609/// Call someone else into the current room
1610async fn call(
1611 request: proto::Call,
1612 response: Response<proto::Call>,
1613 session: Session,
1614) -> Result<()> {
1615 let room_id = RoomId::from_proto(request.room_id);
1616 let calling_user_id = session.user_id();
1617 let calling_connection_id = session.connection_id;
1618 let called_user_id = UserId::from_proto(request.called_user_id);
1619 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1620 if !session
1621 .db()
1622 .await
1623 .has_contact(calling_user_id, called_user_id)
1624 .await?
1625 {
1626 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1627 }
1628
1629 let incoming_call = {
1630 let (room, incoming_call) = &mut *session
1631 .db()
1632 .await
1633 .call(
1634 room_id,
1635 calling_user_id,
1636 calling_connection_id,
1637 called_user_id,
1638 initial_project_id,
1639 )
1640 .await?;
1641 room_updated(room, &session.peer);
1642 mem::take(incoming_call)
1643 };
1644 update_user_contacts(called_user_id, &session).await?;
1645
1646 let mut calls = session
1647 .connection_pool()
1648 .await
1649 .user_connection_ids(called_user_id)
1650 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1651 .collect::<FuturesUnordered<_>>();
1652
1653 while let Some(call_response) = calls.next().await {
1654 match call_response.as_ref() {
1655 Ok(_) => {
1656 response.send(proto::Ack {})?;
1657 return Ok(());
1658 }
1659 Err(_) => {
1660 call_response.trace_err();
1661 }
1662 }
1663 }
1664
1665 {
1666 let room = session
1667 .db()
1668 .await
1669 .call_failed(room_id, called_user_id)
1670 .await?;
1671 room_updated(&room, &session.peer);
1672 }
1673 update_user_contacts(called_user_id, &session).await?;
1674
1675 Err(anyhow!("failed to ring user"))?
1676}
1677
1678/// Cancel an outgoing call.
1679async fn cancel_call(
1680 request: proto::CancelCall,
1681 response: Response<proto::CancelCall>,
1682 session: Session,
1683) -> Result<()> {
1684 let called_user_id = UserId::from_proto(request.called_user_id);
1685 let room_id = RoomId::from_proto(request.room_id);
1686 {
1687 let room = session
1688 .db()
1689 .await
1690 .cancel_call(room_id, session.connection_id, called_user_id)
1691 .await?;
1692 room_updated(&room, &session.peer);
1693 }
1694
1695 for connection_id in session
1696 .connection_pool()
1697 .await
1698 .user_connection_ids(called_user_id)
1699 {
1700 session
1701 .peer
1702 .send(
1703 connection_id,
1704 proto::CallCanceled {
1705 room_id: room_id.to_proto(),
1706 },
1707 )
1708 .trace_err();
1709 }
1710 response.send(proto::Ack {})?;
1711
1712 update_user_contacts(called_user_id, &session).await?;
1713 Ok(())
1714}
1715
1716/// Decline an incoming call.
1717async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1718 let room_id = RoomId::from_proto(message.room_id);
1719 {
1720 let room = session
1721 .db()
1722 .await
1723 .decline_call(Some(room_id), session.user_id())
1724 .await?
1725 .context("declining call")?;
1726 room_updated(&room, &session.peer);
1727 }
1728
1729 for connection_id in session
1730 .connection_pool()
1731 .await
1732 .user_connection_ids(session.user_id())
1733 {
1734 session
1735 .peer
1736 .send(
1737 connection_id,
1738 proto::CallCanceled {
1739 room_id: room_id.to_proto(),
1740 },
1741 )
1742 .trace_err();
1743 }
1744 update_user_contacts(session.user_id(), &session).await?;
1745 Ok(())
1746}
1747
1748/// Updates other participants in the room with your current location.
1749async fn update_participant_location(
1750 request: proto::UpdateParticipantLocation,
1751 response: Response<proto::UpdateParticipantLocation>,
1752 session: Session,
1753) -> Result<()> {
1754 let room_id = RoomId::from_proto(request.room_id);
1755 let location = request.location.context("invalid location")?;
1756
1757 let db = session.db().await;
1758 let room = db
1759 .update_room_participant_location(room_id, session.connection_id, location)
1760 .await?;
1761
1762 room_updated(&room, &session.peer);
1763 response.send(proto::Ack {})?;
1764 Ok(())
1765}
1766
1767/// Share a project into the room.
1768async fn share_project(
1769 request: proto::ShareProject,
1770 response: Response<proto::ShareProject>,
1771 session: Session,
1772) -> Result<()> {
1773 let (project_id, room) = &*session
1774 .db()
1775 .await
1776 .share_project(
1777 RoomId::from_proto(request.room_id),
1778 session.connection_id,
1779 &request.worktrees,
1780 request.is_ssh_project,
1781 )
1782 .await?;
1783 response.send(proto::ShareProjectResponse {
1784 project_id: project_id.to_proto(),
1785 })?;
1786 room_updated(room, &session.peer);
1787
1788 Ok(())
1789}
1790
1791/// Unshare a project from the room.
1792async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1793 let project_id = ProjectId::from_proto(message.project_id);
1794 unshare_project_internal(project_id, session.connection_id, &session).await
1795}
1796
1797async fn unshare_project_internal(
1798 project_id: ProjectId,
1799 connection_id: ConnectionId,
1800 session: &Session,
1801) -> Result<()> {
1802 let delete = {
1803 let room_guard = session
1804 .db()
1805 .await
1806 .unshare_project(project_id, connection_id)
1807 .await?;
1808
1809 let (delete, room, guest_connection_ids) = &*room_guard;
1810
1811 let message = proto::UnshareProject {
1812 project_id: project_id.to_proto(),
1813 };
1814
1815 broadcast(
1816 Some(connection_id),
1817 guest_connection_ids.iter().copied(),
1818 |conn_id| session.peer.send(conn_id, message.clone()),
1819 );
1820 if let Some(room) = room {
1821 room_updated(room, &session.peer);
1822 }
1823
1824 *delete
1825 };
1826
1827 if delete {
1828 let db = session.db().await;
1829 db.delete_project(project_id).await?;
1830 }
1831
1832 Ok(())
1833}
1834
1835/// Join someone elses shared project.
1836async fn join_project(
1837 request: proto::JoinProject,
1838 response: Response<proto::JoinProject>,
1839 session: Session,
1840) -> Result<()> {
1841 let project_id = ProjectId::from_proto(request.project_id);
1842
1843 tracing::info!(%project_id, "join project");
1844
1845 let db = session.db().await;
1846 let (project, replica_id) = &mut *db
1847 .join_project(project_id, session.connection_id, session.user_id())
1848 .await?;
1849 drop(db);
1850 tracing::info!(%project_id, "join remote project");
1851 join_project_internal(response, session, project, replica_id)
1852}
1853
1854trait JoinProjectInternalResponse {
1855 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1856}
1857impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1858 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1859 Response::<proto::JoinProject>::send(self, result)
1860 }
1861}
1862
1863fn join_project_internal(
1864 response: impl JoinProjectInternalResponse,
1865 session: Session,
1866 project: &mut Project,
1867 replica_id: &ReplicaId,
1868) -> Result<()> {
1869 let collaborators = project
1870 .collaborators
1871 .iter()
1872 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1873 .map(|collaborator| collaborator.to_proto())
1874 .collect::<Vec<_>>();
1875 let project_id = project.id;
1876 let guest_user_id = session.user_id();
1877
1878 let worktrees = project
1879 .worktrees
1880 .iter()
1881 .map(|(id, worktree)| proto::WorktreeMetadata {
1882 id: *id,
1883 root_name: worktree.root_name.clone(),
1884 visible: worktree.visible,
1885 abs_path: worktree.abs_path.clone(),
1886 })
1887 .collect::<Vec<_>>();
1888
1889 let add_project_collaborator = proto::AddProjectCollaborator {
1890 project_id: project_id.to_proto(),
1891 collaborator: Some(proto::Collaborator {
1892 peer_id: Some(session.connection_id.into()),
1893 replica_id: replica_id.0 as u32,
1894 user_id: guest_user_id.to_proto(),
1895 is_host: false,
1896 }),
1897 };
1898
1899 for collaborator in &collaborators {
1900 session
1901 .peer
1902 .send(
1903 collaborator.peer_id.unwrap().into(),
1904 add_project_collaborator.clone(),
1905 )
1906 .trace_err();
1907 }
1908
1909 // First, we send the metadata associated with each worktree.
1910 response.send(proto::JoinProjectResponse {
1911 project_id: project.id.0 as u64,
1912 worktrees: worktrees.clone(),
1913 replica_id: replica_id.0 as u32,
1914 collaborators: collaborators.clone(),
1915 language_servers: project.language_servers.clone(),
1916 role: project.role.into(),
1917 })?;
1918
1919 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1920 // Stream this worktree's entries.
1921 let message = proto::UpdateWorktree {
1922 project_id: project_id.to_proto(),
1923 worktree_id,
1924 abs_path: worktree.abs_path.clone(),
1925 root_name: worktree.root_name,
1926 updated_entries: worktree.entries,
1927 removed_entries: Default::default(),
1928 scan_id: worktree.scan_id,
1929 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1930 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1931 removed_repositories: Default::default(),
1932 };
1933 for update in proto::split_worktree_update(message) {
1934 session.peer.send(session.connection_id, update.clone())?;
1935 }
1936
1937 // Stream this worktree's diagnostics.
1938 for summary in worktree.diagnostic_summaries {
1939 session.peer.send(
1940 session.connection_id,
1941 proto::UpdateDiagnosticSummary {
1942 project_id: project_id.to_proto(),
1943 worktree_id: worktree.id,
1944 summary: Some(summary),
1945 },
1946 )?;
1947 }
1948
1949 for settings_file in worktree.settings_files {
1950 session.peer.send(
1951 session.connection_id,
1952 proto::UpdateWorktreeSettings {
1953 project_id: project_id.to_proto(),
1954 worktree_id: worktree.id,
1955 path: settings_file.path,
1956 content: Some(settings_file.content),
1957 kind: Some(settings_file.kind.to_proto() as i32),
1958 },
1959 )?;
1960 }
1961 }
1962
1963 for repository in mem::take(&mut project.repositories) {
1964 for update in split_repository_update(repository) {
1965 session.peer.send(session.connection_id, update)?;
1966 }
1967 }
1968
1969 for language_server in &project.language_servers {
1970 session.peer.send(
1971 session.connection_id,
1972 proto::UpdateLanguageServer {
1973 project_id: project_id.to_proto(),
1974 language_server_id: language_server.id,
1975 variant: Some(
1976 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1977 proto::LspDiskBasedDiagnosticsUpdated {},
1978 ),
1979 ),
1980 },
1981 )?;
1982 }
1983
1984 Ok(())
1985}
1986
1987/// Leave someone elses shared project.
1988async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1989 let sender_id = session.connection_id;
1990 let project_id = ProjectId::from_proto(request.project_id);
1991 let db = session.db().await;
1992
1993 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1994 tracing::info!(
1995 %project_id,
1996 "leave project"
1997 );
1998
1999 project_left(project, &session);
2000 if let Some(room) = room {
2001 room_updated(room, &session.peer);
2002 }
2003
2004 Ok(())
2005}
2006
2007/// Updates other participants with changes to the project
2008async fn update_project(
2009 request: proto::UpdateProject,
2010 response: Response<proto::UpdateProject>,
2011 session: Session,
2012) -> Result<()> {
2013 let project_id = ProjectId::from_proto(request.project_id);
2014 let (room, guest_connection_ids) = &*session
2015 .db()
2016 .await
2017 .update_project(project_id, session.connection_id, &request.worktrees)
2018 .await?;
2019 broadcast(
2020 Some(session.connection_id),
2021 guest_connection_ids.iter().copied(),
2022 |connection_id| {
2023 session
2024 .peer
2025 .forward_send(session.connection_id, connection_id, request.clone())
2026 },
2027 );
2028 if let Some(room) = room {
2029 room_updated(room, &session.peer);
2030 }
2031 response.send(proto::Ack {})?;
2032
2033 Ok(())
2034}
2035
2036/// Updates other participants with changes to the worktree
2037async fn update_worktree(
2038 request: proto::UpdateWorktree,
2039 response: Response<proto::UpdateWorktree>,
2040 session: Session,
2041) -> Result<()> {
2042 let guest_connection_ids = session
2043 .db()
2044 .await
2045 .update_worktree(&request, session.connection_id)
2046 .await?;
2047
2048 broadcast(
2049 Some(session.connection_id),
2050 guest_connection_ids.iter().copied(),
2051 |connection_id| {
2052 session
2053 .peer
2054 .forward_send(session.connection_id, connection_id, request.clone())
2055 },
2056 );
2057 response.send(proto::Ack {})?;
2058 Ok(())
2059}
2060
2061async fn update_repository(
2062 request: proto::UpdateRepository,
2063 response: Response<proto::UpdateRepository>,
2064 session: Session,
2065) -> Result<()> {
2066 let guest_connection_ids = session
2067 .db()
2068 .await
2069 .update_repository(&request, session.connection_id)
2070 .await?;
2071
2072 broadcast(
2073 Some(session.connection_id),
2074 guest_connection_ids.iter().copied(),
2075 |connection_id| {
2076 session
2077 .peer
2078 .forward_send(session.connection_id, connection_id, request.clone())
2079 },
2080 );
2081 response.send(proto::Ack {})?;
2082 Ok(())
2083}
2084
2085async fn remove_repository(
2086 request: proto::RemoveRepository,
2087 response: Response<proto::RemoveRepository>,
2088 session: Session,
2089) -> Result<()> {
2090 let guest_connection_ids = session
2091 .db()
2092 .await
2093 .remove_repository(&request, session.connection_id)
2094 .await?;
2095
2096 broadcast(
2097 Some(session.connection_id),
2098 guest_connection_ids.iter().copied(),
2099 |connection_id| {
2100 session
2101 .peer
2102 .forward_send(session.connection_id, connection_id, request.clone())
2103 },
2104 );
2105 response.send(proto::Ack {})?;
2106 Ok(())
2107}
2108
2109/// Updates other participants with changes to the diagnostics
2110async fn update_diagnostic_summary(
2111 message: proto::UpdateDiagnosticSummary,
2112 session: Session,
2113) -> Result<()> {
2114 let guest_connection_ids = session
2115 .db()
2116 .await
2117 .update_diagnostic_summary(&message, session.connection_id)
2118 .await?;
2119
2120 broadcast(
2121 Some(session.connection_id),
2122 guest_connection_ids.iter().copied(),
2123 |connection_id| {
2124 session
2125 .peer
2126 .forward_send(session.connection_id, connection_id, message.clone())
2127 },
2128 );
2129
2130 Ok(())
2131}
2132
2133/// Updates other participants with changes to the worktree settings
2134async fn update_worktree_settings(
2135 message: proto::UpdateWorktreeSettings,
2136 session: Session,
2137) -> Result<()> {
2138 let guest_connection_ids = session
2139 .db()
2140 .await
2141 .update_worktree_settings(&message, session.connection_id)
2142 .await?;
2143
2144 broadcast(
2145 Some(session.connection_id),
2146 guest_connection_ids.iter().copied(),
2147 |connection_id| {
2148 session
2149 .peer
2150 .forward_send(session.connection_id, connection_id, message.clone())
2151 },
2152 );
2153
2154 Ok(())
2155}
2156
2157/// Notify other participants that a language server has started.
2158async fn start_language_server(
2159 request: proto::StartLanguageServer,
2160 session: Session,
2161) -> Result<()> {
2162 let guest_connection_ids = session
2163 .db()
2164 .await
2165 .start_language_server(&request, session.connection_id)
2166 .await?;
2167
2168 broadcast(
2169 Some(session.connection_id),
2170 guest_connection_ids.iter().copied(),
2171 |connection_id| {
2172 session
2173 .peer
2174 .forward_send(session.connection_id, connection_id, request.clone())
2175 },
2176 );
2177 Ok(())
2178}
2179
2180/// Notify other participants that a language server has changed.
2181async fn update_language_server(
2182 request: proto::UpdateLanguageServer,
2183 session: Session,
2184) -> Result<()> {
2185 let project_id = ProjectId::from_proto(request.project_id);
2186 let project_connection_ids = session
2187 .db()
2188 .await
2189 .project_connection_ids(project_id, session.connection_id, true)
2190 .await?;
2191 broadcast(
2192 Some(session.connection_id),
2193 project_connection_ids.iter().copied(),
2194 |connection_id| {
2195 session
2196 .peer
2197 .forward_send(session.connection_id, connection_id, request.clone())
2198 },
2199 );
2200 Ok(())
2201}
2202
2203/// forward a project request to the host. These requests should be read only
2204/// as guests are allowed to send them.
2205async fn forward_read_only_project_request<T>(
2206 request: T,
2207 response: Response<T>,
2208 session: Session,
2209) -> Result<()>
2210where
2211 T: EntityMessage + RequestMessage,
2212{
2213 let project_id = ProjectId::from_proto(request.remote_entity_id());
2214 let host_connection_id = session
2215 .db()
2216 .await
2217 .host_for_read_only_project_request(project_id, session.connection_id)
2218 .await?;
2219 let payload = session
2220 .peer
2221 .forward_request(session.connection_id, host_connection_id, request)
2222 .await?;
2223 response.send(payload)?;
2224 Ok(())
2225}
2226
2227async fn forward_find_search_candidates_request(
2228 request: proto::FindSearchCandidates,
2229 response: Response<proto::FindSearchCandidates>,
2230 session: Session,
2231) -> Result<()> {
2232 let project_id = ProjectId::from_proto(request.remote_entity_id());
2233 let host_connection_id = session
2234 .db()
2235 .await
2236 .host_for_read_only_project_request(project_id, session.connection_id)
2237 .await?;
2238 let payload = session
2239 .peer
2240 .forward_request(session.connection_id, host_connection_id, request)
2241 .await?;
2242 response.send(payload)?;
2243 Ok(())
2244}
2245
2246/// forward a project request to the host. These requests are disallowed
2247/// for guests.
2248async fn forward_mutating_project_request<T>(
2249 request: T,
2250 response: Response<T>,
2251 session: Session,
2252) -> Result<()>
2253where
2254 T: EntityMessage + RequestMessage,
2255{
2256 let project_id = ProjectId::from_proto(request.remote_entity_id());
2257
2258 let host_connection_id = session
2259 .db()
2260 .await
2261 .host_for_mutating_project_request(project_id, session.connection_id)
2262 .await?;
2263 let payload = session
2264 .peer
2265 .forward_request(session.connection_id, host_connection_id, request)
2266 .await?;
2267 response.send(payload)?;
2268 Ok(())
2269}
2270
2271/// Notify other participants that a new buffer has been created
2272async fn create_buffer_for_peer(
2273 request: proto::CreateBufferForPeer,
2274 session: Session,
2275) -> Result<()> {
2276 session
2277 .db()
2278 .await
2279 .check_user_is_project_host(
2280 ProjectId::from_proto(request.project_id),
2281 session.connection_id,
2282 )
2283 .await?;
2284 let peer_id = request.peer_id.context("invalid peer id")?;
2285 session
2286 .peer
2287 .forward_send(session.connection_id, peer_id.into(), request)?;
2288 Ok(())
2289}
2290
2291/// Notify other participants that a buffer has been updated. This is
2292/// allowed for guests as long as the update is limited to selections.
2293async fn update_buffer(
2294 request: proto::UpdateBuffer,
2295 response: Response<proto::UpdateBuffer>,
2296 session: Session,
2297) -> Result<()> {
2298 let project_id = ProjectId::from_proto(request.project_id);
2299 let mut capability = Capability::ReadOnly;
2300
2301 for op in request.operations.iter() {
2302 match op.variant {
2303 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2304 Some(_) => capability = Capability::ReadWrite,
2305 }
2306 }
2307
2308 let host = {
2309 let guard = session
2310 .db()
2311 .await
2312 .connections_for_buffer_update(project_id, session.connection_id, capability)
2313 .await?;
2314
2315 let (host, guests) = &*guard;
2316
2317 broadcast(
2318 Some(session.connection_id),
2319 guests.clone(),
2320 |connection_id| {
2321 session
2322 .peer
2323 .forward_send(session.connection_id, connection_id, request.clone())
2324 },
2325 );
2326
2327 *host
2328 };
2329
2330 if host != session.connection_id {
2331 session
2332 .peer
2333 .forward_request(session.connection_id, host, request.clone())
2334 .await?;
2335 }
2336
2337 response.send(proto::Ack {})?;
2338 Ok(())
2339}
2340
2341async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2342 let project_id = ProjectId::from_proto(message.project_id);
2343
2344 let operation = message.operation.as_ref().context("invalid operation")?;
2345 let capability = match operation.variant.as_ref() {
2346 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2347 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2348 match buffer_op.variant {
2349 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2350 Capability::ReadOnly
2351 }
2352 _ => Capability::ReadWrite,
2353 }
2354 } else {
2355 Capability::ReadWrite
2356 }
2357 }
2358 Some(_) => Capability::ReadWrite,
2359 None => Capability::ReadOnly,
2360 };
2361
2362 let guard = session
2363 .db()
2364 .await
2365 .connections_for_buffer_update(project_id, session.connection_id, capability)
2366 .await?;
2367
2368 let (host, guests) = &*guard;
2369
2370 broadcast(
2371 Some(session.connection_id),
2372 guests.iter().chain([host]).copied(),
2373 |connection_id| {
2374 session
2375 .peer
2376 .forward_send(session.connection_id, connection_id, message.clone())
2377 },
2378 );
2379
2380 Ok(())
2381}
2382
2383/// Notify other participants that a project has been updated.
2384async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2385 request: T,
2386 session: Session,
2387) -> Result<()> {
2388 let project_id = ProjectId::from_proto(request.remote_entity_id());
2389 let project_connection_ids = session
2390 .db()
2391 .await
2392 .project_connection_ids(project_id, session.connection_id, false)
2393 .await?;
2394
2395 broadcast(
2396 Some(session.connection_id),
2397 project_connection_ids.iter().copied(),
2398 |connection_id| {
2399 session
2400 .peer
2401 .forward_send(session.connection_id, connection_id, request.clone())
2402 },
2403 );
2404 Ok(())
2405}
2406
2407/// Start following another user in a call.
2408async fn follow(
2409 request: proto::Follow,
2410 response: Response<proto::Follow>,
2411 session: Session,
2412) -> Result<()> {
2413 let room_id = RoomId::from_proto(request.room_id);
2414 let project_id = request.project_id.map(ProjectId::from_proto);
2415 let leader_id = request.leader_id.context("invalid leader id")?.into();
2416 let follower_id = session.connection_id;
2417
2418 session
2419 .db()
2420 .await
2421 .check_room_participants(room_id, leader_id, session.connection_id)
2422 .await?;
2423
2424 let response_payload = session
2425 .peer
2426 .forward_request(session.connection_id, leader_id, request)
2427 .await?;
2428 response.send(response_payload)?;
2429
2430 if let Some(project_id) = project_id {
2431 let room = session
2432 .db()
2433 .await
2434 .follow(room_id, project_id, leader_id, follower_id)
2435 .await?;
2436 room_updated(&room, &session.peer);
2437 }
2438
2439 Ok(())
2440}
2441
2442/// Stop following another user in a call.
2443async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2444 let room_id = RoomId::from_proto(request.room_id);
2445 let project_id = request.project_id.map(ProjectId::from_proto);
2446 let leader_id = request.leader_id.context("invalid leader id")?.into();
2447 let follower_id = session.connection_id;
2448
2449 session
2450 .db()
2451 .await
2452 .check_room_participants(room_id, leader_id, session.connection_id)
2453 .await?;
2454
2455 session
2456 .peer
2457 .forward_send(session.connection_id, leader_id, request)?;
2458
2459 if let Some(project_id) = project_id {
2460 let room = session
2461 .db()
2462 .await
2463 .unfollow(room_id, project_id, leader_id, follower_id)
2464 .await?;
2465 room_updated(&room, &session.peer);
2466 }
2467
2468 Ok(())
2469}
2470
2471/// Notify everyone following you of your current location.
2472async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2473 let room_id = RoomId::from_proto(request.room_id);
2474 let database = session.db.lock().await;
2475
2476 let connection_ids = if let Some(project_id) = request.project_id {
2477 let project_id = ProjectId::from_proto(project_id);
2478 database
2479 .project_connection_ids(project_id, session.connection_id, true)
2480 .await?
2481 } else {
2482 database
2483 .room_connection_ids(room_id, session.connection_id)
2484 .await?
2485 };
2486
2487 // For now, don't send view update messages back to that view's current leader.
2488 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2489 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2490 _ => None,
2491 });
2492
2493 for connection_id in connection_ids.iter().cloned() {
2494 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2495 session
2496 .peer
2497 .forward_send(session.connection_id, connection_id, request.clone())?;
2498 }
2499 }
2500 Ok(())
2501}
2502
2503/// Get public data about users.
2504async fn get_users(
2505 request: proto::GetUsers,
2506 response: Response<proto::GetUsers>,
2507 session: Session,
2508) -> Result<()> {
2509 let user_ids = request
2510 .user_ids
2511 .into_iter()
2512 .map(UserId::from_proto)
2513 .collect();
2514 let users = session
2515 .db()
2516 .await
2517 .get_users_by_ids(user_ids)
2518 .await?
2519 .into_iter()
2520 .map(|user| proto::User {
2521 id: user.id.to_proto(),
2522 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2523 github_login: user.github_login,
2524 email: user.email_address,
2525 name: user.name,
2526 })
2527 .collect();
2528 response.send(proto::UsersResponse { users })?;
2529 Ok(())
2530}
2531
2532/// Search for users (to invite) buy Github login
2533async fn fuzzy_search_users(
2534 request: proto::FuzzySearchUsers,
2535 response: Response<proto::FuzzySearchUsers>,
2536 session: Session,
2537) -> Result<()> {
2538 let query = request.query;
2539 let users = match query.len() {
2540 0 => vec![],
2541 1 | 2 => session
2542 .db()
2543 .await
2544 .get_user_by_github_login(&query)
2545 .await?
2546 .into_iter()
2547 .collect(),
2548 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2549 };
2550 let users = users
2551 .into_iter()
2552 .filter(|user| user.id != session.user_id())
2553 .map(|user| proto::User {
2554 id: user.id.to_proto(),
2555 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2556 github_login: user.github_login,
2557 name: user.name,
2558 email: user.email_address,
2559 })
2560 .collect();
2561 response.send(proto::UsersResponse { users })?;
2562 Ok(())
2563}
2564
2565/// Send a contact request to another user.
2566async fn request_contact(
2567 request: proto::RequestContact,
2568 response: Response<proto::RequestContact>,
2569 session: Session,
2570) -> Result<()> {
2571 let requester_id = session.user_id();
2572 let responder_id = UserId::from_proto(request.responder_id);
2573 if requester_id == responder_id {
2574 return Err(anyhow!("cannot add yourself as a contact"))?;
2575 }
2576
2577 let notifications = session
2578 .db()
2579 .await
2580 .send_contact_request(requester_id, responder_id)
2581 .await?;
2582
2583 // Update outgoing contact requests of requester
2584 let mut update = proto::UpdateContacts::default();
2585 update.outgoing_requests.push(responder_id.to_proto());
2586 for connection_id in session
2587 .connection_pool()
2588 .await
2589 .user_connection_ids(requester_id)
2590 {
2591 session.peer.send(connection_id, update.clone())?;
2592 }
2593
2594 // Update incoming contact requests of responder
2595 let mut update = proto::UpdateContacts::default();
2596 update
2597 .incoming_requests
2598 .push(proto::IncomingContactRequest {
2599 requester_id: requester_id.to_proto(),
2600 });
2601 let connection_pool = session.connection_pool().await;
2602 for connection_id in connection_pool.user_connection_ids(responder_id) {
2603 session.peer.send(connection_id, update.clone())?;
2604 }
2605
2606 send_notifications(&connection_pool, &session.peer, notifications);
2607
2608 response.send(proto::Ack {})?;
2609 Ok(())
2610}
2611
2612/// Accept or decline a contact request
2613async fn respond_to_contact_request(
2614 request: proto::RespondToContactRequest,
2615 response: Response<proto::RespondToContactRequest>,
2616 session: Session,
2617) -> Result<()> {
2618 let responder_id = session.user_id();
2619 let requester_id = UserId::from_proto(request.requester_id);
2620 let db = session.db().await;
2621 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2622 db.dismiss_contact_notification(responder_id, requester_id)
2623 .await?;
2624 } else {
2625 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2626
2627 let notifications = db
2628 .respond_to_contact_request(responder_id, requester_id, accept)
2629 .await?;
2630 let requester_busy = db.is_user_busy(requester_id).await?;
2631 let responder_busy = db.is_user_busy(responder_id).await?;
2632
2633 let pool = session.connection_pool().await;
2634 // Update responder with new contact
2635 let mut update = proto::UpdateContacts::default();
2636 if accept {
2637 update
2638 .contacts
2639 .push(contact_for_user(requester_id, requester_busy, &pool));
2640 }
2641 update
2642 .remove_incoming_requests
2643 .push(requester_id.to_proto());
2644 for connection_id in pool.user_connection_ids(responder_id) {
2645 session.peer.send(connection_id, update.clone())?;
2646 }
2647
2648 // Update requester with new contact
2649 let mut update = proto::UpdateContacts::default();
2650 if accept {
2651 update
2652 .contacts
2653 .push(contact_for_user(responder_id, responder_busy, &pool));
2654 }
2655 update
2656 .remove_outgoing_requests
2657 .push(responder_id.to_proto());
2658
2659 for connection_id in pool.user_connection_ids(requester_id) {
2660 session.peer.send(connection_id, update.clone())?;
2661 }
2662
2663 send_notifications(&pool, &session.peer, notifications);
2664 }
2665
2666 response.send(proto::Ack {})?;
2667 Ok(())
2668}
2669
2670/// Remove a contact.
2671async fn remove_contact(
2672 request: proto::RemoveContact,
2673 response: Response<proto::RemoveContact>,
2674 session: Session,
2675) -> Result<()> {
2676 let requester_id = session.user_id();
2677 let responder_id = UserId::from_proto(request.user_id);
2678 let db = session.db().await;
2679 let (contact_accepted, deleted_notification_id) =
2680 db.remove_contact(requester_id, responder_id).await?;
2681
2682 let pool = session.connection_pool().await;
2683 // Update outgoing contact requests of requester
2684 let mut update = proto::UpdateContacts::default();
2685 if contact_accepted {
2686 update.remove_contacts.push(responder_id.to_proto());
2687 } else {
2688 update
2689 .remove_outgoing_requests
2690 .push(responder_id.to_proto());
2691 }
2692 for connection_id in pool.user_connection_ids(requester_id) {
2693 session.peer.send(connection_id, update.clone())?;
2694 }
2695
2696 // Update incoming contact requests of responder
2697 let mut update = proto::UpdateContacts::default();
2698 if contact_accepted {
2699 update.remove_contacts.push(requester_id.to_proto());
2700 } else {
2701 update
2702 .remove_incoming_requests
2703 .push(requester_id.to_proto());
2704 }
2705 for connection_id in pool.user_connection_ids(responder_id) {
2706 session.peer.send(connection_id, update.clone())?;
2707 if let Some(notification_id) = deleted_notification_id {
2708 session.peer.send(
2709 connection_id,
2710 proto::DeleteNotification {
2711 notification_id: notification_id.to_proto(),
2712 },
2713 )?;
2714 }
2715 }
2716
2717 response.send(proto::Ack {})?;
2718 Ok(())
2719}
2720
2721fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2722 version.0.minor() < 139
2723}
2724
2725async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2726 if is_staff {
2727 return Ok(proto::Plan::ZedPro);
2728 }
2729
2730 let subscription = db.get_active_billing_subscription(user_id).await?;
2731 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2732
2733 let plan = if let Some(subscription_kind) = subscription_kind {
2734 match subscription_kind {
2735 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2736 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2737 SubscriptionKind::ZedFree => proto::Plan::Free,
2738 }
2739 } else {
2740 proto::Plan::Free
2741 };
2742
2743 Ok(plan)
2744}
2745
2746async fn make_update_user_plan_message(
2747 user: &User,
2748 is_staff: bool,
2749 db: &Arc<Database>,
2750 llm_db: Option<Arc<LlmDatabase>>,
2751) -> Result<proto::UpdateUserPlan> {
2752 let feature_flags = db.get_user_flags(user.id).await?;
2753 let plan = current_plan(db, user.id, is_staff).await?;
2754 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2755 let billing_preferences = db.get_billing_preferences(user.id).await?;
2756
2757 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2758 let subscription = db.get_active_billing_subscription(user.id).await?;
2759
2760 let subscription_period =
2761 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2762
2763 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2764 llm_db
2765 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2766 .await?
2767 } else {
2768 None
2769 };
2770
2771 (subscription_period, usage)
2772 } else {
2773 (None, None)
2774 };
2775
2776 let account_too_young =
2777 !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2778
2779 Ok(proto::UpdateUserPlan {
2780 plan: plan.into(),
2781 trial_started_at: billing_customer
2782 .as_ref()
2783 .and_then(|billing_customer| billing_customer.trial_started_at)
2784 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2785 is_usage_based_billing_enabled: if is_staff {
2786 Some(true)
2787 } else {
2788 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2789 },
2790 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2791 proto::SubscriptionPeriod {
2792 started_at: started_at.timestamp() as u64,
2793 ended_at: ended_at.timestamp() as u64,
2794 }
2795 }),
2796 account_too_young: Some(account_too_young),
2797 has_overdue_invoices: billing_customer
2798 .map(|billing_customer| billing_customer.has_overdue_invoices),
2799 usage: usage.map(|usage| {
2800 let plan = match plan {
2801 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2802 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2803 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2804 };
2805
2806 let model_requests_limit = match plan.model_requests_limit() {
2807 zed_llm_client::UsageLimit::Limited(limit) => {
2808 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2809 && feature_flags
2810 .iter()
2811 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2812 {
2813 1_000
2814 } else {
2815 limit
2816 };
2817
2818 zed_llm_client::UsageLimit::Limited(limit)
2819 }
2820 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2821 };
2822
2823 proto::SubscriptionUsage {
2824 model_requests_usage_amount: usage.model_requests as u32,
2825 model_requests_usage_limit: Some(proto::UsageLimit {
2826 variant: Some(match model_requests_limit {
2827 zed_llm_client::UsageLimit::Limited(limit) => {
2828 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2829 limit: limit as u32,
2830 })
2831 }
2832 zed_llm_client::UsageLimit::Unlimited => {
2833 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2834 }
2835 }),
2836 }),
2837 edit_predictions_usage_amount: usage.edit_predictions as u32,
2838 edit_predictions_usage_limit: Some(proto::UsageLimit {
2839 variant: Some(match plan.edit_predictions_limit() {
2840 zed_llm_client::UsageLimit::Limited(limit) => {
2841 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2842 limit: limit as u32,
2843 })
2844 }
2845 zed_llm_client::UsageLimit::Unlimited => {
2846 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2847 }
2848 }),
2849 }),
2850 }
2851 }),
2852 })
2853}
2854
2855async fn update_user_plan(session: &Session) -> Result<()> {
2856 let db = session.db().await;
2857
2858 let update_user_plan = make_update_user_plan_message(
2859 session.principal.user(),
2860 session.is_staff(),
2861 &db.0,
2862 session.app_state.llm_db.clone(),
2863 )
2864 .await?;
2865
2866 session
2867 .peer
2868 .send(session.connection_id, update_user_plan)
2869 .trace_err();
2870
2871 Ok(())
2872}
2873
2874async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2875 subscribe_user_to_channels(session.user_id(), &session).await?;
2876 Ok(())
2877}
2878
2879async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2880 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2881 let mut pool = session.connection_pool().await;
2882 for membership in &channels_for_user.channel_memberships {
2883 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2884 }
2885 session.peer.send(
2886 session.connection_id,
2887 build_update_user_channels(&channels_for_user),
2888 )?;
2889 session.peer.send(
2890 session.connection_id,
2891 build_channels_update(channels_for_user),
2892 )?;
2893 Ok(())
2894}
2895
2896/// Creates a new channel.
2897async fn create_channel(
2898 request: proto::CreateChannel,
2899 response: Response<proto::CreateChannel>,
2900 session: Session,
2901) -> Result<()> {
2902 let db = session.db().await;
2903
2904 let parent_id = request.parent_id.map(ChannelId::from_proto);
2905 let (channel, membership) = db
2906 .create_channel(&request.name, parent_id, session.user_id())
2907 .await?;
2908
2909 let root_id = channel.root_id();
2910 let channel = Channel::from_model(channel);
2911
2912 response.send(proto::CreateChannelResponse {
2913 channel: Some(channel.to_proto()),
2914 parent_id: request.parent_id,
2915 })?;
2916
2917 let mut connection_pool = session.connection_pool().await;
2918 if let Some(membership) = membership {
2919 connection_pool.subscribe_to_channel(
2920 membership.user_id,
2921 membership.channel_id,
2922 membership.role,
2923 );
2924 let update = proto::UpdateUserChannels {
2925 channel_memberships: vec![proto::ChannelMembership {
2926 channel_id: membership.channel_id.to_proto(),
2927 role: membership.role.into(),
2928 }],
2929 ..Default::default()
2930 };
2931 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2932 session.peer.send(connection_id, update.clone())?;
2933 }
2934 }
2935
2936 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2937 if !role.can_see_channel(channel.visibility) {
2938 continue;
2939 }
2940
2941 let update = proto::UpdateChannels {
2942 channels: vec![channel.to_proto()],
2943 ..Default::default()
2944 };
2945 session.peer.send(connection_id, update.clone())?;
2946 }
2947
2948 Ok(())
2949}
2950
2951/// Delete a channel
2952async fn delete_channel(
2953 request: proto::DeleteChannel,
2954 response: Response<proto::DeleteChannel>,
2955 session: Session,
2956) -> Result<()> {
2957 let db = session.db().await;
2958
2959 let channel_id = request.channel_id;
2960 let (root_channel, removed_channels) = db
2961 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2962 .await?;
2963 response.send(proto::Ack {})?;
2964
2965 // Notify members of removed channels
2966 let mut update = proto::UpdateChannels::default();
2967 update
2968 .delete_channels
2969 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2970
2971 let connection_pool = session.connection_pool().await;
2972 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2973 session.peer.send(connection_id, update.clone())?;
2974 }
2975
2976 Ok(())
2977}
2978
2979/// Invite someone to join a channel.
2980async fn invite_channel_member(
2981 request: proto::InviteChannelMember,
2982 response: Response<proto::InviteChannelMember>,
2983 session: Session,
2984) -> Result<()> {
2985 let db = session.db().await;
2986 let channel_id = ChannelId::from_proto(request.channel_id);
2987 let invitee_id = UserId::from_proto(request.user_id);
2988 let InviteMemberResult {
2989 channel,
2990 notifications,
2991 } = db
2992 .invite_channel_member(
2993 channel_id,
2994 invitee_id,
2995 session.user_id(),
2996 request.role().into(),
2997 )
2998 .await?;
2999
3000 let update = proto::UpdateChannels {
3001 channel_invitations: vec![channel.to_proto()],
3002 ..Default::default()
3003 };
3004
3005 let connection_pool = session.connection_pool().await;
3006 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3007 session.peer.send(connection_id, update.clone())?;
3008 }
3009
3010 send_notifications(&connection_pool, &session.peer, notifications);
3011
3012 response.send(proto::Ack {})?;
3013 Ok(())
3014}
3015
3016/// remove someone from a channel
3017async fn remove_channel_member(
3018 request: proto::RemoveChannelMember,
3019 response: Response<proto::RemoveChannelMember>,
3020 session: Session,
3021) -> Result<()> {
3022 let db = session.db().await;
3023 let channel_id = ChannelId::from_proto(request.channel_id);
3024 let member_id = UserId::from_proto(request.user_id);
3025
3026 let RemoveChannelMemberResult {
3027 membership_update,
3028 notification_id,
3029 } = db
3030 .remove_channel_member(channel_id, member_id, session.user_id())
3031 .await?;
3032
3033 let mut connection_pool = session.connection_pool().await;
3034 notify_membership_updated(
3035 &mut connection_pool,
3036 membership_update,
3037 member_id,
3038 &session.peer,
3039 );
3040 for connection_id in connection_pool.user_connection_ids(member_id) {
3041 if let Some(notification_id) = notification_id {
3042 session
3043 .peer
3044 .send(
3045 connection_id,
3046 proto::DeleteNotification {
3047 notification_id: notification_id.to_proto(),
3048 },
3049 )
3050 .trace_err();
3051 }
3052 }
3053
3054 response.send(proto::Ack {})?;
3055 Ok(())
3056}
3057
3058/// Toggle the channel between public and private.
3059/// Care is taken to maintain the invariant that public channels only descend from public channels,
3060/// (though members-only channels can appear at any point in the hierarchy).
3061async fn set_channel_visibility(
3062 request: proto::SetChannelVisibility,
3063 response: Response<proto::SetChannelVisibility>,
3064 session: Session,
3065) -> Result<()> {
3066 let db = session.db().await;
3067 let channel_id = ChannelId::from_proto(request.channel_id);
3068 let visibility = request.visibility().into();
3069
3070 let channel_model = db
3071 .set_channel_visibility(channel_id, visibility, session.user_id())
3072 .await?;
3073 let root_id = channel_model.root_id();
3074 let channel = Channel::from_model(channel_model);
3075
3076 let mut connection_pool = session.connection_pool().await;
3077 for (user_id, role) in connection_pool
3078 .channel_user_ids(root_id)
3079 .collect::<Vec<_>>()
3080 .into_iter()
3081 {
3082 let update = if role.can_see_channel(channel.visibility) {
3083 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3084 proto::UpdateChannels {
3085 channels: vec![channel.to_proto()],
3086 ..Default::default()
3087 }
3088 } else {
3089 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3090 proto::UpdateChannels {
3091 delete_channels: vec![channel.id.to_proto()],
3092 ..Default::default()
3093 }
3094 };
3095
3096 for connection_id in connection_pool.user_connection_ids(user_id) {
3097 session.peer.send(connection_id, update.clone())?;
3098 }
3099 }
3100
3101 response.send(proto::Ack {})?;
3102 Ok(())
3103}
3104
3105/// Alter the role for a user in the channel.
3106async fn set_channel_member_role(
3107 request: proto::SetChannelMemberRole,
3108 response: Response<proto::SetChannelMemberRole>,
3109 session: Session,
3110) -> Result<()> {
3111 let db = session.db().await;
3112 let channel_id = ChannelId::from_proto(request.channel_id);
3113 let member_id = UserId::from_proto(request.user_id);
3114 let result = db
3115 .set_channel_member_role(
3116 channel_id,
3117 session.user_id(),
3118 member_id,
3119 request.role().into(),
3120 )
3121 .await?;
3122
3123 match result {
3124 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3125 let mut connection_pool = session.connection_pool().await;
3126 notify_membership_updated(
3127 &mut connection_pool,
3128 membership_update,
3129 member_id,
3130 &session.peer,
3131 )
3132 }
3133 db::SetMemberRoleResult::InviteUpdated(channel) => {
3134 let update = proto::UpdateChannels {
3135 channel_invitations: vec![channel.to_proto()],
3136 ..Default::default()
3137 };
3138
3139 for connection_id in session
3140 .connection_pool()
3141 .await
3142 .user_connection_ids(member_id)
3143 {
3144 session.peer.send(connection_id, update.clone())?;
3145 }
3146 }
3147 }
3148
3149 response.send(proto::Ack {})?;
3150 Ok(())
3151}
3152
3153/// Change the name of a channel
3154async fn rename_channel(
3155 request: proto::RenameChannel,
3156 response: Response<proto::RenameChannel>,
3157 session: Session,
3158) -> Result<()> {
3159 let db = session.db().await;
3160 let channel_id = ChannelId::from_proto(request.channel_id);
3161 let channel_model = db
3162 .rename_channel(channel_id, session.user_id(), &request.name)
3163 .await?;
3164 let root_id = channel_model.root_id();
3165 let channel = Channel::from_model(channel_model);
3166
3167 response.send(proto::RenameChannelResponse {
3168 channel: Some(channel.to_proto()),
3169 })?;
3170
3171 let connection_pool = session.connection_pool().await;
3172 let update = proto::UpdateChannels {
3173 channels: vec![channel.to_proto()],
3174 ..Default::default()
3175 };
3176 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3177 if role.can_see_channel(channel.visibility) {
3178 session.peer.send(connection_id, update.clone())?;
3179 }
3180 }
3181
3182 Ok(())
3183}
3184
3185/// Move a channel to a new parent.
3186async fn move_channel(
3187 request: proto::MoveChannel,
3188 response: Response<proto::MoveChannel>,
3189 session: Session,
3190) -> Result<()> {
3191 let channel_id = ChannelId::from_proto(request.channel_id);
3192 let to = ChannelId::from_proto(request.to);
3193
3194 let (root_id, channels) = session
3195 .db()
3196 .await
3197 .move_channel(channel_id, to, session.user_id())
3198 .await?;
3199
3200 let connection_pool = session.connection_pool().await;
3201 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3202 let channels = channels
3203 .iter()
3204 .filter_map(|channel| {
3205 if role.can_see_channel(channel.visibility) {
3206 Some(channel.to_proto())
3207 } else {
3208 None
3209 }
3210 })
3211 .collect::<Vec<_>>();
3212 if channels.is_empty() {
3213 continue;
3214 }
3215
3216 let update = proto::UpdateChannels {
3217 channels,
3218 ..Default::default()
3219 };
3220
3221 session.peer.send(connection_id, update.clone())?;
3222 }
3223
3224 response.send(Ack {})?;
3225 Ok(())
3226}
3227
3228async fn reorder_channel(
3229 request: proto::ReorderChannel,
3230 response: Response<proto::ReorderChannel>,
3231 session: Session,
3232) -> Result<()> {
3233 let channel_id = ChannelId::from_proto(request.channel_id);
3234 let direction = request.direction();
3235
3236 let updated_channels = session
3237 .db()
3238 .await
3239 .reorder_channel(channel_id, direction, session.user_id())
3240 .await?;
3241
3242 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3243 let connection_pool = session.connection_pool().await;
3244 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3245 let channels = updated_channels
3246 .iter()
3247 .filter_map(|channel| {
3248 if role.can_see_channel(channel.visibility) {
3249 Some(channel.to_proto())
3250 } else {
3251 None
3252 }
3253 })
3254 .collect::<Vec<_>>();
3255
3256 if channels.is_empty() {
3257 continue;
3258 }
3259
3260 let update = proto::UpdateChannels {
3261 channels,
3262 ..Default::default()
3263 };
3264
3265 session.peer.send(connection_id, update.clone())?;
3266 }
3267 }
3268
3269 response.send(Ack {})?;
3270 Ok(())
3271}
3272
3273/// Get the list of channel members
3274async fn get_channel_members(
3275 request: proto::GetChannelMembers,
3276 response: Response<proto::GetChannelMembers>,
3277 session: Session,
3278) -> Result<()> {
3279 let db = session.db().await;
3280 let channel_id = ChannelId::from_proto(request.channel_id);
3281 let limit = if request.limit == 0 {
3282 u16::MAX as u64
3283 } else {
3284 request.limit
3285 };
3286 let (members, users) = db
3287 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3288 .await?;
3289 response.send(proto::GetChannelMembersResponse { members, users })?;
3290 Ok(())
3291}
3292
3293/// Accept or decline a channel invitation.
3294async fn respond_to_channel_invite(
3295 request: proto::RespondToChannelInvite,
3296 response: Response<proto::RespondToChannelInvite>,
3297 session: Session,
3298) -> Result<()> {
3299 let db = session.db().await;
3300 let channel_id = ChannelId::from_proto(request.channel_id);
3301 let RespondToChannelInvite {
3302 membership_update,
3303 notifications,
3304 } = db
3305 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3306 .await?;
3307
3308 let mut connection_pool = session.connection_pool().await;
3309 if let Some(membership_update) = membership_update {
3310 notify_membership_updated(
3311 &mut connection_pool,
3312 membership_update,
3313 session.user_id(),
3314 &session.peer,
3315 );
3316 } else {
3317 let update = proto::UpdateChannels {
3318 remove_channel_invitations: vec![channel_id.to_proto()],
3319 ..Default::default()
3320 };
3321
3322 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3323 session.peer.send(connection_id, update.clone())?;
3324 }
3325 };
3326
3327 send_notifications(&connection_pool, &session.peer, notifications);
3328
3329 response.send(proto::Ack {})?;
3330
3331 Ok(())
3332}
3333
3334/// Join the channels' room
3335async fn join_channel(
3336 request: proto::JoinChannel,
3337 response: Response<proto::JoinChannel>,
3338 session: Session,
3339) -> Result<()> {
3340 let channel_id = ChannelId::from_proto(request.channel_id);
3341 join_channel_internal(channel_id, Box::new(response), session).await
3342}
3343
3344trait JoinChannelInternalResponse {
3345 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3346}
3347impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3348 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3349 Response::<proto::JoinChannel>::send(self, result)
3350 }
3351}
3352impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3353 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3354 Response::<proto::JoinRoom>::send(self, result)
3355 }
3356}
3357
3358async fn join_channel_internal(
3359 channel_id: ChannelId,
3360 response: Box<impl JoinChannelInternalResponse>,
3361 session: Session,
3362) -> Result<()> {
3363 let joined_room = {
3364 let mut db = session.db().await;
3365 // If zed quits without leaving the room, and the user re-opens zed before the
3366 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3367 // room they were in.
3368 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3369 tracing::info!(
3370 stale_connection_id = %connection,
3371 "cleaning up stale connection",
3372 );
3373 drop(db);
3374 leave_room_for_session(&session, connection).await?;
3375 db = session.db().await;
3376 }
3377
3378 let (joined_room, membership_updated, role) = db
3379 .join_channel(channel_id, session.user_id(), session.connection_id)
3380 .await?;
3381
3382 let live_kit_connection_info =
3383 session
3384 .app_state
3385 .livekit_client
3386 .as_ref()
3387 .and_then(|live_kit| {
3388 let (can_publish, token) = if role == ChannelRole::Guest {
3389 (
3390 false,
3391 live_kit
3392 .guest_token(
3393 &joined_room.room.livekit_room,
3394 &session.user_id().to_string(),
3395 )
3396 .trace_err()?,
3397 )
3398 } else {
3399 (
3400 true,
3401 live_kit
3402 .room_token(
3403 &joined_room.room.livekit_room,
3404 &session.user_id().to_string(),
3405 )
3406 .trace_err()?,
3407 )
3408 };
3409
3410 Some(LiveKitConnectionInfo {
3411 server_url: live_kit.url().into(),
3412 token,
3413 can_publish,
3414 })
3415 });
3416
3417 response.send(proto::JoinRoomResponse {
3418 room: Some(joined_room.room.clone()),
3419 channel_id: joined_room
3420 .channel
3421 .as_ref()
3422 .map(|channel| channel.id.to_proto()),
3423 live_kit_connection_info,
3424 })?;
3425
3426 let mut connection_pool = session.connection_pool().await;
3427 if let Some(membership_updated) = membership_updated {
3428 notify_membership_updated(
3429 &mut connection_pool,
3430 membership_updated,
3431 session.user_id(),
3432 &session.peer,
3433 );
3434 }
3435
3436 room_updated(&joined_room.room, &session.peer);
3437
3438 joined_room
3439 };
3440
3441 channel_updated(
3442 &joined_room.channel.context("channel not returned")?,
3443 &joined_room.room,
3444 &session.peer,
3445 &*session.connection_pool().await,
3446 );
3447
3448 update_user_contacts(session.user_id(), &session).await?;
3449 Ok(())
3450}
3451
3452/// Start editing the channel notes
3453async fn join_channel_buffer(
3454 request: proto::JoinChannelBuffer,
3455 response: Response<proto::JoinChannelBuffer>,
3456 session: Session,
3457) -> Result<()> {
3458 let db = session.db().await;
3459 let channel_id = ChannelId::from_proto(request.channel_id);
3460
3461 let open_response = db
3462 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3463 .await?;
3464
3465 let collaborators = open_response.collaborators.clone();
3466 response.send(open_response)?;
3467
3468 let update = UpdateChannelBufferCollaborators {
3469 channel_id: channel_id.to_proto(),
3470 collaborators: collaborators.clone(),
3471 };
3472 channel_buffer_updated(
3473 session.connection_id,
3474 collaborators
3475 .iter()
3476 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3477 &update,
3478 &session.peer,
3479 );
3480
3481 Ok(())
3482}
3483
3484/// Edit the channel notes
3485async fn update_channel_buffer(
3486 request: proto::UpdateChannelBuffer,
3487 session: Session,
3488) -> Result<()> {
3489 let db = session.db().await;
3490 let channel_id = ChannelId::from_proto(request.channel_id);
3491
3492 let (collaborators, epoch, version) = db
3493 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3494 .await?;
3495
3496 channel_buffer_updated(
3497 session.connection_id,
3498 collaborators.clone(),
3499 &proto::UpdateChannelBuffer {
3500 channel_id: channel_id.to_proto(),
3501 operations: request.operations,
3502 },
3503 &session.peer,
3504 );
3505
3506 let pool = &*session.connection_pool().await;
3507
3508 let non_collaborators =
3509 pool.channel_connection_ids(channel_id)
3510 .filter_map(|(connection_id, _)| {
3511 if collaborators.contains(&connection_id) {
3512 None
3513 } else {
3514 Some(connection_id)
3515 }
3516 });
3517
3518 broadcast(None, non_collaborators, |peer_id| {
3519 session.peer.send(
3520 peer_id,
3521 proto::UpdateChannels {
3522 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3523 channel_id: channel_id.to_proto(),
3524 epoch: epoch as u64,
3525 version: version.clone(),
3526 }],
3527 ..Default::default()
3528 },
3529 )
3530 });
3531
3532 Ok(())
3533}
3534
3535/// Rejoin the channel notes after a connection blip
3536async fn rejoin_channel_buffers(
3537 request: proto::RejoinChannelBuffers,
3538 response: Response<proto::RejoinChannelBuffers>,
3539 session: Session,
3540) -> Result<()> {
3541 let db = session.db().await;
3542 let buffers = db
3543 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3544 .await?;
3545
3546 for rejoined_buffer in &buffers {
3547 let collaborators_to_notify = rejoined_buffer
3548 .buffer
3549 .collaborators
3550 .iter()
3551 .filter_map(|c| Some(c.peer_id?.into()));
3552 channel_buffer_updated(
3553 session.connection_id,
3554 collaborators_to_notify,
3555 &proto::UpdateChannelBufferCollaborators {
3556 channel_id: rejoined_buffer.buffer.channel_id,
3557 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3558 },
3559 &session.peer,
3560 );
3561 }
3562
3563 response.send(proto::RejoinChannelBuffersResponse {
3564 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3565 })?;
3566
3567 Ok(())
3568}
3569
3570/// Stop editing the channel notes
3571async fn leave_channel_buffer(
3572 request: proto::LeaveChannelBuffer,
3573 response: Response<proto::LeaveChannelBuffer>,
3574 session: Session,
3575) -> Result<()> {
3576 let db = session.db().await;
3577 let channel_id = ChannelId::from_proto(request.channel_id);
3578
3579 let left_buffer = db
3580 .leave_channel_buffer(channel_id, session.connection_id)
3581 .await?;
3582
3583 response.send(Ack {})?;
3584
3585 channel_buffer_updated(
3586 session.connection_id,
3587 left_buffer.connections,
3588 &proto::UpdateChannelBufferCollaborators {
3589 channel_id: channel_id.to_proto(),
3590 collaborators: left_buffer.collaborators,
3591 },
3592 &session.peer,
3593 );
3594
3595 Ok(())
3596}
3597
3598fn channel_buffer_updated<T: EnvelopedMessage>(
3599 sender_id: ConnectionId,
3600 collaborators: impl IntoIterator<Item = ConnectionId>,
3601 message: &T,
3602 peer: &Peer,
3603) {
3604 broadcast(Some(sender_id), collaborators, |peer_id| {
3605 peer.send(peer_id, message.clone())
3606 });
3607}
3608
3609fn send_notifications(
3610 connection_pool: &ConnectionPool,
3611 peer: &Peer,
3612 notifications: db::NotificationBatch,
3613) {
3614 for (user_id, notification) in notifications {
3615 for connection_id in connection_pool.user_connection_ids(user_id) {
3616 if let Err(error) = peer.send(
3617 connection_id,
3618 proto::AddNotification {
3619 notification: Some(notification.clone()),
3620 },
3621 ) {
3622 tracing::error!(
3623 "failed to send notification to {:?} {}",
3624 connection_id,
3625 error
3626 );
3627 }
3628 }
3629 }
3630}
3631
3632/// Send a message to the channel
3633async fn send_channel_message(
3634 request: proto::SendChannelMessage,
3635 response: Response<proto::SendChannelMessage>,
3636 session: Session,
3637) -> Result<()> {
3638 // Validate the message body.
3639 let body = request.body.trim().to_string();
3640 if body.len() > MAX_MESSAGE_LEN {
3641 return Err(anyhow!("message is too long"))?;
3642 }
3643 if body.is_empty() {
3644 return Err(anyhow!("message can't be blank"))?;
3645 }
3646
3647 // TODO: adjust mentions if body is trimmed
3648
3649 let timestamp = OffsetDateTime::now_utc();
3650 let nonce = request.nonce.context("nonce can't be blank")?;
3651
3652 let channel_id = ChannelId::from_proto(request.channel_id);
3653 let CreatedChannelMessage {
3654 message_id,
3655 participant_connection_ids,
3656 notifications,
3657 } = session
3658 .db()
3659 .await
3660 .create_channel_message(
3661 channel_id,
3662 session.user_id(),
3663 &body,
3664 &request.mentions,
3665 timestamp,
3666 nonce.clone().into(),
3667 request.reply_to_message_id.map(MessageId::from_proto),
3668 )
3669 .await?;
3670
3671 let message = proto::ChannelMessage {
3672 sender_id: session.user_id().to_proto(),
3673 id: message_id.to_proto(),
3674 body,
3675 mentions: request.mentions,
3676 timestamp: timestamp.unix_timestamp() as u64,
3677 nonce: Some(nonce),
3678 reply_to_message_id: request.reply_to_message_id,
3679 edited_at: None,
3680 };
3681 broadcast(
3682 Some(session.connection_id),
3683 participant_connection_ids.clone(),
3684 |connection| {
3685 session.peer.send(
3686 connection,
3687 proto::ChannelMessageSent {
3688 channel_id: channel_id.to_proto(),
3689 message: Some(message.clone()),
3690 },
3691 )
3692 },
3693 );
3694 response.send(proto::SendChannelMessageResponse {
3695 message: Some(message),
3696 })?;
3697
3698 let pool = &*session.connection_pool().await;
3699 let non_participants =
3700 pool.channel_connection_ids(channel_id)
3701 .filter_map(|(connection_id, _)| {
3702 if participant_connection_ids.contains(&connection_id) {
3703 None
3704 } else {
3705 Some(connection_id)
3706 }
3707 });
3708 broadcast(None, non_participants, |peer_id| {
3709 session.peer.send(
3710 peer_id,
3711 proto::UpdateChannels {
3712 latest_channel_message_ids: vec![proto::ChannelMessageId {
3713 channel_id: channel_id.to_proto(),
3714 message_id: message_id.to_proto(),
3715 }],
3716 ..Default::default()
3717 },
3718 )
3719 });
3720 send_notifications(pool, &session.peer, notifications);
3721
3722 Ok(())
3723}
3724
3725/// Delete a channel message
3726async fn remove_channel_message(
3727 request: proto::RemoveChannelMessage,
3728 response: Response<proto::RemoveChannelMessage>,
3729 session: Session,
3730) -> Result<()> {
3731 let channel_id = ChannelId::from_proto(request.channel_id);
3732 let message_id = MessageId::from_proto(request.message_id);
3733 let (connection_ids, existing_notification_ids) = session
3734 .db()
3735 .await
3736 .remove_channel_message(channel_id, message_id, session.user_id())
3737 .await?;
3738
3739 broadcast(
3740 Some(session.connection_id),
3741 connection_ids,
3742 move |connection| {
3743 session.peer.send(connection, request.clone())?;
3744
3745 for notification_id in &existing_notification_ids {
3746 session.peer.send(
3747 connection,
3748 proto::DeleteNotification {
3749 notification_id: (*notification_id).to_proto(),
3750 },
3751 )?;
3752 }
3753
3754 Ok(())
3755 },
3756 );
3757 response.send(proto::Ack {})?;
3758 Ok(())
3759}
3760
3761async fn update_channel_message(
3762 request: proto::UpdateChannelMessage,
3763 response: Response<proto::UpdateChannelMessage>,
3764 session: Session,
3765) -> Result<()> {
3766 let channel_id = ChannelId::from_proto(request.channel_id);
3767 let message_id = MessageId::from_proto(request.message_id);
3768 let updated_at = OffsetDateTime::now_utc();
3769 let UpdatedChannelMessage {
3770 message_id,
3771 participant_connection_ids,
3772 notifications,
3773 reply_to_message_id,
3774 timestamp,
3775 deleted_mention_notification_ids,
3776 updated_mention_notifications,
3777 } = session
3778 .db()
3779 .await
3780 .update_channel_message(
3781 channel_id,
3782 message_id,
3783 session.user_id(),
3784 request.body.as_str(),
3785 &request.mentions,
3786 updated_at,
3787 )
3788 .await?;
3789
3790 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3791
3792 let message = proto::ChannelMessage {
3793 sender_id: session.user_id().to_proto(),
3794 id: message_id.to_proto(),
3795 body: request.body.clone(),
3796 mentions: request.mentions.clone(),
3797 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3798 nonce: Some(nonce),
3799 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3800 edited_at: Some(updated_at.unix_timestamp() as u64),
3801 };
3802
3803 response.send(proto::Ack {})?;
3804
3805 let pool = &*session.connection_pool().await;
3806 broadcast(
3807 Some(session.connection_id),
3808 participant_connection_ids,
3809 |connection| {
3810 session.peer.send(
3811 connection,
3812 proto::ChannelMessageUpdate {
3813 channel_id: channel_id.to_proto(),
3814 message: Some(message.clone()),
3815 },
3816 )?;
3817
3818 for notification_id in &deleted_mention_notification_ids {
3819 session.peer.send(
3820 connection,
3821 proto::DeleteNotification {
3822 notification_id: (*notification_id).to_proto(),
3823 },
3824 )?;
3825 }
3826
3827 for notification in &updated_mention_notifications {
3828 session.peer.send(
3829 connection,
3830 proto::UpdateNotification {
3831 notification: Some(notification.clone()),
3832 },
3833 )?;
3834 }
3835
3836 Ok(())
3837 },
3838 );
3839
3840 send_notifications(pool, &session.peer, notifications);
3841
3842 Ok(())
3843}
3844
3845/// Mark a channel message as read
3846async fn acknowledge_channel_message(
3847 request: proto::AckChannelMessage,
3848 session: Session,
3849) -> Result<()> {
3850 let channel_id = ChannelId::from_proto(request.channel_id);
3851 let message_id = MessageId::from_proto(request.message_id);
3852 let notifications = session
3853 .db()
3854 .await
3855 .observe_channel_message(channel_id, session.user_id(), message_id)
3856 .await?;
3857 send_notifications(
3858 &*session.connection_pool().await,
3859 &session.peer,
3860 notifications,
3861 );
3862 Ok(())
3863}
3864
3865/// Mark a buffer version as synced
3866async fn acknowledge_buffer_version(
3867 request: proto::AckBufferOperation,
3868 session: Session,
3869) -> Result<()> {
3870 let buffer_id = BufferId::from_proto(request.buffer_id);
3871 session
3872 .db()
3873 .await
3874 .observe_buffer_version(
3875 buffer_id,
3876 session.user_id(),
3877 request.epoch as i32,
3878 &request.version,
3879 )
3880 .await?;
3881 Ok(())
3882}
3883
3884/// Get a Supermaven API key for the user
3885async fn get_supermaven_api_key(
3886 _request: proto::GetSupermavenApiKey,
3887 response: Response<proto::GetSupermavenApiKey>,
3888 session: Session,
3889) -> Result<()> {
3890 let user_id: String = session.user_id().to_string();
3891 if !session.is_staff() {
3892 return Err(anyhow!("supermaven not enabled for this account"))?;
3893 }
3894
3895 let email = session.email().context("user must have an email")?;
3896
3897 let supermaven_admin_api = session
3898 .supermaven_client
3899 .as_ref()
3900 .context("supermaven not configured")?;
3901
3902 let result = supermaven_admin_api
3903 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3904 .await?;
3905
3906 response.send(proto::GetSupermavenApiKeyResponse {
3907 api_key: result.api_key,
3908 })?;
3909
3910 Ok(())
3911}
3912
3913/// Start receiving chat updates for a channel
3914async fn join_channel_chat(
3915 request: proto::JoinChannelChat,
3916 response: Response<proto::JoinChannelChat>,
3917 session: Session,
3918) -> Result<()> {
3919 let channel_id = ChannelId::from_proto(request.channel_id);
3920
3921 let db = session.db().await;
3922 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3923 .await?;
3924 let messages = db
3925 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3926 .await?;
3927 response.send(proto::JoinChannelChatResponse {
3928 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3929 messages,
3930 })?;
3931 Ok(())
3932}
3933
3934/// Stop receiving chat updates for a channel
3935async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3936 let channel_id = ChannelId::from_proto(request.channel_id);
3937 session
3938 .db()
3939 .await
3940 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3941 .await?;
3942 Ok(())
3943}
3944
3945/// Retrieve the chat history for a channel
3946async fn get_channel_messages(
3947 request: proto::GetChannelMessages,
3948 response: Response<proto::GetChannelMessages>,
3949 session: Session,
3950) -> Result<()> {
3951 let channel_id = ChannelId::from_proto(request.channel_id);
3952 let messages = session
3953 .db()
3954 .await
3955 .get_channel_messages(
3956 channel_id,
3957 session.user_id(),
3958 MESSAGE_COUNT_PER_PAGE,
3959 Some(MessageId::from_proto(request.before_message_id)),
3960 )
3961 .await?;
3962 response.send(proto::GetChannelMessagesResponse {
3963 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3964 messages,
3965 })?;
3966 Ok(())
3967}
3968
3969/// Retrieve specific chat messages
3970async fn get_channel_messages_by_id(
3971 request: proto::GetChannelMessagesById,
3972 response: Response<proto::GetChannelMessagesById>,
3973 session: Session,
3974) -> Result<()> {
3975 let message_ids = request
3976 .message_ids
3977 .iter()
3978 .map(|id| MessageId::from_proto(*id))
3979 .collect::<Vec<_>>();
3980 let messages = session
3981 .db()
3982 .await
3983 .get_channel_messages_by_id(session.user_id(), &message_ids)
3984 .await?;
3985 response.send(proto::GetChannelMessagesResponse {
3986 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3987 messages,
3988 })?;
3989 Ok(())
3990}
3991
3992/// Retrieve the current users notifications
3993async fn get_notifications(
3994 request: proto::GetNotifications,
3995 response: Response<proto::GetNotifications>,
3996 session: Session,
3997) -> Result<()> {
3998 let notifications = session
3999 .db()
4000 .await
4001 .get_notifications(
4002 session.user_id(),
4003 NOTIFICATION_COUNT_PER_PAGE,
4004 request.before_id.map(db::NotificationId::from_proto),
4005 )
4006 .await?;
4007 response.send(proto::GetNotificationsResponse {
4008 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4009 notifications,
4010 })?;
4011 Ok(())
4012}
4013
4014/// Mark notifications as read
4015async fn mark_notification_as_read(
4016 request: proto::MarkNotificationRead,
4017 response: Response<proto::MarkNotificationRead>,
4018 session: Session,
4019) -> Result<()> {
4020 let database = &session.db().await;
4021 let notifications = database
4022 .mark_notification_as_read_by_id(
4023 session.user_id(),
4024 NotificationId::from_proto(request.notification_id),
4025 )
4026 .await?;
4027 send_notifications(
4028 &*session.connection_pool().await,
4029 &session.peer,
4030 notifications,
4031 );
4032 response.send(proto::Ack {})?;
4033 Ok(())
4034}
4035
4036/// Get the current users information
4037async fn get_private_user_info(
4038 _request: proto::GetPrivateUserInfo,
4039 response: Response<proto::GetPrivateUserInfo>,
4040 session: Session,
4041) -> Result<()> {
4042 let db = session.db().await;
4043
4044 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4045 let user = db
4046 .get_user_by_id(session.user_id())
4047 .await?
4048 .context("user not found")?;
4049 let flags = db.get_user_flags(session.user_id()).await?;
4050
4051 response.send(proto::GetPrivateUserInfoResponse {
4052 metrics_id,
4053 staff: user.admin,
4054 flags,
4055 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4056 })?;
4057 Ok(())
4058}
4059
4060/// Accept the terms of service (tos) on behalf of the current user
4061async fn accept_terms_of_service(
4062 _request: proto::AcceptTermsOfService,
4063 response: Response<proto::AcceptTermsOfService>,
4064 session: Session,
4065) -> Result<()> {
4066 let db = session.db().await;
4067
4068 let accepted_tos_at = Utc::now();
4069 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4070 .await?;
4071
4072 response.send(proto::AcceptTermsOfServiceResponse {
4073 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4074 })?;
4075 Ok(())
4076}
4077
4078/// The minimum account age an account must have in order to use the LLM service.
4079pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4080
4081async fn get_llm_api_token(
4082 _request: proto::GetLlmToken,
4083 response: Response<proto::GetLlmToken>,
4084 session: Session,
4085) -> Result<()> {
4086 let db = session.db().await;
4087
4088 let flags = db.get_user_flags(session.user_id()).await?;
4089
4090 let user_id = session.user_id();
4091 let user = db
4092 .get_user_by_id(user_id)
4093 .await?
4094 .with_context(|| format!("user {user_id} not found"))?;
4095
4096 if user.accepted_tos_at.is_none() {
4097 Err(anyhow!("terms of service not accepted"))?
4098 }
4099
4100 let stripe_client = session
4101 .app_state
4102 .stripe_client
4103 .as_ref()
4104 .context("failed to retrieve Stripe client")?;
4105
4106 let stripe_billing = session
4107 .app_state
4108 .stripe_billing
4109 .as_ref()
4110 .context("failed to retrieve Stripe billing object")?;
4111
4112 let billing_customer = if let Some(billing_customer) =
4113 db.get_billing_customer_by_user_id(user.id).await?
4114 {
4115 billing_customer
4116 } else {
4117 let customer_id = stripe_billing
4118 .find_or_create_customer_by_email(user.email_address.as_deref())
4119 .await?;
4120
4121 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4122 .await?
4123 .context("billing customer not found")?
4124 };
4125
4126 let billing_subscription =
4127 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4128 billing_subscription
4129 } else {
4130 let stripe_customer_id =
4131 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4132
4133 let stripe_subscription = stripe_billing
4134 .subscribe_to_zed_free(stripe_customer_id)
4135 .await?;
4136
4137 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4138 billing_customer_id: billing_customer.id,
4139 kind: Some(SubscriptionKind::ZedFree),
4140 stripe_subscription_id: stripe_subscription.id.to_string(),
4141 stripe_subscription_status: stripe_subscription.status.into(),
4142 stripe_cancellation_reason: None,
4143 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4144 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4145 })
4146 .await?
4147 };
4148
4149 let billing_preferences = db.get_billing_preferences(user.id).await?;
4150
4151 let token = LlmTokenClaims::create(
4152 &user,
4153 session.is_staff(),
4154 billing_customer,
4155 billing_preferences,
4156 &flags,
4157 billing_subscription,
4158 session.system_id.clone(),
4159 &session.app_state.config,
4160 )?;
4161 response.send(proto::GetLlmTokenResponse { token })?;
4162 Ok(())
4163}
4164
4165fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4166 let message = match message {
4167 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4168 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4169 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4170 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4171 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4172 code: frame.code.into(),
4173 reason: frame.reason.as_str().to_owned().into(),
4174 })),
4175 // We should never receive a frame while reading the message, according
4176 // to the `tungstenite` maintainers:
4177 //
4178 // > It cannot occur when you read messages from the WebSocket, but it
4179 // > can be used when you want to send the raw frames (e.g. you want to
4180 // > send the frames to the WebSocket without composing the full message first).
4181 // >
4182 // > — https://github.com/snapview/tungstenite-rs/issues/268
4183 TungsteniteMessage::Frame(_) => {
4184 bail!("received an unexpected frame while reading the message")
4185 }
4186 };
4187
4188 Ok(message)
4189}
4190
4191fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4192 match message {
4193 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4194 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4195 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4196 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4197 AxumMessage::Close(frame) => {
4198 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4199 code: frame.code.into(),
4200 reason: frame.reason.as_ref().into(),
4201 }))
4202 }
4203 }
4204}
4205
4206fn notify_membership_updated(
4207 connection_pool: &mut ConnectionPool,
4208 result: MembershipUpdated,
4209 user_id: UserId,
4210 peer: &Peer,
4211) {
4212 for membership in &result.new_channels.channel_memberships {
4213 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4214 }
4215 for channel_id in &result.removed_channels {
4216 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4217 }
4218
4219 let user_channels_update = proto::UpdateUserChannels {
4220 channel_memberships: result
4221 .new_channels
4222 .channel_memberships
4223 .iter()
4224 .map(|cm| proto::ChannelMembership {
4225 channel_id: cm.channel_id.to_proto(),
4226 role: cm.role.into(),
4227 })
4228 .collect(),
4229 ..Default::default()
4230 };
4231
4232 let mut update = build_channels_update(result.new_channels);
4233 update.delete_channels = result
4234 .removed_channels
4235 .into_iter()
4236 .map(|id| id.to_proto())
4237 .collect();
4238 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4239
4240 for connection_id in connection_pool.user_connection_ids(user_id) {
4241 peer.send(connection_id, user_channels_update.clone())
4242 .trace_err();
4243 peer.send(connection_id, update.clone()).trace_err();
4244 }
4245}
4246
4247fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4248 proto::UpdateUserChannels {
4249 channel_memberships: channels
4250 .channel_memberships
4251 .iter()
4252 .map(|m| proto::ChannelMembership {
4253 channel_id: m.channel_id.to_proto(),
4254 role: m.role.into(),
4255 })
4256 .collect(),
4257 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4258 observed_channel_message_id: channels.observed_channel_messages.clone(),
4259 }
4260}
4261
4262fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4263 let mut update = proto::UpdateChannels::default();
4264
4265 for channel in channels.channels {
4266 update.channels.push(channel.to_proto());
4267 }
4268
4269 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4270 update.latest_channel_message_ids = channels.latest_channel_messages;
4271
4272 for (channel_id, participants) in channels.channel_participants {
4273 update
4274 .channel_participants
4275 .push(proto::ChannelParticipants {
4276 channel_id: channel_id.to_proto(),
4277 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4278 });
4279 }
4280
4281 for channel in channels.invited_channels {
4282 update.channel_invitations.push(channel.to_proto());
4283 }
4284
4285 update
4286}
4287
4288fn build_initial_contacts_update(
4289 contacts: Vec<db::Contact>,
4290 pool: &ConnectionPool,
4291) -> proto::UpdateContacts {
4292 let mut update = proto::UpdateContacts::default();
4293
4294 for contact in contacts {
4295 match contact {
4296 db::Contact::Accepted { user_id, busy } => {
4297 update.contacts.push(contact_for_user(user_id, busy, pool));
4298 }
4299 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4300 db::Contact::Incoming { user_id } => {
4301 update
4302 .incoming_requests
4303 .push(proto::IncomingContactRequest {
4304 requester_id: user_id.to_proto(),
4305 })
4306 }
4307 }
4308 }
4309
4310 update
4311}
4312
4313fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4314 proto::Contact {
4315 user_id: user_id.to_proto(),
4316 online: pool.is_user_online(user_id),
4317 busy,
4318 }
4319}
4320
4321fn room_updated(room: &proto::Room, peer: &Peer) {
4322 broadcast(
4323 None,
4324 room.participants
4325 .iter()
4326 .filter_map(|participant| Some(participant.peer_id?.into())),
4327 |peer_id| {
4328 peer.send(
4329 peer_id,
4330 proto::RoomUpdated {
4331 room: Some(room.clone()),
4332 },
4333 )
4334 },
4335 );
4336}
4337
4338fn channel_updated(
4339 channel: &db::channel::Model,
4340 room: &proto::Room,
4341 peer: &Peer,
4342 pool: &ConnectionPool,
4343) {
4344 let participants = room
4345 .participants
4346 .iter()
4347 .map(|p| p.user_id)
4348 .collect::<Vec<_>>();
4349
4350 broadcast(
4351 None,
4352 pool.channel_connection_ids(channel.root_id())
4353 .filter_map(|(channel_id, role)| {
4354 role.can_see_channel(channel.visibility)
4355 .then_some(channel_id)
4356 }),
4357 |peer_id| {
4358 peer.send(
4359 peer_id,
4360 proto::UpdateChannels {
4361 channel_participants: vec![proto::ChannelParticipants {
4362 channel_id: channel.id.to_proto(),
4363 participant_user_ids: participants.clone(),
4364 }],
4365 ..Default::default()
4366 },
4367 )
4368 },
4369 );
4370}
4371
4372async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4373 let db = session.db().await;
4374
4375 let contacts = db.get_contacts(user_id).await?;
4376 let busy = db.is_user_busy(user_id).await?;
4377
4378 let pool = session.connection_pool().await;
4379 let updated_contact = contact_for_user(user_id, busy, &pool);
4380 for contact in contacts {
4381 if let db::Contact::Accepted {
4382 user_id: contact_user_id,
4383 ..
4384 } = contact
4385 {
4386 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4387 session
4388 .peer
4389 .send(
4390 contact_conn_id,
4391 proto::UpdateContacts {
4392 contacts: vec![updated_contact.clone()],
4393 remove_contacts: Default::default(),
4394 incoming_requests: Default::default(),
4395 remove_incoming_requests: Default::default(),
4396 outgoing_requests: Default::default(),
4397 remove_outgoing_requests: Default::default(),
4398 },
4399 )
4400 .trace_err();
4401 }
4402 }
4403 }
4404 Ok(())
4405}
4406
4407async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4408 let mut contacts_to_update = HashSet::default();
4409
4410 let room_id;
4411 let canceled_calls_to_user_ids;
4412 let livekit_room;
4413 let delete_livekit_room;
4414 let room;
4415 let channel;
4416
4417 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4418 contacts_to_update.insert(session.user_id());
4419
4420 for project in left_room.left_projects.values() {
4421 project_left(project, session);
4422 }
4423
4424 room_id = RoomId::from_proto(left_room.room.id);
4425 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4426 livekit_room = mem::take(&mut left_room.room.livekit_room);
4427 delete_livekit_room = left_room.deleted;
4428 room = mem::take(&mut left_room.room);
4429 channel = mem::take(&mut left_room.channel);
4430
4431 room_updated(&room, &session.peer);
4432 } else {
4433 return Ok(());
4434 }
4435
4436 if let Some(channel) = channel {
4437 channel_updated(
4438 &channel,
4439 &room,
4440 &session.peer,
4441 &*session.connection_pool().await,
4442 );
4443 }
4444
4445 {
4446 let pool = session.connection_pool().await;
4447 for canceled_user_id in canceled_calls_to_user_ids {
4448 for connection_id in pool.user_connection_ids(canceled_user_id) {
4449 session
4450 .peer
4451 .send(
4452 connection_id,
4453 proto::CallCanceled {
4454 room_id: room_id.to_proto(),
4455 },
4456 )
4457 .trace_err();
4458 }
4459 contacts_to_update.insert(canceled_user_id);
4460 }
4461 }
4462
4463 for contact_user_id in contacts_to_update {
4464 update_user_contacts(contact_user_id, session).await?;
4465 }
4466
4467 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4468 live_kit
4469 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4470 .await
4471 .trace_err();
4472
4473 if delete_livekit_room {
4474 live_kit.delete_room(livekit_room).await.trace_err();
4475 }
4476 }
4477
4478 Ok(())
4479}
4480
4481async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4482 let left_channel_buffers = session
4483 .db()
4484 .await
4485 .leave_channel_buffers(session.connection_id)
4486 .await?;
4487
4488 for left_buffer in left_channel_buffers {
4489 channel_buffer_updated(
4490 session.connection_id,
4491 left_buffer.connections,
4492 &proto::UpdateChannelBufferCollaborators {
4493 channel_id: left_buffer.channel_id.to_proto(),
4494 collaborators: left_buffer.collaborators,
4495 },
4496 &session.peer,
4497 );
4498 }
4499
4500 Ok(())
4501}
4502
4503fn project_left(project: &db::LeftProject, session: &Session) {
4504 for connection_id in &project.connection_ids {
4505 if project.should_unshare {
4506 session
4507 .peer
4508 .send(
4509 *connection_id,
4510 proto::UnshareProject {
4511 project_id: project.id.to_proto(),
4512 },
4513 )
4514 .trace_err();
4515 } else {
4516 session
4517 .peer
4518 .send(
4519 *connection_id,
4520 proto::RemoveProjectCollaborator {
4521 project_id: project.id.to_proto(),
4522 peer_id: Some(session.connection_id.into()),
4523 },
4524 )
4525 .trace_err();
4526 }
4527 }
4528}
4529
4530pub trait ResultExt {
4531 type Ok;
4532
4533 fn trace_err(self) -> Option<Self::Ok>;
4534}
4535
4536impl<T, E> ResultExt for Result<T, E>
4537where
4538 E: std::fmt::Debug,
4539{
4540 type Ok = T;
4541
4542 #[track_caller]
4543 fn trace_err(self) -> Option<T> {
4544 match self {
4545 Ok(value) => Some(value),
4546 Err(error) => {
4547 tracing::error!("{:?}", error);
4548 None
4549 }
4550 }
4551 }
4552}