1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::{
27 Extension, Router, TypedHeader,
28 body::Body,
29 extract::{
30 ConnectInfo, WebSocketUpgrade,
31 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
32 },
33 headers::{Header, HeaderName},
34 http::StatusCode,
35 middleware,
36 response::IntoResponse,
37 routing::get,
38};
39use chrono::Utc;
40use collections::{HashMap, HashSet};
41pub use connection_pool::{ConnectionPool, ZedVersion};
42use core::fmt::{self, Debug, Formatter};
43use reqwest_client::ReqwestClient;
44use rpc::proto::split_repository_update;
45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
46
47use futures::{
48 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
49 stream::FuturesUnordered,
50};
51use prometheus::{IntGauge, register_int_gauge};
52use rpc::{
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54 proto::{
55 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
56 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
57 },
58};
59use semantic_version::SemanticVersion;
60use serde::{Serialize, Serializer};
61use std::{
62 any::TypeId,
63 future::Future,
64 marker::PhantomData,
65 mem,
66 net::SocketAddr,
67 ops::{Deref, DerefMut},
68 rc::Rc,
69 sync::{
70 Arc, OnceLock,
71 atomic::{AtomicBool, Ordering::SeqCst},
72 },
73 time::{Duration, Instant},
74};
75use time::OffsetDateTime;
76use tokio::sync::{Semaphore, watch};
77use tower::ServiceBuilder;
78use tracing::{
79 Instrument,
80 field::{self},
81 info_span, instrument,
82};
83
84pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
85
86// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
87pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
88
89const MESSAGE_COUNT_PER_PAGE: usize = 100;
90const MAX_MESSAGE_LEN: usize = 1024;
91const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
92
93type MessageHandler =
94 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
95
96struct Response<R> {
97 peer: Arc<Peer>,
98 receipt: Receipt<R>,
99 responded: Arc<AtomicBool>,
100}
101
102impl<R: RequestMessage> Response<R> {
103 fn send(self, payload: R::Response) -> Result<()> {
104 self.responded.store(true, SeqCst);
105 self.peer.respond(self.receipt, payload)?;
106 Ok(())
107 }
108}
109
110#[derive(Clone, Debug)]
111pub enum Principal {
112 User(User),
113 Impersonated { user: User, admin: User },
114}
115
116impl Principal {
117 fn user(&self) -> &User {
118 match self {
119 Principal::User(user) => user,
120 Principal::Impersonated { user, .. } => user,
121 }
122 }
123
124 fn update_span(&self, span: &tracing::Span) {
125 match &self {
126 Principal::User(user) => {
127 span.record("user_id", user.id.0);
128 span.record("login", &user.github_login);
129 }
130 Principal::Impersonated { user, admin } => {
131 span.record("user_id", user.id.0);
132 span.record("login", &user.github_login);
133 span.record("impersonator", &admin.github_login);
134 }
135 }
136 }
137}
138
139#[derive(Clone)]
140struct Session {
141 principal: Principal,
142 connection_id: ConnectionId,
143 db: Arc<tokio::sync::Mutex<DbHandle>>,
144 peer: Arc<Peer>,
145 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
146 app_state: Arc<AppState>,
147 supermaven_client: Option<Arc<SupermavenAdminApi>>,
148 /// The GeoIP country code for the user.
149 #[allow(unused)]
150 geoip_country_code: Option<String>,
151 system_id: Option<String>,
152 _executor: Executor,
153}
154
155impl Session {
156 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
157 #[cfg(test)]
158 tokio::task::yield_now().await;
159 let guard = self.db.lock().await;
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 guard
163 }
164
165 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
166 #[cfg(test)]
167 tokio::task::yield_now().await;
168 let guard = self.connection_pool.lock();
169 ConnectionPoolGuard {
170 guard,
171 _not_send: PhantomData,
172 }
173 }
174
175 fn is_staff(&self) -> bool {
176 match &self.principal {
177 Principal::User(user) => user.admin,
178 Principal::Impersonated { .. } => true,
179 }
180 }
181
182 fn user_id(&self) -> UserId {
183 match &self.principal {
184 Principal::User(user) => user.id,
185 Principal::Impersonated { user, .. } => user.id,
186 }
187 }
188
189 pub fn email(&self) -> Option<String> {
190 match &self.principal {
191 Principal::User(user) => user.email_address.clone(),
192 Principal::Impersonated { user, .. } => user.email_address.clone(),
193 }
194 }
195}
196
197impl Debug for Session {
198 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
199 let mut result = f.debug_struct("Session");
200 match &self.principal {
201 Principal::User(user) => {
202 result.field("user", &user.github_login);
203 }
204 Principal::Impersonated { user, admin } => {
205 result.field("user", &user.github_login);
206 result.field("impersonator", &admin.github_login);
207 }
208 }
209 result.field("connection_id", &self.connection_id).finish()
210 }
211}
212
213struct DbHandle(Arc<Database>);
214
215impl Deref for DbHandle {
216 type Target = Database;
217
218 fn deref(&self) -> &Self::Target {
219 self.0.as_ref()
220 }
221}
222
223pub struct Server {
224 id: parking_lot::Mutex<ServerId>,
225 peer: Arc<Peer>,
226 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
227 app_state: Arc<AppState>,
228 handlers: HashMap<TypeId, MessageHandler>,
229 teardown: watch::Sender<bool>,
230}
231
232pub(crate) struct ConnectionPoolGuard<'a> {
233 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
234 _not_send: PhantomData<Rc<()>>,
235}
236
237#[derive(Serialize)]
238pub struct ServerSnapshot<'a> {
239 peer: &'a Peer,
240 #[serde(serialize_with = "serialize_deref")]
241 connection_pool: ConnectionPoolGuard<'a>,
242}
243
244pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
245where
246 S: Serializer,
247 T: Deref<Target = U>,
248 U: Serialize,
249{
250 Serialize::serialize(value.deref(), serializer)
251}
252
253impl Server {
254 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
255 let mut server = Self {
256 id: parking_lot::Mutex::new(id),
257 peer: Peer::new(id.0 as u32),
258 app_state: app_state.clone(),
259 connection_pool: Default::default(),
260 handlers: Default::default(),
261 teardown: watch::channel(false).0,
262 };
263
264 server
265 .add_request_handler(ping)
266 .add_request_handler(create_room)
267 .add_request_handler(join_room)
268 .add_request_handler(rejoin_room)
269 .add_request_handler(leave_room)
270 .add_request_handler(set_room_participant_role)
271 .add_request_handler(call)
272 .add_request_handler(cancel_call)
273 .add_message_handler(decline_call)
274 .add_request_handler(update_participant_location)
275 .add_request_handler(share_project)
276 .add_message_handler(unshare_project)
277 .add_request_handler(join_project)
278 .add_message_handler(leave_project)
279 .add_request_handler(update_project)
280 .add_request_handler(update_worktree)
281 .add_request_handler(update_repository)
282 .add_request_handler(remove_repository)
283 .add_message_handler(start_language_server)
284 .add_message_handler(update_language_server)
285 .add_message_handler(update_diagnostic_summary)
286 .add_message_handler(update_worktree_settings)
287 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
288 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
289 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
290 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
291 .add_request_handler(forward_find_search_candidates_request)
292 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
293 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
294 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
295 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
296 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
297 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
298 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
299 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
300 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
301 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
302 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
303 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
304 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
305 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
306 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
307 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
308 .add_request_handler(
309 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
310 )
311 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
312 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
313 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
314 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
315 .add_request_handler(
316 forward_read_only_project_request::<proto::LanguageServerIdForName>,
317 )
318 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
323 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
324 .add_request_handler(
325 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
326 )
327 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
332 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
333 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
334 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
335 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
336 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
337 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
338 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
339 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
340 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
341 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
342 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
343 .add_request_handler(
344 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
345 )
346 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
347 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
348 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
349 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
350 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
351 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
352 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
353 .add_message_handler(create_buffer_for_peer)
354 .add_request_handler(update_buffer)
355 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
356 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
357 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
358 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
359 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
360 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
361 .add_message_handler(
362 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
363 )
364 .add_request_handler(get_users)
365 .add_request_handler(fuzzy_search_users)
366 .add_request_handler(request_contact)
367 .add_request_handler(remove_contact)
368 .add_request_handler(respond_to_contact_request)
369 .add_message_handler(subscribe_to_channels)
370 .add_request_handler(create_channel)
371 .add_request_handler(delete_channel)
372 .add_request_handler(invite_channel_member)
373 .add_request_handler(remove_channel_member)
374 .add_request_handler(set_channel_member_role)
375 .add_request_handler(set_channel_visibility)
376 .add_request_handler(rename_channel)
377 .add_request_handler(join_channel_buffer)
378 .add_request_handler(leave_channel_buffer)
379 .add_message_handler(update_channel_buffer)
380 .add_request_handler(rejoin_channel_buffers)
381 .add_request_handler(get_channel_members)
382 .add_request_handler(respond_to_channel_invite)
383 .add_request_handler(join_channel)
384 .add_request_handler(join_channel_chat)
385 .add_message_handler(leave_channel_chat)
386 .add_request_handler(send_channel_message)
387 .add_request_handler(remove_channel_message)
388 .add_request_handler(update_channel_message)
389 .add_request_handler(get_channel_messages)
390 .add_request_handler(get_channel_messages_by_id)
391 .add_request_handler(get_notifications)
392 .add_request_handler(mark_notification_as_read)
393 .add_request_handler(move_channel)
394 .add_request_handler(reorder_channel)
395 .add_request_handler(follow)
396 .add_message_handler(unfollow)
397 .add_message_handler(update_followers)
398 .add_request_handler(get_private_user_info)
399 .add_request_handler(get_llm_api_token)
400 .add_request_handler(accept_terms_of_service)
401 .add_message_handler(acknowledge_channel_message)
402 .add_message_handler(acknowledge_buffer_version)
403 .add_request_handler(get_supermaven_api_key)
404 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
405 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
406 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
407 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
408 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
409 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
410 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
411 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
412 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
413 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
414 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
415 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
416 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
417 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
418 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
419 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
420 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
421 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
422 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
423 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
424 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
425 .add_message_handler(update_context);
426
427 Arc::new(server)
428 }
429
430 pub async fn start(&self) -> Result<()> {
431 let server_id = *self.id.lock();
432 let app_state = self.app_state.clone();
433 let peer = self.peer.clone();
434 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
435 let pool = self.connection_pool.clone();
436 let livekit_client = self.app_state.livekit_client.clone();
437
438 let span = info_span!("start server");
439 self.app_state.executor.spawn_detached(
440 async move {
441 tracing::info!("waiting for cleanup timeout");
442 timeout.await;
443 tracing::info!("cleanup timeout expired, retrieving stale rooms");
444
445 app_state
446 .db
447 .delete_stale_channel_chat_participants(
448 &app_state.config.zed_environment,
449 server_id,
450 )
451 .await
452 .trace_err();
453
454 if let Some((room_ids, channel_ids)) = app_state
455 .db
456 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
457 .await
458 .trace_err()
459 {
460 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
461 tracing::info!(
462 stale_channel_buffer_count = channel_ids.len(),
463 "retrieved stale channel buffers"
464 );
465
466 for channel_id in channel_ids {
467 if let Some(refreshed_channel_buffer) = app_state
468 .db
469 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
470 .await
471 .trace_err()
472 {
473 for connection_id in refreshed_channel_buffer.connection_ids {
474 peer.send(
475 connection_id,
476 proto::UpdateChannelBufferCollaborators {
477 channel_id: channel_id.to_proto(),
478 collaborators: refreshed_channel_buffer
479 .collaborators
480 .clone(),
481 },
482 )
483 .trace_err();
484 }
485 }
486 }
487
488 for room_id in room_ids {
489 let mut contacts_to_update = HashSet::default();
490 let mut canceled_calls_to_user_ids = Vec::new();
491 let mut livekit_room = String::new();
492 let mut delete_livekit_room = false;
493
494 if let Some(mut refreshed_room) = app_state
495 .db
496 .clear_stale_room_participants(room_id, server_id)
497 .await
498 .trace_err()
499 {
500 tracing::info!(
501 room_id = room_id.0,
502 new_participant_count = refreshed_room.room.participants.len(),
503 "refreshed room"
504 );
505 room_updated(&refreshed_room.room, &peer);
506 if let Some(channel) = refreshed_room.channel.as_ref() {
507 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
508 }
509 contacts_to_update
510 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
511 contacts_to_update
512 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
513 canceled_calls_to_user_ids =
514 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
515 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
516 delete_livekit_room = refreshed_room.room.participants.is_empty();
517 }
518
519 {
520 let pool = pool.lock();
521 for canceled_user_id in canceled_calls_to_user_ids {
522 for connection_id in pool.user_connection_ids(canceled_user_id) {
523 peer.send(
524 connection_id,
525 proto::CallCanceled {
526 room_id: room_id.to_proto(),
527 },
528 )
529 .trace_err();
530 }
531 }
532 }
533
534 for user_id in contacts_to_update {
535 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
536 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
537 if let Some((busy, contacts)) = busy.zip(contacts) {
538 let pool = pool.lock();
539 let updated_contact = contact_for_user(user_id, busy, &pool);
540 for contact in contacts {
541 if let db::Contact::Accepted {
542 user_id: contact_user_id,
543 ..
544 } = contact
545 {
546 for contact_conn_id in
547 pool.user_connection_ids(contact_user_id)
548 {
549 peer.send(
550 contact_conn_id,
551 proto::UpdateContacts {
552 contacts: vec![updated_contact.clone()],
553 remove_contacts: Default::default(),
554 incoming_requests: Default::default(),
555 remove_incoming_requests: Default::default(),
556 outgoing_requests: Default::default(),
557 remove_outgoing_requests: Default::default(),
558 },
559 )
560 .trace_err();
561 }
562 }
563 }
564 }
565 }
566
567 if let Some(live_kit) = livekit_client.as_ref() {
568 if delete_livekit_room {
569 live_kit.delete_room(livekit_room).await.trace_err();
570 }
571 }
572 }
573 }
574
575 app_state
576 .db
577 .delete_stale_channel_chat_participants(
578 &app_state.config.zed_environment,
579 server_id,
580 )
581 .await
582 .trace_err();
583
584 app_state
585 .db
586 .clear_old_worktree_entries(server_id)
587 .await
588 .trace_err();
589
590 app_state
591 .db
592 .delete_stale_servers(&app_state.config.zed_environment, server_id)
593 .await
594 .trace_err();
595 }
596 .instrument(span),
597 );
598 Ok(())
599 }
600
601 pub fn teardown(&self) {
602 self.peer.teardown();
603 self.connection_pool.lock().reset();
604 let _ = self.teardown.send(true);
605 }
606
607 #[cfg(test)]
608 pub fn reset(&self, id: ServerId) {
609 self.teardown();
610 *self.id.lock() = id;
611 self.peer.reset(id.0 as u32);
612 let _ = self.teardown.send(false);
613 }
614
615 #[cfg(test)]
616 pub fn id(&self) -> ServerId {
617 *self.id.lock()
618 }
619
620 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
621 where
622 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
623 Fut: 'static + Send + Future<Output = Result<()>>,
624 M: EnvelopedMessage,
625 {
626 let prev_handler = self.handlers.insert(
627 TypeId::of::<M>(),
628 Box::new(move |envelope, session| {
629 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
630 let received_at = envelope.received_at;
631 tracing::info!("message received");
632 let start_time = Instant::now();
633 let future = (handler)(*envelope, session);
634 async move {
635 let result = future.await;
636 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
637 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
638 let queue_duration_ms = total_duration_ms - processing_duration_ms;
639 let payload_type = M::NAME;
640
641 match result {
642 Err(error) => {
643 tracing::error!(
644 ?error,
645 total_duration_ms,
646 processing_duration_ms,
647 queue_duration_ms,
648 payload_type,
649 "error handling message"
650 )
651 }
652 Ok(()) => tracing::info!(
653 total_duration_ms,
654 processing_duration_ms,
655 queue_duration_ms,
656 "finished handling message"
657 ),
658 }
659 }
660 .boxed()
661 }),
662 );
663 if prev_handler.is_some() {
664 panic!("registered a handler for the same message twice");
665 }
666 self
667 }
668
669 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
670 where
671 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
672 Fut: 'static + Send + Future<Output = Result<()>>,
673 M: EnvelopedMessage,
674 {
675 self.add_handler(move |envelope, session| handler(envelope.payload, session));
676 self
677 }
678
679 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
680 where
681 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
682 Fut: Send + Future<Output = Result<()>>,
683 M: RequestMessage,
684 {
685 let handler = Arc::new(handler);
686 self.add_handler(move |envelope, session| {
687 let receipt = envelope.receipt();
688 let handler = handler.clone();
689 async move {
690 let peer = session.peer.clone();
691 let responded = Arc::new(AtomicBool::default());
692 let response = Response {
693 peer: peer.clone(),
694 responded: responded.clone(),
695 receipt,
696 };
697 match (handler)(envelope.payload, response, session).await {
698 Ok(()) => {
699 if responded.load(std::sync::atomic::Ordering::SeqCst) {
700 Ok(())
701 } else {
702 Err(anyhow!("handler did not send a response"))?
703 }
704 }
705 Err(error) => {
706 let proto_err = match &error {
707 Error::Internal(err) => err.to_proto(),
708 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
709 };
710 peer.respond_with_error(receipt, proto_err)?;
711 Err(error)
712 }
713 }
714 }
715 })
716 }
717
718 pub fn handle_connection(
719 self: &Arc<Self>,
720 connection: Connection,
721 address: String,
722 principal: Principal,
723 zed_version: ZedVersion,
724 geoip_country_code: Option<String>,
725 system_id: Option<String>,
726 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
727 executor: Executor,
728 ) -> impl Future<Output = ()> + use<> {
729 let this = self.clone();
730 let span = info_span!("handle connection", %address,
731 connection_id=field::Empty,
732 user_id=field::Empty,
733 login=field::Empty,
734 impersonator=field::Empty,
735 geoip_country_code=field::Empty
736 );
737 principal.update_span(&span);
738 if let Some(country_code) = geoip_country_code.as_ref() {
739 span.record("geoip_country_code", country_code);
740 }
741
742 let mut teardown = self.teardown.subscribe();
743 async move {
744 if *teardown.borrow() {
745 tracing::error!("server is tearing down");
746 return
747 }
748 let (connection_id, handle_io, mut incoming_rx) = this
749 .peer
750 .add_connection(connection, {
751 let executor = executor.clone();
752 move |duration| executor.sleep(duration)
753 });
754 tracing::Span::current().record("connection_id", format!("{}", connection_id));
755
756 tracing::info!("connection opened");
757
758 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
759 let http_client = match ReqwestClient::user_agent(&user_agent) {
760 Ok(http_client) => Arc::new(http_client),
761 Err(error) => {
762 tracing::error!(?error, "failed to create HTTP client");
763 return;
764 }
765 };
766
767 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
768 supermaven_admin_api_key.to_string(),
769 http_client.clone(),
770 )));
771
772 let session = Session {
773 principal: principal.clone(),
774 connection_id,
775 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
776 peer: this.peer.clone(),
777 connection_pool: this.connection_pool.clone(),
778 app_state: this.app_state.clone(),
779 geoip_country_code,
780 system_id,
781 _executor: executor.clone(),
782 supermaven_client,
783 };
784
785 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
786 tracing::error!(?error, "failed to send initial client update");
787 return;
788 }
789
790 let handle_io = handle_io.fuse();
791 futures::pin_mut!(handle_io);
792
793 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
794 // This prevents deadlocks when e.g., client A performs a request to client B and
795 // client B performs a request to client A. If both clients stop processing further
796 // messages until their respective request completes, they won't have a chance to
797 // respond to the other client's request and cause a deadlock.
798 //
799 // This arrangement ensures we will attempt to process earlier messages first, but fall
800 // back to processing messages arrived later in the spirit of making progress.
801 let mut foreground_message_handlers = FuturesUnordered::new();
802 let concurrent_handlers = Arc::new(Semaphore::new(256));
803 loop {
804 let next_message = async {
805 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
806 let message = incoming_rx.next().await;
807 (permit, message)
808 }.fuse();
809 futures::pin_mut!(next_message);
810 futures::select_biased! {
811 _ = teardown.changed().fuse() => return,
812 result = handle_io => {
813 if let Err(error) = result {
814 tracing::error!(?error, "error handling I/O");
815 }
816 break;
817 }
818 _ = foreground_message_handlers.next() => {}
819 next_message = next_message => {
820 let (permit, message) = next_message;
821 if let Some(message) = message {
822 let type_name = message.payload_type_name();
823 // note: we copy all the fields from the parent span so we can query them in the logs.
824 // (https://github.com/tokio-rs/tracing/issues/2670).
825 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
826 user_id=field::Empty,
827 login=field::Empty,
828 impersonator=field::Empty,
829 );
830 principal.update_span(&span);
831 let span_enter = span.enter();
832 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
833 let is_background = message.is_background();
834 let handle_message = (handler)(message, session.clone());
835 drop(span_enter);
836
837 let handle_message = async move {
838 handle_message.await;
839 drop(permit);
840 }.instrument(span);
841 if is_background {
842 executor.spawn_detached(handle_message);
843 } else {
844 foreground_message_handlers.push(handle_message);
845 }
846 } else {
847 tracing::error!("no message handler");
848 }
849 } else {
850 tracing::info!("connection closed");
851 break;
852 }
853 }
854 }
855 }
856
857 drop(foreground_message_handlers);
858 tracing::info!("signing out");
859 if let Err(error) = connection_lost(session, teardown, executor).await {
860 tracing::error!(?error, "error signing out");
861 }
862
863 }.instrument(span)
864 }
865
866 async fn send_initial_client_update(
867 &self,
868 connection_id: ConnectionId,
869 zed_version: ZedVersion,
870 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
871 session: &Session,
872 ) -> Result<()> {
873 self.peer.send(
874 connection_id,
875 proto::Hello {
876 peer_id: Some(connection_id.into()),
877 },
878 )?;
879 tracing::info!("sent hello message");
880 if let Some(send_connection_id) = send_connection_id.take() {
881 let _ = send_connection_id.send(connection_id);
882 }
883
884 match &session.principal {
885 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
886 if !user.connected_once {
887 self.peer.send(connection_id, proto::ShowContacts {})?;
888 self.app_state
889 .db
890 .set_user_connected_once(user.id, true)
891 .await?;
892 }
893
894 update_user_plan(session).await?;
895
896 let contacts = self.app_state.db.get_contacts(user.id).await?;
897
898 {
899 let mut pool = self.connection_pool.lock();
900 pool.add_connection(connection_id, user.id, user.admin, zed_version);
901 self.peer.send(
902 connection_id,
903 build_initial_contacts_update(contacts, &pool),
904 )?;
905 }
906
907 if should_auto_subscribe_to_channels(zed_version) {
908 subscribe_user_to_channels(user.id, session).await?;
909 }
910
911 if let Some(incoming_call) =
912 self.app_state.db.incoming_call_for_user(user.id).await?
913 {
914 self.peer.send(connection_id, incoming_call)?;
915 }
916
917 update_user_contacts(user.id, session).await?;
918 }
919 }
920
921 Ok(())
922 }
923
924 pub async fn invite_code_redeemed(
925 self: &Arc<Self>,
926 inviter_id: UserId,
927 invitee_id: UserId,
928 ) -> Result<()> {
929 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
930 if let Some(code) = &user.invite_code {
931 let pool = self.connection_pool.lock();
932 let invitee_contact = contact_for_user(invitee_id, false, &pool);
933 for connection_id in pool.user_connection_ids(inviter_id) {
934 self.peer.send(
935 connection_id,
936 proto::UpdateContacts {
937 contacts: vec![invitee_contact.clone()],
938 ..Default::default()
939 },
940 )?;
941 self.peer.send(
942 connection_id,
943 proto::UpdateInviteInfo {
944 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
945 count: user.invite_count as u32,
946 },
947 )?;
948 }
949 }
950 }
951 Ok(())
952 }
953
954 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
955 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
956 if let Some(invite_code) = &user.invite_code {
957 let pool = self.connection_pool.lock();
958 for connection_id in pool.user_connection_ids(user_id) {
959 self.peer.send(
960 connection_id,
961 proto::UpdateInviteInfo {
962 url: format!(
963 "{}{}",
964 self.app_state.config.invite_link_prefix, invite_code
965 ),
966 count: user.invite_count as u32,
967 },
968 )?;
969 }
970 }
971 }
972 Ok(())
973 }
974
975 pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
976 let user = self
977 .app_state
978 .db
979 .get_user_by_id(user_id)
980 .await?
981 .context("user not found")?;
982
983 let update_user_plan = make_update_user_plan_message(
984 &user,
985 user.admin,
986 &self.app_state.db,
987 self.app_state.llm_db.clone(),
988 )
989 .await?;
990
991 let pool = self.connection_pool.lock();
992 for connection_id in pool.user_connection_ids(user_id) {
993 self.peer
994 .send(connection_id, update_user_plan.clone())
995 .trace_err();
996 }
997
998 Ok(())
999 }
1000
1001 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1002 let pool = self.connection_pool.lock();
1003 for connection_id in pool.user_connection_ids(user_id) {
1004 self.peer
1005 .send(connection_id, proto::RefreshLlmToken {})
1006 .trace_err();
1007 }
1008 }
1009
1010 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1011 ServerSnapshot {
1012 connection_pool: ConnectionPoolGuard {
1013 guard: self.connection_pool.lock(),
1014 _not_send: PhantomData,
1015 },
1016 peer: &self.peer,
1017 }
1018 }
1019}
1020
1021impl Deref for ConnectionPoolGuard<'_> {
1022 type Target = ConnectionPool;
1023
1024 fn deref(&self) -> &Self::Target {
1025 &self.guard
1026 }
1027}
1028
1029impl DerefMut for ConnectionPoolGuard<'_> {
1030 fn deref_mut(&mut self) -> &mut Self::Target {
1031 &mut self.guard
1032 }
1033}
1034
1035impl Drop for ConnectionPoolGuard<'_> {
1036 fn drop(&mut self) {
1037 #[cfg(test)]
1038 self.check_invariants();
1039 }
1040}
1041
1042fn broadcast<F>(
1043 sender_id: Option<ConnectionId>,
1044 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1045 mut f: F,
1046) where
1047 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1048{
1049 for receiver_id in receiver_ids {
1050 if Some(receiver_id) != sender_id {
1051 if let Err(error) = f(receiver_id) {
1052 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1053 }
1054 }
1055 }
1056}
1057
1058pub struct ProtocolVersion(u32);
1059
1060impl Header for ProtocolVersion {
1061 fn name() -> &'static HeaderName {
1062 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1063 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1064 }
1065
1066 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1067 where
1068 Self: Sized,
1069 I: Iterator<Item = &'i axum::http::HeaderValue>,
1070 {
1071 let version = values
1072 .next()
1073 .ok_or_else(axum::headers::Error::invalid)?
1074 .to_str()
1075 .map_err(|_| axum::headers::Error::invalid())?
1076 .parse()
1077 .map_err(|_| axum::headers::Error::invalid())?;
1078 Ok(Self(version))
1079 }
1080
1081 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1082 values.extend([self.0.to_string().parse().unwrap()]);
1083 }
1084}
1085
1086pub struct AppVersionHeader(SemanticVersion);
1087impl Header for AppVersionHeader {
1088 fn name() -> &'static HeaderName {
1089 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1090 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1091 }
1092
1093 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1094 where
1095 Self: Sized,
1096 I: Iterator<Item = &'i axum::http::HeaderValue>,
1097 {
1098 let version = values
1099 .next()
1100 .ok_or_else(axum::headers::Error::invalid)?
1101 .to_str()
1102 .map_err(|_| axum::headers::Error::invalid())?
1103 .parse()
1104 .map_err(|_| axum::headers::Error::invalid())?;
1105 Ok(Self(version))
1106 }
1107
1108 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1109 values.extend([self.0.to_string().parse().unwrap()]);
1110 }
1111}
1112
1113pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1114 Router::new()
1115 .route("/rpc", get(handle_websocket_request))
1116 .layer(
1117 ServiceBuilder::new()
1118 .layer(Extension(server.app_state.clone()))
1119 .layer(middleware::from_fn(auth::validate_header)),
1120 )
1121 .route("/metrics", get(handle_metrics))
1122 .layer(Extension(server))
1123}
1124
1125pub async fn handle_websocket_request(
1126 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1127 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1128 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1129 Extension(server): Extension<Arc<Server>>,
1130 Extension(principal): Extension<Principal>,
1131 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1132 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1133 ws: WebSocketUpgrade,
1134) -> axum::response::Response {
1135 if protocol_version != rpc::PROTOCOL_VERSION {
1136 return (
1137 StatusCode::UPGRADE_REQUIRED,
1138 "client must be upgraded".to_string(),
1139 )
1140 .into_response();
1141 }
1142
1143 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1144 return (
1145 StatusCode::UPGRADE_REQUIRED,
1146 "no version header found".to_string(),
1147 )
1148 .into_response();
1149 };
1150
1151 if !version.can_collaborate() {
1152 return (
1153 StatusCode::UPGRADE_REQUIRED,
1154 "client must be upgraded".to_string(),
1155 )
1156 .into_response();
1157 }
1158
1159 let socket_address = socket_address.to_string();
1160 ws.on_upgrade(move |socket| {
1161 let socket = socket
1162 .map_ok(to_tungstenite_message)
1163 .err_into()
1164 .with(|message| async move { to_axum_message(message) });
1165 let connection = Connection::new(Box::pin(socket));
1166 async move {
1167 server
1168 .handle_connection(
1169 connection,
1170 socket_address,
1171 principal,
1172 version,
1173 country_code_header.map(|header| header.to_string()),
1174 system_id_header.map(|header| header.to_string()),
1175 None,
1176 Executor::Production,
1177 )
1178 .await;
1179 }
1180 })
1181}
1182
1183pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1184 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1185 let connections_metric = CONNECTIONS_METRIC
1186 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1187
1188 let connections = server
1189 .connection_pool
1190 .lock()
1191 .connections()
1192 .filter(|connection| !connection.admin)
1193 .count();
1194 connections_metric.set(connections as _);
1195
1196 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1197 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1198 register_int_gauge!(
1199 "shared_projects",
1200 "number of open projects with one or more guests"
1201 )
1202 .unwrap()
1203 });
1204
1205 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1206 shared_projects_metric.set(shared_projects as _);
1207
1208 let encoder = prometheus::TextEncoder::new();
1209 let metric_families = prometheus::gather();
1210 let encoded_metrics = encoder
1211 .encode_to_string(&metric_families)
1212 .map_err(|err| anyhow!("{err}"))?;
1213 Ok(encoded_metrics)
1214}
1215
1216#[instrument(err, skip(executor))]
1217async fn connection_lost(
1218 session: Session,
1219 mut teardown: watch::Receiver<bool>,
1220 executor: Executor,
1221) -> Result<()> {
1222 session.peer.disconnect(session.connection_id);
1223 session
1224 .connection_pool()
1225 .await
1226 .remove_connection(session.connection_id)?;
1227
1228 session
1229 .db()
1230 .await
1231 .connection_lost(session.connection_id)
1232 .await
1233 .trace_err();
1234
1235 futures::select_biased! {
1236 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1237
1238 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1239 leave_room_for_session(&session, session.connection_id).await.trace_err();
1240 leave_channel_buffers_for_session(&session)
1241 .await
1242 .trace_err();
1243
1244 if !session
1245 .connection_pool()
1246 .await
1247 .is_user_online(session.user_id())
1248 {
1249 let db = session.db().await;
1250 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1251 room_updated(&room, &session.peer);
1252 }
1253 }
1254
1255 update_user_contacts(session.user_id(), &session).await?;
1256 },
1257 _ = teardown.changed().fuse() => {}
1258 }
1259
1260 Ok(())
1261}
1262
1263/// Acknowledges a ping from a client, used to keep the connection alive.
1264async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1265 response.send(proto::Ack {})?;
1266 Ok(())
1267}
1268
1269/// Creates a new room for calling (outside of channels)
1270async fn create_room(
1271 _request: proto::CreateRoom,
1272 response: Response<proto::CreateRoom>,
1273 session: Session,
1274) -> Result<()> {
1275 let livekit_room = nanoid::nanoid!(30);
1276
1277 let live_kit_connection_info = util::maybe!(async {
1278 let live_kit = session.app_state.livekit_client.as_ref();
1279 let live_kit = live_kit?;
1280 let user_id = session.user_id().to_string();
1281
1282 let token = live_kit
1283 .room_token(&livekit_room, &user_id.to_string())
1284 .trace_err()?;
1285
1286 Some(proto::LiveKitConnectionInfo {
1287 server_url: live_kit.url().into(),
1288 token,
1289 can_publish: true,
1290 })
1291 })
1292 .await;
1293
1294 let room = session
1295 .db()
1296 .await
1297 .create_room(session.user_id(), session.connection_id, &livekit_room)
1298 .await?;
1299
1300 response.send(proto::CreateRoomResponse {
1301 room: Some(room.clone()),
1302 live_kit_connection_info,
1303 })?;
1304
1305 update_user_contacts(session.user_id(), &session).await?;
1306 Ok(())
1307}
1308
1309/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1310async fn join_room(
1311 request: proto::JoinRoom,
1312 response: Response<proto::JoinRoom>,
1313 session: Session,
1314) -> Result<()> {
1315 let room_id = RoomId::from_proto(request.id);
1316
1317 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1318
1319 if let Some(channel_id) = channel_id {
1320 return join_channel_internal(channel_id, Box::new(response), session).await;
1321 }
1322
1323 let joined_room = {
1324 let room = session
1325 .db()
1326 .await
1327 .join_room(room_id, session.user_id(), session.connection_id)
1328 .await?;
1329 room_updated(&room.room, &session.peer);
1330 room.into_inner()
1331 };
1332
1333 for connection_id in session
1334 .connection_pool()
1335 .await
1336 .user_connection_ids(session.user_id())
1337 {
1338 session
1339 .peer
1340 .send(
1341 connection_id,
1342 proto::CallCanceled {
1343 room_id: room_id.to_proto(),
1344 },
1345 )
1346 .trace_err();
1347 }
1348
1349 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1350 {
1351 live_kit
1352 .room_token(
1353 &joined_room.room.livekit_room,
1354 &session.user_id().to_string(),
1355 )
1356 .trace_err()
1357 .map(|token| proto::LiveKitConnectionInfo {
1358 server_url: live_kit.url().into(),
1359 token,
1360 can_publish: true,
1361 })
1362 } else {
1363 None
1364 };
1365
1366 response.send(proto::JoinRoomResponse {
1367 room: Some(joined_room.room),
1368 channel_id: None,
1369 live_kit_connection_info,
1370 })?;
1371
1372 update_user_contacts(session.user_id(), &session).await?;
1373 Ok(())
1374}
1375
1376/// Rejoin room is used to reconnect to a room after connection errors.
1377async fn rejoin_room(
1378 request: proto::RejoinRoom,
1379 response: Response<proto::RejoinRoom>,
1380 session: Session,
1381) -> Result<()> {
1382 let room;
1383 let channel;
1384 {
1385 let mut rejoined_room = session
1386 .db()
1387 .await
1388 .rejoin_room(request, session.user_id(), session.connection_id)
1389 .await?;
1390
1391 response.send(proto::RejoinRoomResponse {
1392 room: Some(rejoined_room.room.clone()),
1393 reshared_projects: rejoined_room
1394 .reshared_projects
1395 .iter()
1396 .map(|project| proto::ResharedProject {
1397 id: project.id.to_proto(),
1398 collaborators: project
1399 .collaborators
1400 .iter()
1401 .map(|collaborator| collaborator.to_proto())
1402 .collect(),
1403 })
1404 .collect(),
1405 rejoined_projects: rejoined_room
1406 .rejoined_projects
1407 .iter()
1408 .map(|rejoined_project| rejoined_project.to_proto())
1409 .collect(),
1410 })?;
1411 room_updated(&rejoined_room.room, &session.peer);
1412
1413 for project in &rejoined_room.reshared_projects {
1414 for collaborator in &project.collaborators {
1415 session
1416 .peer
1417 .send(
1418 collaborator.connection_id,
1419 proto::UpdateProjectCollaborator {
1420 project_id: project.id.to_proto(),
1421 old_peer_id: Some(project.old_connection_id.into()),
1422 new_peer_id: Some(session.connection_id.into()),
1423 },
1424 )
1425 .trace_err();
1426 }
1427
1428 broadcast(
1429 Some(session.connection_id),
1430 project
1431 .collaborators
1432 .iter()
1433 .map(|collaborator| collaborator.connection_id),
1434 |connection_id| {
1435 session.peer.forward_send(
1436 session.connection_id,
1437 connection_id,
1438 proto::UpdateProject {
1439 project_id: project.id.to_proto(),
1440 worktrees: project.worktrees.clone(),
1441 },
1442 )
1443 },
1444 );
1445 }
1446
1447 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1448
1449 let rejoined_room = rejoined_room.into_inner();
1450
1451 room = rejoined_room.room;
1452 channel = rejoined_room.channel;
1453 }
1454
1455 if let Some(channel) = channel {
1456 channel_updated(
1457 &channel,
1458 &room,
1459 &session.peer,
1460 &*session.connection_pool().await,
1461 );
1462 }
1463
1464 update_user_contacts(session.user_id(), &session).await?;
1465 Ok(())
1466}
1467
1468fn notify_rejoined_projects(
1469 rejoined_projects: &mut Vec<RejoinedProject>,
1470 session: &Session,
1471) -> Result<()> {
1472 for project in rejoined_projects.iter() {
1473 for collaborator in &project.collaborators {
1474 session
1475 .peer
1476 .send(
1477 collaborator.connection_id,
1478 proto::UpdateProjectCollaborator {
1479 project_id: project.id.to_proto(),
1480 old_peer_id: Some(project.old_connection_id.into()),
1481 new_peer_id: Some(session.connection_id.into()),
1482 },
1483 )
1484 .trace_err();
1485 }
1486 }
1487
1488 for project in rejoined_projects {
1489 for worktree in mem::take(&mut project.worktrees) {
1490 // Stream this worktree's entries.
1491 let message = proto::UpdateWorktree {
1492 project_id: project.id.to_proto(),
1493 worktree_id: worktree.id,
1494 abs_path: worktree.abs_path.clone(),
1495 root_name: worktree.root_name,
1496 updated_entries: worktree.updated_entries,
1497 removed_entries: worktree.removed_entries,
1498 scan_id: worktree.scan_id,
1499 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1500 updated_repositories: worktree.updated_repositories,
1501 removed_repositories: worktree.removed_repositories,
1502 };
1503 for update in proto::split_worktree_update(message) {
1504 session.peer.send(session.connection_id, update)?;
1505 }
1506
1507 // Stream this worktree's diagnostics.
1508 for summary in worktree.diagnostic_summaries {
1509 session.peer.send(
1510 session.connection_id,
1511 proto::UpdateDiagnosticSummary {
1512 project_id: project.id.to_proto(),
1513 worktree_id: worktree.id,
1514 summary: Some(summary),
1515 },
1516 )?;
1517 }
1518
1519 for settings_file in worktree.settings_files {
1520 session.peer.send(
1521 session.connection_id,
1522 proto::UpdateWorktreeSettings {
1523 project_id: project.id.to_proto(),
1524 worktree_id: worktree.id,
1525 path: settings_file.path,
1526 content: Some(settings_file.content),
1527 kind: Some(settings_file.kind.to_proto().into()),
1528 },
1529 )?;
1530 }
1531 }
1532
1533 for repository in mem::take(&mut project.updated_repositories) {
1534 for update in split_repository_update(repository) {
1535 session.peer.send(session.connection_id, update)?;
1536 }
1537 }
1538
1539 for id in mem::take(&mut project.removed_repositories) {
1540 session.peer.send(
1541 session.connection_id,
1542 proto::RemoveRepository {
1543 project_id: project.id.to_proto(),
1544 id,
1545 },
1546 )?;
1547 }
1548 }
1549
1550 Ok(())
1551}
1552
1553/// leave room disconnects from the room.
1554async fn leave_room(
1555 _: proto::LeaveRoom,
1556 response: Response<proto::LeaveRoom>,
1557 session: Session,
1558) -> Result<()> {
1559 leave_room_for_session(&session, session.connection_id).await?;
1560 response.send(proto::Ack {})?;
1561 Ok(())
1562}
1563
1564/// Updates the permissions of someone else in the room.
1565async fn set_room_participant_role(
1566 request: proto::SetRoomParticipantRole,
1567 response: Response<proto::SetRoomParticipantRole>,
1568 session: Session,
1569) -> Result<()> {
1570 let user_id = UserId::from_proto(request.user_id);
1571 let role = ChannelRole::from(request.role());
1572
1573 let (livekit_room, can_publish) = {
1574 let room = session
1575 .db()
1576 .await
1577 .set_room_participant_role(
1578 session.user_id(),
1579 RoomId::from_proto(request.room_id),
1580 user_id,
1581 role,
1582 )
1583 .await?;
1584
1585 let livekit_room = room.livekit_room.clone();
1586 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1587 room_updated(&room, &session.peer);
1588 (livekit_room, can_publish)
1589 };
1590
1591 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1592 live_kit
1593 .update_participant(
1594 livekit_room.clone(),
1595 request.user_id.to_string(),
1596 livekit_api::proto::ParticipantPermission {
1597 can_subscribe: true,
1598 can_publish,
1599 can_publish_data: can_publish,
1600 hidden: false,
1601 recorder: false,
1602 },
1603 )
1604 .await
1605 .trace_err();
1606 }
1607
1608 response.send(proto::Ack {})?;
1609 Ok(())
1610}
1611
1612/// Call someone else into the current room
1613async fn call(
1614 request: proto::Call,
1615 response: Response<proto::Call>,
1616 session: Session,
1617) -> Result<()> {
1618 let room_id = RoomId::from_proto(request.room_id);
1619 let calling_user_id = session.user_id();
1620 let calling_connection_id = session.connection_id;
1621 let called_user_id = UserId::from_proto(request.called_user_id);
1622 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1623 if !session
1624 .db()
1625 .await
1626 .has_contact(calling_user_id, called_user_id)
1627 .await?
1628 {
1629 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1630 }
1631
1632 let incoming_call = {
1633 let (room, incoming_call) = &mut *session
1634 .db()
1635 .await
1636 .call(
1637 room_id,
1638 calling_user_id,
1639 calling_connection_id,
1640 called_user_id,
1641 initial_project_id,
1642 )
1643 .await?;
1644 room_updated(room, &session.peer);
1645 mem::take(incoming_call)
1646 };
1647 update_user_contacts(called_user_id, &session).await?;
1648
1649 let mut calls = session
1650 .connection_pool()
1651 .await
1652 .user_connection_ids(called_user_id)
1653 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1654 .collect::<FuturesUnordered<_>>();
1655
1656 while let Some(call_response) = calls.next().await {
1657 match call_response.as_ref() {
1658 Ok(_) => {
1659 response.send(proto::Ack {})?;
1660 return Ok(());
1661 }
1662 Err(_) => {
1663 call_response.trace_err();
1664 }
1665 }
1666 }
1667
1668 {
1669 let room = session
1670 .db()
1671 .await
1672 .call_failed(room_id, called_user_id)
1673 .await?;
1674 room_updated(&room, &session.peer);
1675 }
1676 update_user_contacts(called_user_id, &session).await?;
1677
1678 Err(anyhow!("failed to ring user"))?
1679}
1680
1681/// Cancel an outgoing call.
1682async fn cancel_call(
1683 request: proto::CancelCall,
1684 response: Response<proto::CancelCall>,
1685 session: Session,
1686) -> Result<()> {
1687 let called_user_id = UserId::from_proto(request.called_user_id);
1688 let room_id = RoomId::from_proto(request.room_id);
1689 {
1690 let room = session
1691 .db()
1692 .await
1693 .cancel_call(room_id, session.connection_id, called_user_id)
1694 .await?;
1695 room_updated(&room, &session.peer);
1696 }
1697
1698 for connection_id in session
1699 .connection_pool()
1700 .await
1701 .user_connection_ids(called_user_id)
1702 {
1703 session
1704 .peer
1705 .send(
1706 connection_id,
1707 proto::CallCanceled {
1708 room_id: room_id.to_proto(),
1709 },
1710 )
1711 .trace_err();
1712 }
1713 response.send(proto::Ack {})?;
1714
1715 update_user_contacts(called_user_id, &session).await?;
1716 Ok(())
1717}
1718
1719/// Decline an incoming call.
1720async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1721 let room_id = RoomId::from_proto(message.room_id);
1722 {
1723 let room = session
1724 .db()
1725 .await
1726 .decline_call(Some(room_id), session.user_id())
1727 .await?
1728 .context("declining call")?;
1729 room_updated(&room, &session.peer);
1730 }
1731
1732 for connection_id in session
1733 .connection_pool()
1734 .await
1735 .user_connection_ids(session.user_id())
1736 {
1737 session
1738 .peer
1739 .send(
1740 connection_id,
1741 proto::CallCanceled {
1742 room_id: room_id.to_proto(),
1743 },
1744 )
1745 .trace_err();
1746 }
1747 update_user_contacts(session.user_id(), &session).await?;
1748 Ok(())
1749}
1750
1751/// Updates other participants in the room with your current location.
1752async fn update_participant_location(
1753 request: proto::UpdateParticipantLocation,
1754 response: Response<proto::UpdateParticipantLocation>,
1755 session: Session,
1756) -> Result<()> {
1757 let room_id = RoomId::from_proto(request.room_id);
1758 let location = request.location.context("invalid location")?;
1759
1760 let db = session.db().await;
1761 let room = db
1762 .update_room_participant_location(room_id, session.connection_id, location)
1763 .await?;
1764
1765 room_updated(&room, &session.peer);
1766 response.send(proto::Ack {})?;
1767 Ok(())
1768}
1769
1770/// Share a project into the room.
1771async fn share_project(
1772 request: proto::ShareProject,
1773 response: Response<proto::ShareProject>,
1774 session: Session,
1775) -> Result<()> {
1776 let (project_id, room) = &*session
1777 .db()
1778 .await
1779 .share_project(
1780 RoomId::from_proto(request.room_id),
1781 session.connection_id,
1782 &request.worktrees,
1783 request.is_ssh_project,
1784 )
1785 .await?;
1786 response.send(proto::ShareProjectResponse {
1787 project_id: project_id.to_proto(),
1788 })?;
1789 room_updated(room, &session.peer);
1790
1791 Ok(())
1792}
1793
1794/// Unshare a project from the room.
1795async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1796 let project_id = ProjectId::from_proto(message.project_id);
1797 unshare_project_internal(project_id, session.connection_id, &session).await
1798}
1799
1800async fn unshare_project_internal(
1801 project_id: ProjectId,
1802 connection_id: ConnectionId,
1803 session: &Session,
1804) -> Result<()> {
1805 let delete = {
1806 let room_guard = session
1807 .db()
1808 .await
1809 .unshare_project(project_id, connection_id)
1810 .await?;
1811
1812 let (delete, room, guest_connection_ids) = &*room_guard;
1813
1814 let message = proto::UnshareProject {
1815 project_id: project_id.to_proto(),
1816 };
1817
1818 broadcast(
1819 Some(connection_id),
1820 guest_connection_ids.iter().copied(),
1821 |conn_id| session.peer.send(conn_id, message.clone()),
1822 );
1823 if let Some(room) = room {
1824 room_updated(room, &session.peer);
1825 }
1826
1827 *delete
1828 };
1829
1830 if delete {
1831 let db = session.db().await;
1832 db.delete_project(project_id).await?;
1833 }
1834
1835 Ok(())
1836}
1837
1838/// Join someone elses shared project.
1839async fn join_project(
1840 request: proto::JoinProject,
1841 response: Response<proto::JoinProject>,
1842 session: Session,
1843) -> Result<()> {
1844 let project_id = ProjectId::from_proto(request.project_id);
1845
1846 tracing::info!(%project_id, "join project");
1847
1848 let db = session.db().await;
1849 let (project, replica_id) = &mut *db
1850 .join_project(project_id, session.connection_id, session.user_id())
1851 .await?;
1852 drop(db);
1853 tracing::info!(%project_id, "join remote project");
1854 join_project_internal(response, session, project, replica_id)
1855}
1856
1857trait JoinProjectInternalResponse {
1858 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1859}
1860impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1861 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1862 Response::<proto::JoinProject>::send(self, result)
1863 }
1864}
1865
1866fn join_project_internal(
1867 response: impl JoinProjectInternalResponse,
1868 session: Session,
1869 project: &mut Project,
1870 replica_id: &ReplicaId,
1871) -> Result<()> {
1872 let collaborators = project
1873 .collaborators
1874 .iter()
1875 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1876 .map(|collaborator| collaborator.to_proto())
1877 .collect::<Vec<_>>();
1878 let project_id = project.id;
1879 let guest_user_id = session.user_id();
1880
1881 let worktrees = project
1882 .worktrees
1883 .iter()
1884 .map(|(id, worktree)| proto::WorktreeMetadata {
1885 id: *id,
1886 root_name: worktree.root_name.clone(),
1887 visible: worktree.visible,
1888 abs_path: worktree.abs_path.clone(),
1889 })
1890 .collect::<Vec<_>>();
1891
1892 let add_project_collaborator = proto::AddProjectCollaborator {
1893 project_id: project_id.to_proto(),
1894 collaborator: Some(proto::Collaborator {
1895 peer_id: Some(session.connection_id.into()),
1896 replica_id: replica_id.0 as u32,
1897 user_id: guest_user_id.to_proto(),
1898 is_host: false,
1899 }),
1900 };
1901
1902 for collaborator in &collaborators {
1903 session
1904 .peer
1905 .send(
1906 collaborator.peer_id.unwrap().into(),
1907 add_project_collaborator.clone(),
1908 )
1909 .trace_err();
1910 }
1911
1912 // First, we send the metadata associated with each worktree.
1913 response.send(proto::JoinProjectResponse {
1914 project_id: project.id.0 as u64,
1915 worktrees: worktrees.clone(),
1916 replica_id: replica_id.0 as u32,
1917 collaborators: collaborators.clone(),
1918 language_servers: project.language_servers.clone(),
1919 role: project.role.into(),
1920 })?;
1921
1922 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1923 // Stream this worktree's entries.
1924 let message = proto::UpdateWorktree {
1925 project_id: project_id.to_proto(),
1926 worktree_id,
1927 abs_path: worktree.abs_path.clone(),
1928 root_name: worktree.root_name,
1929 updated_entries: worktree.entries,
1930 removed_entries: Default::default(),
1931 scan_id: worktree.scan_id,
1932 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1933 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1934 removed_repositories: Default::default(),
1935 };
1936 for update in proto::split_worktree_update(message) {
1937 session.peer.send(session.connection_id, update.clone())?;
1938 }
1939
1940 // Stream this worktree's diagnostics.
1941 for summary in worktree.diagnostic_summaries {
1942 session.peer.send(
1943 session.connection_id,
1944 proto::UpdateDiagnosticSummary {
1945 project_id: project_id.to_proto(),
1946 worktree_id: worktree.id,
1947 summary: Some(summary),
1948 },
1949 )?;
1950 }
1951
1952 for settings_file in worktree.settings_files {
1953 session.peer.send(
1954 session.connection_id,
1955 proto::UpdateWorktreeSettings {
1956 project_id: project_id.to_proto(),
1957 worktree_id: worktree.id,
1958 path: settings_file.path,
1959 content: Some(settings_file.content),
1960 kind: Some(settings_file.kind.to_proto() as i32),
1961 },
1962 )?;
1963 }
1964 }
1965
1966 for repository in mem::take(&mut project.repositories) {
1967 for update in split_repository_update(repository) {
1968 session.peer.send(session.connection_id, update)?;
1969 }
1970 }
1971
1972 for language_server in &project.language_servers {
1973 session.peer.send(
1974 session.connection_id,
1975 proto::UpdateLanguageServer {
1976 project_id: project_id.to_proto(),
1977 language_server_id: language_server.id,
1978 variant: Some(
1979 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1980 proto::LspDiskBasedDiagnosticsUpdated {},
1981 ),
1982 ),
1983 },
1984 )?;
1985 }
1986
1987 Ok(())
1988}
1989
1990/// Leave someone elses shared project.
1991async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1992 let sender_id = session.connection_id;
1993 let project_id = ProjectId::from_proto(request.project_id);
1994 let db = session.db().await;
1995
1996 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1997 tracing::info!(
1998 %project_id,
1999 "leave project"
2000 );
2001
2002 project_left(project, &session);
2003 if let Some(room) = room {
2004 room_updated(room, &session.peer);
2005 }
2006
2007 Ok(())
2008}
2009
2010/// Updates other participants with changes to the project
2011async fn update_project(
2012 request: proto::UpdateProject,
2013 response: Response<proto::UpdateProject>,
2014 session: Session,
2015) -> Result<()> {
2016 let project_id = ProjectId::from_proto(request.project_id);
2017 let (room, guest_connection_ids) = &*session
2018 .db()
2019 .await
2020 .update_project(project_id, session.connection_id, &request.worktrees)
2021 .await?;
2022 broadcast(
2023 Some(session.connection_id),
2024 guest_connection_ids.iter().copied(),
2025 |connection_id| {
2026 session
2027 .peer
2028 .forward_send(session.connection_id, connection_id, request.clone())
2029 },
2030 );
2031 if let Some(room) = room {
2032 room_updated(room, &session.peer);
2033 }
2034 response.send(proto::Ack {})?;
2035
2036 Ok(())
2037}
2038
2039/// Updates other participants with changes to the worktree
2040async fn update_worktree(
2041 request: proto::UpdateWorktree,
2042 response: Response<proto::UpdateWorktree>,
2043 session: Session,
2044) -> Result<()> {
2045 let guest_connection_ids = session
2046 .db()
2047 .await
2048 .update_worktree(&request, session.connection_id)
2049 .await?;
2050
2051 broadcast(
2052 Some(session.connection_id),
2053 guest_connection_ids.iter().copied(),
2054 |connection_id| {
2055 session
2056 .peer
2057 .forward_send(session.connection_id, connection_id, request.clone())
2058 },
2059 );
2060 response.send(proto::Ack {})?;
2061 Ok(())
2062}
2063
2064async fn update_repository(
2065 request: proto::UpdateRepository,
2066 response: Response<proto::UpdateRepository>,
2067 session: Session,
2068) -> Result<()> {
2069 let guest_connection_ids = session
2070 .db()
2071 .await
2072 .update_repository(&request, session.connection_id)
2073 .await?;
2074
2075 broadcast(
2076 Some(session.connection_id),
2077 guest_connection_ids.iter().copied(),
2078 |connection_id| {
2079 session
2080 .peer
2081 .forward_send(session.connection_id, connection_id, request.clone())
2082 },
2083 );
2084 response.send(proto::Ack {})?;
2085 Ok(())
2086}
2087
2088async fn remove_repository(
2089 request: proto::RemoveRepository,
2090 response: Response<proto::RemoveRepository>,
2091 session: Session,
2092) -> Result<()> {
2093 let guest_connection_ids = session
2094 .db()
2095 .await
2096 .remove_repository(&request, session.connection_id)
2097 .await?;
2098
2099 broadcast(
2100 Some(session.connection_id),
2101 guest_connection_ids.iter().copied(),
2102 |connection_id| {
2103 session
2104 .peer
2105 .forward_send(session.connection_id, connection_id, request.clone())
2106 },
2107 );
2108 response.send(proto::Ack {})?;
2109 Ok(())
2110}
2111
2112/// Updates other participants with changes to the diagnostics
2113async fn update_diagnostic_summary(
2114 message: proto::UpdateDiagnosticSummary,
2115 session: Session,
2116) -> Result<()> {
2117 let guest_connection_ids = session
2118 .db()
2119 .await
2120 .update_diagnostic_summary(&message, session.connection_id)
2121 .await?;
2122
2123 broadcast(
2124 Some(session.connection_id),
2125 guest_connection_ids.iter().copied(),
2126 |connection_id| {
2127 session
2128 .peer
2129 .forward_send(session.connection_id, connection_id, message.clone())
2130 },
2131 );
2132
2133 Ok(())
2134}
2135
2136/// Updates other participants with changes to the worktree settings
2137async fn update_worktree_settings(
2138 message: proto::UpdateWorktreeSettings,
2139 session: Session,
2140) -> Result<()> {
2141 let guest_connection_ids = session
2142 .db()
2143 .await
2144 .update_worktree_settings(&message, session.connection_id)
2145 .await?;
2146
2147 broadcast(
2148 Some(session.connection_id),
2149 guest_connection_ids.iter().copied(),
2150 |connection_id| {
2151 session
2152 .peer
2153 .forward_send(session.connection_id, connection_id, message.clone())
2154 },
2155 );
2156
2157 Ok(())
2158}
2159
2160/// Notify other participants that a language server has started.
2161async fn start_language_server(
2162 request: proto::StartLanguageServer,
2163 session: Session,
2164) -> Result<()> {
2165 let guest_connection_ids = session
2166 .db()
2167 .await
2168 .start_language_server(&request, session.connection_id)
2169 .await?;
2170
2171 broadcast(
2172 Some(session.connection_id),
2173 guest_connection_ids.iter().copied(),
2174 |connection_id| {
2175 session
2176 .peer
2177 .forward_send(session.connection_id, connection_id, request.clone())
2178 },
2179 );
2180 Ok(())
2181}
2182
2183/// Notify other participants that a language server has changed.
2184async fn update_language_server(
2185 request: proto::UpdateLanguageServer,
2186 session: Session,
2187) -> Result<()> {
2188 let project_id = ProjectId::from_proto(request.project_id);
2189 let project_connection_ids = session
2190 .db()
2191 .await
2192 .project_connection_ids(project_id, session.connection_id, true)
2193 .await?;
2194 broadcast(
2195 Some(session.connection_id),
2196 project_connection_ids.iter().copied(),
2197 |connection_id| {
2198 session
2199 .peer
2200 .forward_send(session.connection_id, connection_id, request.clone())
2201 },
2202 );
2203 Ok(())
2204}
2205
2206/// forward a project request to the host. These requests should be read only
2207/// as guests are allowed to send them.
2208async fn forward_read_only_project_request<T>(
2209 request: T,
2210 response: Response<T>,
2211 session: Session,
2212) -> Result<()>
2213where
2214 T: EntityMessage + RequestMessage,
2215{
2216 let project_id = ProjectId::from_proto(request.remote_entity_id());
2217 let host_connection_id = session
2218 .db()
2219 .await
2220 .host_for_read_only_project_request(project_id, session.connection_id)
2221 .await?;
2222 let payload = session
2223 .peer
2224 .forward_request(session.connection_id, host_connection_id, request)
2225 .await?;
2226 response.send(payload)?;
2227 Ok(())
2228}
2229
2230async fn forward_find_search_candidates_request(
2231 request: proto::FindSearchCandidates,
2232 response: Response<proto::FindSearchCandidates>,
2233 session: Session,
2234) -> Result<()> {
2235 let project_id = ProjectId::from_proto(request.remote_entity_id());
2236 let host_connection_id = session
2237 .db()
2238 .await
2239 .host_for_read_only_project_request(project_id, session.connection_id)
2240 .await?;
2241 let payload = session
2242 .peer
2243 .forward_request(session.connection_id, host_connection_id, request)
2244 .await?;
2245 response.send(payload)?;
2246 Ok(())
2247}
2248
2249/// forward a project request to the host. These requests are disallowed
2250/// for guests.
2251async fn forward_mutating_project_request<T>(
2252 request: T,
2253 response: Response<T>,
2254 session: Session,
2255) -> Result<()>
2256where
2257 T: EntityMessage + RequestMessage,
2258{
2259 let project_id = ProjectId::from_proto(request.remote_entity_id());
2260
2261 let host_connection_id = session
2262 .db()
2263 .await
2264 .host_for_mutating_project_request(project_id, session.connection_id)
2265 .await?;
2266 let payload = session
2267 .peer
2268 .forward_request(session.connection_id, host_connection_id, request)
2269 .await?;
2270 response.send(payload)?;
2271 Ok(())
2272}
2273
2274/// Notify other participants that a new buffer has been created
2275async fn create_buffer_for_peer(
2276 request: proto::CreateBufferForPeer,
2277 session: Session,
2278) -> Result<()> {
2279 session
2280 .db()
2281 .await
2282 .check_user_is_project_host(
2283 ProjectId::from_proto(request.project_id),
2284 session.connection_id,
2285 )
2286 .await?;
2287 let peer_id = request.peer_id.context("invalid peer id")?;
2288 session
2289 .peer
2290 .forward_send(session.connection_id, peer_id.into(), request)?;
2291 Ok(())
2292}
2293
2294/// Notify other participants that a buffer has been updated. This is
2295/// allowed for guests as long as the update is limited to selections.
2296async fn update_buffer(
2297 request: proto::UpdateBuffer,
2298 response: Response<proto::UpdateBuffer>,
2299 session: Session,
2300) -> Result<()> {
2301 let project_id = ProjectId::from_proto(request.project_id);
2302 let mut capability = Capability::ReadOnly;
2303
2304 for op in request.operations.iter() {
2305 match op.variant {
2306 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2307 Some(_) => capability = Capability::ReadWrite,
2308 }
2309 }
2310
2311 let host = {
2312 let guard = session
2313 .db()
2314 .await
2315 .connections_for_buffer_update(project_id, session.connection_id, capability)
2316 .await?;
2317
2318 let (host, guests) = &*guard;
2319
2320 broadcast(
2321 Some(session.connection_id),
2322 guests.clone(),
2323 |connection_id| {
2324 session
2325 .peer
2326 .forward_send(session.connection_id, connection_id, request.clone())
2327 },
2328 );
2329
2330 *host
2331 };
2332
2333 if host != session.connection_id {
2334 session
2335 .peer
2336 .forward_request(session.connection_id, host, request.clone())
2337 .await?;
2338 }
2339
2340 response.send(proto::Ack {})?;
2341 Ok(())
2342}
2343
2344async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2345 let project_id = ProjectId::from_proto(message.project_id);
2346
2347 let operation = message.operation.as_ref().context("invalid operation")?;
2348 let capability = match operation.variant.as_ref() {
2349 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2350 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2351 match buffer_op.variant {
2352 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2353 Capability::ReadOnly
2354 }
2355 _ => Capability::ReadWrite,
2356 }
2357 } else {
2358 Capability::ReadWrite
2359 }
2360 }
2361 Some(_) => Capability::ReadWrite,
2362 None => Capability::ReadOnly,
2363 };
2364
2365 let guard = session
2366 .db()
2367 .await
2368 .connections_for_buffer_update(project_id, session.connection_id, capability)
2369 .await?;
2370
2371 let (host, guests) = &*guard;
2372
2373 broadcast(
2374 Some(session.connection_id),
2375 guests.iter().chain([host]).copied(),
2376 |connection_id| {
2377 session
2378 .peer
2379 .forward_send(session.connection_id, connection_id, message.clone())
2380 },
2381 );
2382
2383 Ok(())
2384}
2385
2386/// Notify other participants that a project has been updated.
2387async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2388 request: T,
2389 session: Session,
2390) -> Result<()> {
2391 let project_id = ProjectId::from_proto(request.remote_entity_id());
2392 let project_connection_ids = session
2393 .db()
2394 .await
2395 .project_connection_ids(project_id, session.connection_id, false)
2396 .await?;
2397
2398 broadcast(
2399 Some(session.connection_id),
2400 project_connection_ids.iter().copied(),
2401 |connection_id| {
2402 session
2403 .peer
2404 .forward_send(session.connection_id, connection_id, request.clone())
2405 },
2406 );
2407 Ok(())
2408}
2409
2410/// Start following another user in a call.
2411async fn follow(
2412 request: proto::Follow,
2413 response: Response<proto::Follow>,
2414 session: Session,
2415) -> Result<()> {
2416 let room_id = RoomId::from_proto(request.room_id);
2417 let project_id = request.project_id.map(ProjectId::from_proto);
2418 let leader_id = request.leader_id.context("invalid leader id")?.into();
2419 let follower_id = session.connection_id;
2420
2421 session
2422 .db()
2423 .await
2424 .check_room_participants(room_id, leader_id, session.connection_id)
2425 .await?;
2426
2427 let response_payload = session
2428 .peer
2429 .forward_request(session.connection_id, leader_id, request)
2430 .await?;
2431 response.send(response_payload)?;
2432
2433 if let Some(project_id) = project_id {
2434 let room = session
2435 .db()
2436 .await
2437 .follow(room_id, project_id, leader_id, follower_id)
2438 .await?;
2439 room_updated(&room, &session.peer);
2440 }
2441
2442 Ok(())
2443}
2444
2445/// Stop following another user in a call.
2446async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2447 let room_id = RoomId::from_proto(request.room_id);
2448 let project_id = request.project_id.map(ProjectId::from_proto);
2449 let leader_id = request.leader_id.context("invalid leader id")?.into();
2450 let follower_id = session.connection_id;
2451
2452 session
2453 .db()
2454 .await
2455 .check_room_participants(room_id, leader_id, session.connection_id)
2456 .await?;
2457
2458 session
2459 .peer
2460 .forward_send(session.connection_id, leader_id, request)?;
2461
2462 if let Some(project_id) = project_id {
2463 let room = session
2464 .db()
2465 .await
2466 .unfollow(room_id, project_id, leader_id, follower_id)
2467 .await?;
2468 room_updated(&room, &session.peer);
2469 }
2470
2471 Ok(())
2472}
2473
2474/// Notify everyone following you of your current location.
2475async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2476 let room_id = RoomId::from_proto(request.room_id);
2477 let database = session.db.lock().await;
2478
2479 let connection_ids = if let Some(project_id) = request.project_id {
2480 let project_id = ProjectId::from_proto(project_id);
2481 database
2482 .project_connection_ids(project_id, session.connection_id, true)
2483 .await?
2484 } else {
2485 database
2486 .room_connection_ids(room_id, session.connection_id)
2487 .await?
2488 };
2489
2490 // For now, don't send view update messages back to that view's current leader.
2491 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2492 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2493 _ => None,
2494 });
2495
2496 for connection_id in connection_ids.iter().cloned() {
2497 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2498 session
2499 .peer
2500 .forward_send(session.connection_id, connection_id, request.clone())?;
2501 }
2502 }
2503 Ok(())
2504}
2505
2506/// Get public data about users.
2507async fn get_users(
2508 request: proto::GetUsers,
2509 response: Response<proto::GetUsers>,
2510 session: Session,
2511) -> Result<()> {
2512 let user_ids = request
2513 .user_ids
2514 .into_iter()
2515 .map(UserId::from_proto)
2516 .collect();
2517 let users = session
2518 .db()
2519 .await
2520 .get_users_by_ids(user_ids)
2521 .await?
2522 .into_iter()
2523 .map(|user| proto::User {
2524 id: user.id.to_proto(),
2525 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2526 github_login: user.github_login,
2527 email: user.email_address,
2528 name: user.name,
2529 })
2530 .collect();
2531 response.send(proto::UsersResponse { users })?;
2532 Ok(())
2533}
2534
2535/// Search for users (to invite) buy Github login
2536async fn fuzzy_search_users(
2537 request: proto::FuzzySearchUsers,
2538 response: Response<proto::FuzzySearchUsers>,
2539 session: Session,
2540) -> Result<()> {
2541 let query = request.query;
2542 let users = match query.len() {
2543 0 => vec![],
2544 1 | 2 => session
2545 .db()
2546 .await
2547 .get_user_by_github_login(&query)
2548 .await?
2549 .into_iter()
2550 .collect(),
2551 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2552 };
2553 let users = users
2554 .into_iter()
2555 .filter(|user| user.id != session.user_id())
2556 .map(|user| proto::User {
2557 id: user.id.to_proto(),
2558 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2559 github_login: user.github_login,
2560 name: user.name,
2561 email: user.email_address,
2562 })
2563 .collect();
2564 response.send(proto::UsersResponse { users })?;
2565 Ok(())
2566}
2567
2568/// Send a contact request to another user.
2569async fn request_contact(
2570 request: proto::RequestContact,
2571 response: Response<proto::RequestContact>,
2572 session: Session,
2573) -> Result<()> {
2574 let requester_id = session.user_id();
2575 let responder_id = UserId::from_proto(request.responder_id);
2576 if requester_id == responder_id {
2577 return Err(anyhow!("cannot add yourself as a contact"))?;
2578 }
2579
2580 let notifications = session
2581 .db()
2582 .await
2583 .send_contact_request(requester_id, responder_id)
2584 .await?;
2585
2586 // Update outgoing contact requests of requester
2587 let mut update = proto::UpdateContacts::default();
2588 update.outgoing_requests.push(responder_id.to_proto());
2589 for connection_id in session
2590 .connection_pool()
2591 .await
2592 .user_connection_ids(requester_id)
2593 {
2594 session.peer.send(connection_id, update.clone())?;
2595 }
2596
2597 // Update incoming contact requests of responder
2598 let mut update = proto::UpdateContacts::default();
2599 update
2600 .incoming_requests
2601 .push(proto::IncomingContactRequest {
2602 requester_id: requester_id.to_proto(),
2603 });
2604 let connection_pool = session.connection_pool().await;
2605 for connection_id in connection_pool.user_connection_ids(responder_id) {
2606 session.peer.send(connection_id, update.clone())?;
2607 }
2608
2609 send_notifications(&connection_pool, &session.peer, notifications);
2610
2611 response.send(proto::Ack {})?;
2612 Ok(())
2613}
2614
2615/// Accept or decline a contact request
2616async fn respond_to_contact_request(
2617 request: proto::RespondToContactRequest,
2618 response: Response<proto::RespondToContactRequest>,
2619 session: Session,
2620) -> Result<()> {
2621 let responder_id = session.user_id();
2622 let requester_id = UserId::from_proto(request.requester_id);
2623 let db = session.db().await;
2624 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2625 db.dismiss_contact_notification(responder_id, requester_id)
2626 .await?;
2627 } else {
2628 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2629
2630 let notifications = db
2631 .respond_to_contact_request(responder_id, requester_id, accept)
2632 .await?;
2633 let requester_busy = db.is_user_busy(requester_id).await?;
2634 let responder_busy = db.is_user_busy(responder_id).await?;
2635
2636 let pool = session.connection_pool().await;
2637 // Update responder with new contact
2638 let mut update = proto::UpdateContacts::default();
2639 if accept {
2640 update
2641 .contacts
2642 .push(contact_for_user(requester_id, requester_busy, &pool));
2643 }
2644 update
2645 .remove_incoming_requests
2646 .push(requester_id.to_proto());
2647 for connection_id in pool.user_connection_ids(responder_id) {
2648 session.peer.send(connection_id, update.clone())?;
2649 }
2650
2651 // Update requester with new contact
2652 let mut update = proto::UpdateContacts::default();
2653 if accept {
2654 update
2655 .contacts
2656 .push(contact_for_user(responder_id, responder_busy, &pool));
2657 }
2658 update
2659 .remove_outgoing_requests
2660 .push(responder_id.to_proto());
2661
2662 for connection_id in pool.user_connection_ids(requester_id) {
2663 session.peer.send(connection_id, update.clone())?;
2664 }
2665
2666 send_notifications(&pool, &session.peer, notifications);
2667 }
2668
2669 response.send(proto::Ack {})?;
2670 Ok(())
2671}
2672
2673/// Remove a contact.
2674async fn remove_contact(
2675 request: proto::RemoveContact,
2676 response: Response<proto::RemoveContact>,
2677 session: Session,
2678) -> Result<()> {
2679 let requester_id = session.user_id();
2680 let responder_id = UserId::from_proto(request.user_id);
2681 let db = session.db().await;
2682 let (contact_accepted, deleted_notification_id) =
2683 db.remove_contact(requester_id, responder_id).await?;
2684
2685 let pool = session.connection_pool().await;
2686 // Update outgoing contact requests of requester
2687 let mut update = proto::UpdateContacts::default();
2688 if contact_accepted {
2689 update.remove_contacts.push(responder_id.to_proto());
2690 } else {
2691 update
2692 .remove_outgoing_requests
2693 .push(responder_id.to_proto());
2694 }
2695 for connection_id in pool.user_connection_ids(requester_id) {
2696 session.peer.send(connection_id, update.clone())?;
2697 }
2698
2699 // Update incoming contact requests of responder
2700 let mut update = proto::UpdateContacts::default();
2701 if contact_accepted {
2702 update.remove_contacts.push(requester_id.to_proto());
2703 } else {
2704 update
2705 .remove_incoming_requests
2706 .push(requester_id.to_proto());
2707 }
2708 for connection_id in pool.user_connection_ids(responder_id) {
2709 session.peer.send(connection_id, update.clone())?;
2710 if let Some(notification_id) = deleted_notification_id {
2711 session.peer.send(
2712 connection_id,
2713 proto::DeleteNotification {
2714 notification_id: notification_id.to_proto(),
2715 },
2716 )?;
2717 }
2718 }
2719
2720 response.send(proto::Ack {})?;
2721 Ok(())
2722}
2723
2724fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2725 version.0.minor() < 139
2726}
2727
2728async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2729 if is_staff {
2730 return Ok(proto::Plan::ZedPro);
2731 }
2732
2733 let subscription = db.get_active_billing_subscription(user_id).await?;
2734 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2735
2736 let plan = if let Some(subscription_kind) = subscription_kind {
2737 match subscription_kind {
2738 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2739 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2740 SubscriptionKind::ZedFree => proto::Plan::Free,
2741 }
2742 } else {
2743 proto::Plan::Free
2744 };
2745
2746 Ok(plan)
2747}
2748
2749async fn make_update_user_plan_message(
2750 user: &User,
2751 is_staff: bool,
2752 db: &Arc<Database>,
2753 llm_db: Option<Arc<LlmDatabase>>,
2754) -> Result<proto::UpdateUserPlan> {
2755 let feature_flags = db.get_user_flags(user.id).await?;
2756 let plan = current_plan(db, user.id, is_staff).await?;
2757 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2758 let billing_preferences = db.get_billing_preferences(user.id).await?;
2759
2760 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2761 let subscription = db.get_active_billing_subscription(user.id).await?;
2762
2763 let subscription_period =
2764 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2765
2766 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2767 llm_db
2768 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2769 .await?
2770 } else {
2771 None
2772 };
2773
2774 (subscription_period, usage)
2775 } else {
2776 (None, None)
2777 };
2778
2779 let bypass_account_age_check = feature_flags
2780 .iter()
2781 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2782 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2783 && !bypass_account_age_check
2784 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2785
2786 Ok(proto::UpdateUserPlan {
2787 plan: plan.into(),
2788 trial_started_at: billing_customer
2789 .as_ref()
2790 .and_then(|billing_customer| billing_customer.trial_started_at)
2791 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2792 is_usage_based_billing_enabled: if is_staff {
2793 Some(true)
2794 } else {
2795 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2796 },
2797 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2798 proto::SubscriptionPeriod {
2799 started_at: started_at.timestamp() as u64,
2800 ended_at: ended_at.timestamp() as u64,
2801 }
2802 }),
2803 account_too_young: Some(account_too_young),
2804 has_overdue_invoices: billing_customer
2805 .map(|billing_customer| billing_customer.has_overdue_invoices),
2806 usage: usage.map(|usage| {
2807 let plan = match plan {
2808 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2809 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2810 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2811 };
2812
2813 let model_requests_limit = match plan.model_requests_limit() {
2814 zed_llm_client::UsageLimit::Limited(limit) => {
2815 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2816 && feature_flags
2817 .iter()
2818 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2819 {
2820 1_000
2821 } else {
2822 limit
2823 };
2824
2825 zed_llm_client::UsageLimit::Limited(limit)
2826 }
2827 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2828 };
2829
2830 proto::SubscriptionUsage {
2831 model_requests_usage_amount: usage.model_requests as u32,
2832 model_requests_usage_limit: Some(proto::UsageLimit {
2833 variant: Some(match model_requests_limit {
2834 zed_llm_client::UsageLimit::Limited(limit) => {
2835 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2836 limit: limit as u32,
2837 })
2838 }
2839 zed_llm_client::UsageLimit::Unlimited => {
2840 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2841 }
2842 }),
2843 }),
2844 edit_predictions_usage_amount: usage.edit_predictions as u32,
2845 edit_predictions_usage_limit: Some(proto::UsageLimit {
2846 variant: Some(match plan.edit_predictions_limit() {
2847 zed_llm_client::UsageLimit::Limited(limit) => {
2848 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2849 limit: limit as u32,
2850 })
2851 }
2852 zed_llm_client::UsageLimit::Unlimited => {
2853 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2854 }
2855 }),
2856 }),
2857 }
2858 }),
2859 })
2860}
2861
2862async fn update_user_plan(session: &Session) -> Result<()> {
2863 let db = session.db().await;
2864
2865 let update_user_plan = make_update_user_plan_message(
2866 session.principal.user(),
2867 session.is_staff(),
2868 &db.0,
2869 session.app_state.llm_db.clone(),
2870 )
2871 .await?;
2872
2873 session
2874 .peer
2875 .send(session.connection_id, update_user_plan)
2876 .trace_err();
2877
2878 Ok(())
2879}
2880
2881async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2882 subscribe_user_to_channels(session.user_id(), &session).await?;
2883 Ok(())
2884}
2885
2886async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2887 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2888 let mut pool = session.connection_pool().await;
2889 for membership in &channels_for_user.channel_memberships {
2890 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2891 }
2892 session.peer.send(
2893 session.connection_id,
2894 build_update_user_channels(&channels_for_user),
2895 )?;
2896 session.peer.send(
2897 session.connection_id,
2898 build_channels_update(channels_for_user),
2899 )?;
2900 Ok(())
2901}
2902
2903/// Creates a new channel.
2904async fn create_channel(
2905 request: proto::CreateChannel,
2906 response: Response<proto::CreateChannel>,
2907 session: Session,
2908) -> Result<()> {
2909 let db = session.db().await;
2910
2911 let parent_id = request.parent_id.map(ChannelId::from_proto);
2912 let (channel, membership) = db
2913 .create_channel(&request.name, parent_id, session.user_id())
2914 .await?;
2915
2916 let root_id = channel.root_id();
2917 let channel = Channel::from_model(channel);
2918
2919 response.send(proto::CreateChannelResponse {
2920 channel: Some(channel.to_proto()),
2921 parent_id: request.parent_id,
2922 })?;
2923
2924 let mut connection_pool = session.connection_pool().await;
2925 if let Some(membership) = membership {
2926 connection_pool.subscribe_to_channel(
2927 membership.user_id,
2928 membership.channel_id,
2929 membership.role,
2930 );
2931 let update = proto::UpdateUserChannels {
2932 channel_memberships: vec![proto::ChannelMembership {
2933 channel_id: membership.channel_id.to_proto(),
2934 role: membership.role.into(),
2935 }],
2936 ..Default::default()
2937 };
2938 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2939 session.peer.send(connection_id, update.clone())?;
2940 }
2941 }
2942
2943 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2944 if !role.can_see_channel(channel.visibility) {
2945 continue;
2946 }
2947
2948 let update = proto::UpdateChannels {
2949 channels: vec![channel.to_proto()],
2950 ..Default::default()
2951 };
2952 session.peer.send(connection_id, update.clone())?;
2953 }
2954
2955 Ok(())
2956}
2957
2958/// Delete a channel
2959async fn delete_channel(
2960 request: proto::DeleteChannel,
2961 response: Response<proto::DeleteChannel>,
2962 session: Session,
2963) -> Result<()> {
2964 let db = session.db().await;
2965
2966 let channel_id = request.channel_id;
2967 let (root_channel, removed_channels) = db
2968 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2969 .await?;
2970 response.send(proto::Ack {})?;
2971
2972 // Notify members of removed channels
2973 let mut update = proto::UpdateChannels::default();
2974 update
2975 .delete_channels
2976 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2977
2978 let connection_pool = session.connection_pool().await;
2979 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2980 session.peer.send(connection_id, update.clone())?;
2981 }
2982
2983 Ok(())
2984}
2985
2986/// Invite someone to join a channel.
2987async fn invite_channel_member(
2988 request: proto::InviteChannelMember,
2989 response: Response<proto::InviteChannelMember>,
2990 session: Session,
2991) -> Result<()> {
2992 let db = session.db().await;
2993 let channel_id = ChannelId::from_proto(request.channel_id);
2994 let invitee_id = UserId::from_proto(request.user_id);
2995 let InviteMemberResult {
2996 channel,
2997 notifications,
2998 } = db
2999 .invite_channel_member(
3000 channel_id,
3001 invitee_id,
3002 session.user_id(),
3003 request.role().into(),
3004 )
3005 .await?;
3006
3007 let update = proto::UpdateChannels {
3008 channel_invitations: vec![channel.to_proto()],
3009 ..Default::default()
3010 };
3011
3012 let connection_pool = session.connection_pool().await;
3013 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3014 session.peer.send(connection_id, update.clone())?;
3015 }
3016
3017 send_notifications(&connection_pool, &session.peer, notifications);
3018
3019 response.send(proto::Ack {})?;
3020 Ok(())
3021}
3022
3023/// remove someone from a channel
3024async fn remove_channel_member(
3025 request: proto::RemoveChannelMember,
3026 response: Response<proto::RemoveChannelMember>,
3027 session: Session,
3028) -> Result<()> {
3029 let db = session.db().await;
3030 let channel_id = ChannelId::from_proto(request.channel_id);
3031 let member_id = UserId::from_proto(request.user_id);
3032
3033 let RemoveChannelMemberResult {
3034 membership_update,
3035 notification_id,
3036 } = db
3037 .remove_channel_member(channel_id, member_id, session.user_id())
3038 .await?;
3039
3040 let mut connection_pool = session.connection_pool().await;
3041 notify_membership_updated(
3042 &mut connection_pool,
3043 membership_update,
3044 member_id,
3045 &session.peer,
3046 );
3047 for connection_id in connection_pool.user_connection_ids(member_id) {
3048 if let Some(notification_id) = notification_id {
3049 session
3050 .peer
3051 .send(
3052 connection_id,
3053 proto::DeleteNotification {
3054 notification_id: notification_id.to_proto(),
3055 },
3056 )
3057 .trace_err();
3058 }
3059 }
3060
3061 response.send(proto::Ack {})?;
3062 Ok(())
3063}
3064
3065/// Toggle the channel between public and private.
3066/// Care is taken to maintain the invariant that public channels only descend from public channels,
3067/// (though members-only channels can appear at any point in the hierarchy).
3068async fn set_channel_visibility(
3069 request: proto::SetChannelVisibility,
3070 response: Response<proto::SetChannelVisibility>,
3071 session: Session,
3072) -> Result<()> {
3073 let db = session.db().await;
3074 let channel_id = ChannelId::from_proto(request.channel_id);
3075 let visibility = request.visibility().into();
3076
3077 let channel_model = db
3078 .set_channel_visibility(channel_id, visibility, session.user_id())
3079 .await?;
3080 let root_id = channel_model.root_id();
3081 let channel = Channel::from_model(channel_model);
3082
3083 let mut connection_pool = session.connection_pool().await;
3084 for (user_id, role) in connection_pool
3085 .channel_user_ids(root_id)
3086 .collect::<Vec<_>>()
3087 .into_iter()
3088 {
3089 let update = if role.can_see_channel(channel.visibility) {
3090 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3091 proto::UpdateChannels {
3092 channels: vec![channel.to_proto()],
3093 ..Default::default()
3094 }
3095 } else {
3096 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3097 proto::UpdateChannels {
3098 delete_channels: vec![channel.id.to_proto()],
3099 ..Default::default()
3100 }
3101 };
3102
3103 for connection_id in connection_pool.user_connection_ids(user_id) {
3104 session.peer.send(connection_id, update.clone())?;
3105 }
3106 }
3107
3108 response.send(proto::Ack {})?;
3109 Ok(())
3110}
3111
3112/// Alter the role for a user in the channel.
3113async fn set_channel_member_role(
3114 request: proto::SetChannelMemberRole,
3115 response: Response<proto::SetChannelMemberRole>,
3116 session: Session,
3117) -> Result<()> {
3118 let db = session.db().await;
3119 let channel_id = ChannelId::from_proto(request.channel_id);
3120 let member_id = UserId::from_proto(request.user_id);
3121 let result = db
3122 .set_channel_member_role(
3123 channel_id,
3124 session.user_id(),
3125 member_id,
3126 request.role().into(),
3127 )
3128 .await?;
3129
3130 match result {
3131 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3132 let mut connection_pool = session.connection_pool().await;
3133 notify_membership_updated(
3134 &mut connection_pool,
3135 membership_update,
3136 member_id,
3137 &session.peer,
3138 )
3139 }
3140 db::SetMemberRoleResult::InviteUpdated(channel) => {
3141 let update = proto::UpdateChannels {
3142 channel_invitations: vec![channel.to_proto()],
3143 ..Default::default()
3144 };
3145
3146 for connection_id in session
3147 .connection_pool()
3148 .await
3149 .user_connection_ids(member_id)
3150 {
3151 session.peer.send(connection_id, update.clone())?;
3152 }
3153 }
3154 }
3155
3156 response.send(proto::Ack {})?;
3157 Ok(())
3158}
3159
3160/// Change the name of a channel
3161async fn rename_channel(
3162 request: proto::RenameChannel,
3163 response: Response<proto::RenameChannel>,
3164 session: Session,
3165) -> Result<()> {
3166 let db = session.db().await;
3167 let channel_id = ChannelId::from_proto(request.channel_id);
3168 let channel_model = db
3169 .rename_channel(channel_id, session.user_id(), &request.name)
3170 .await?;
3171 let root_id = channel_model.root_id();
3172 let channel = Channel::from_model(channel_model);
3173
3174 response.send(proto::RenameChannelResponse {
3175 channel: Some(channel.to_proto()),
3176 })?;
3177
3178 let connection_pool = session.connection_pool().await;
3179 let update = proto::UpdateChannels {
3180 channels: vec![channel.to_proto()],
3181 ..Default::default()
3182 };
3183 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3184 if role.can_see_channel(channel.visibility) {
3185 session.peer.send(connection_id, update.clone())?;
3186 }
3187 }
3188
3189 Ok(())
3190}
3191
3192/// Move a channel to a new parent.
3193async fn move_channel(
3194 request: proto::MoveChannel,
3195 response: Response<proto::MoveChannel>,
3196 session: Session,
3197) -> Result<()> {
3198 let channel_id = ChannelId::from_proto(request.channel_id);
3199 let to = ChannelId::from_proto(request.to);
3200
3201 let (root_id, channels) = session
3202 .db()
3203 .await
3204 .move_channel(channel_id, to, session.user_id())
3205 .await?;
3206
3207 let connection_pool = session.connection_pool().await;
3208 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3209 let channels = channels
3210 .iter()
3211 .filter_map(|channel| {
3212 if role.can_see_channel(channel.visibility) {
3213 Some(channel.to_proto())
3214 } else {
3215 None
3216 }
3217 })
3218 .collect::<Vec<_>>();
3219 if channels.is_empty() {
3220 continue;
3221 }
3222
3223 let update = proto::UpdateChannels {
3224 channels,
3225 ..Default::default()
3226 };
3227
3228 session.peer.send(connection_id, update.clone())?;
3229 }
3230
3231 response.send(Ack {})?;
3232 Ok(())
3233}
3234
3235async fn reorder_channel(
3236 request: proto::ReorderChannel,
3237 response: Response<proto::ReorderChannel>,
3238 session: Session,
3239) -> Result<()> {
3240 let channel_id = ChannelId::from_proto(request.channel_id);
3241 let direction = request.direction();
3242
3243 let updated_channels = session
3244 .db()
3245 .await
3246 .reorder_channel(channel_id, direction, session.user_id())
3247 .await?;
3248
3249 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3250 let connection_pool = session.connection_pool().await;
3251 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3252 let channels = updated_channels
3253 .iter()
3254 .filter_map(|channel| {
3255 if role.can_see_channel(channel.visibility) {
3256 Some(channel.to_proto())
3257 } else {
3258 None
3259 }
3260 })
3261 .collect::<Vec<_>>();
3262
3263 if channels.is_empty() {
3264 continue;
3265 }
3266
3267 let update = proto::UpdateChannels {
3268 channels,
3269 ..Default::default()
3270 };
3271
3272 session.peer.send(connection_id, update.clone())?;
3273 }
3274 }
3275
3276 response.send(Ack {})?;
3277 Ok(())
3278}
3279
3280/// Get the list of channel members
3281async fn get_channel_members(
3282 request: proto::GetChannelMembers,
3283 response: Response<proto::GetChannelMembers>,
3284 session: Session,
3285) -> Result<()> {
3286 let db = session.db().await;
3287 let channel_id = ChannelId::from_proto(request.channel_id);
3288 let limit = if request.limit == 0 {
3289 u16::MAX as u64
3290 } else {
3291 request.limit
3292 };
3293 let (members, users) = db
3294 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3295 .await?;
3296 response.send(proto::GetChannelMembersResponse { members, users })?;
3297 Ok(())
3298}
3299
3300/// Accept or decline a channel invitation.
3301async fn respond_to_channel_invite(
3302 request: proto::RespondToChannelInvite,
3303 response: Response<proto::RespondToChannelInvite>,
3304 session: Session,
3305) -> Result<()> {
3306 let db = session.db().await;
3307 let channel_id = ChannelId::from_proto(request.channel_id);
3308 let RespondToChannelInvite {
3309 membership_update,
3310 notifications,
3311 } = db
3312 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3313 .await?;
3314
3315 let mut connection_pool = session.connection_pool().await;
3316 if let Some(membership_update) = membership_update {
3317 notify_membership_updated(
3318 &mut connection_pool,
3319 membership_update,
3320 session.user_id(),
3321 &session.peer,
3322 );
3323 } else {
3324 let update = proto::UpdateChannels {
3325 remove_channel_invitations: vec![channel_id.to_proto()],
3326 ..Default::default()
3327 };
3328
3329 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3330 session.peer.send(connection_id, update.clone())?;
3331 }
3332 };
3333
3334 send_notifications(&connection_pool, &session.peer, notifications);
3335
3336 response.send(proto::Ack {})?;
3337
3338 Ok(())
3339}
3340
3341/// Join the channels' room
3342async fn join_channel(
3343 request: proto::JoinChannel,
3344 response: Response<proto::JoinChannel>,
3345 session: Session,
3346) -> Result<()> {
3347 let channel_id = ChannelId::from_proto(request.channel_id);
3348 join_channel_internal(channel_id, Box::new(response), session).await
3349}
3350
3351trait JoinChannelInternalResponse {
3352 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3353}
3354impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3355 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3356 Response::<proto::JoinChannel>::send(self, result)
3357 }
3358}
3359impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3360 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3361 Response::<proto::JoinRoom>::send(self, result)
3362 }
3363}
3364
3365async fn join_channel_internal(
3366 channel_id: ChannelId,
3367 response: Box<impl JoinChannelInternalResponse>,
3368 session: Session,
3369) -> Result<()> {
3370 let joined_room = {
3371 let mut db = session.db().await;
3372 // If zed quits without leaving the room, and the user re-opens zed before the
3373 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3374 // room they were in.
3375 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3376 tracing::info!(
3377 stale_connection_id = %connection,
3378 "cleaning up stale connection",
3379 );
3380 drop(db);
3381 leave_room_for_session(&session, connection).await?;
3382 db = session.db().await;
3383 }
3384
3385 let (joined_room, membership_updated, role) = db
3386 .join_channel(channel_id, session.user_id(), session.connection_id)
3387 .await?;
3388
3389 let live_kit_connection_info =
3390 session
3391 .app_state
3392 .livekit_client
3393 .as_ref()
3394 .and_then(|live_kit| {
3395 let (can_publish, token) = if role == ChannelRole::Guest {
3396 (
3397 false,
3398 live_kit
3399 .guest_token(
3400 &joined_room.room.livekit_room,
3401 &session.user_id().to_string(),
3402 )
3403 .trace_err()?,
3404 )
3405 } else {
3406 (
3407 true,
3408 live_kit
3409 .room_token(
3410 &joined_room.room.livekit_room,
3411 &session.user_id().to_string(),
3412 )
3413 .trace_err()?,
3414 )
3415 };
3416
3417 Some(LiveKitConnectionInfo {
3418 server_url: live_kit.url().into(),
3419 token,
3420 can_publish,
3421 })
3422 });
3423
3424 response.send(proto::JoinRoomResponse {
3425 room: Some(joined_room.room.clone()),
3426 channel_id: joined_room
3427 .channel
3428 .as_ref()
3429 .map(|channel| channel.id.to_proto()),
3430 live_kit_connection_info,
3431 })?;
3432
3433 let mut connection_pool = session.connection_pool().await;
3434 if let Some(membership_updated) = membership_updated {
3435 notify_membership_updated(
3436 &mut connection_pool,
3437 membership_updated,
3438 session.user_id(),
3439 &session.peer,
3440 );
3441 }
3442
3443 room_updated(&joined_room.room, &session.peer);
3444
3445 joined_room
3446 };
3447
3448 channel_updated(
3449 &joined_room.channel.context("channel not returned")?,
3450 &joined_room.room,
3451 &session.peer,
3452 &*session.connection_pool().await,
3453 );
3454
3455 update_user_contacts(session.user_id(), &session).await?;
3456 Ok(())
3457}
3458
3459/// Start editing the channel notes
3460async fn join_channel_buffer(
3461 request: proto::JoinChannelBuffer,
3462 response: Response<proto::JoinChannelBuffer>,
3463 session: Session,
3464) -> Result<()> {
3465 let db = session.db().await;
3466 let channel_id = ChannelId::from_proto(request.channel_id);
3467
3468 let open_response = db
3469 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3470 .await?;
3471
3472 let collaborators = open_response.collaborators.clone();
3473 response.send(open_response)?;
3474
3475 let update = UpdateChannelBufferCollaborators {
3476 channel_id: channel_id.to_proto(),
3477 collaborators: collaborators.clone(),
3478 };
3479 channel_buffer_updated(
3480 session.connection_id,
3481 collaborators
3482 .iter()
3483 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3484 &update,
3485 &session.peer,
3486 );
3487
3488 Ok(())
3489}
3490
3491/// Edit the channel notes
3492async fn update_channel_buffer(
3493 request: proto::UpdateChannelBuffer,
3494 session: Session,
3495) -> Result<()> {
3496 let db = session.db().await;
3497 let channel_id = ChannelId::from_proto(request.channel_id);
3498
3499 let (collaborators, epoch, version) = db
3500 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3501 .await?;
3502
3503 channel_buffer_updated(
3504 session.connection_id,
3505 collaborators.clone(),
3506 &proto::UpdateChannelBuffer {
3507 channel_id: channel_id.to_proto(),
3508 operations: request.operations,
3509 },
3510 &session.peer,
3511 );
3512
3513 let pool = &*session.connection_pool().await;
3514
3515 let non_collaborators =
3516 pool.channel_connection_ids(channel_id)
3517 .filter_map(|(connection_id, _)| {
3518 if collaborators.contains(&connection_id) {
3519 None
3520 } else {
3521 Some(connection_id)
3522 }
3523 });
3524
3525 broadcast(None, non_collaborators, |peer_id| {
3526 session.peer.send(
3527 peer_id,
3528 proto::UpdateChannels {
3529 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3530 channel_id: channel_id.to_proto(),
3531 epoch: epoch as u64,
3532 version: version.clone(),
3533 }],
3534 ..Default::default()
3535 },
3536 )
3537 });
3538
3539 Ok(())
3540}
3541
3542/// Rejoin the channel notes after a connection blip
3543async fn rejoin_channel_buffers(
3544 request: proto::RejoinChannelBuffers,
3545 response: Response<proto::RejoinChannelBuffers>,
3546 session: Session,
3547) -> Result<()> {
3548 let db = session.db().await;
3549 let buffers = db
3550 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3551 .await?;
3552
3553 for rejoined_buffer in &buffers {
3554 let collaborators_to_notify = rejoined_buffer
3555 .buffer
3556 .collaborators
3557 .iter()
3558 .filter_map(|c| Some(c.peer_id?.into()));
3559 channel_buffer_updated(
3560 session.connection_id,
3561 collaborators_to_notify,
3562 &proto::UpdateChannelBufferCollaborators {
3563 channel_id: rejoined_buffer.buffer.channel_id,
3564 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3565 },
3566 &session.peer,
3567 );
3568 }
3569
3570 response.send(proto::RejoinChannelBuffersResponse {
3571 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3572 })?;
3573
3574 Ok(())
3575}
3576
3577/// Stop editing the channel notes
3578async fn leave_channel_buffer(
3579 request: proto::LeaveChannelBuffer,
3580 response: Response<proto::LeaveChannelBuffer>,
3581 session: Session,
3582) -> Result<()> {
3583 let db = session.db().await;
3584 let channel_id = ChannelId::from_proto(request.channel_id);
3585
3586 let left_buffer = db
3587 .leave_channel_buffer(channel_id, session.connection_id)
3588 .await?;
3589
3590 response.send(Ack {})?;
3591
3592 channel_buffer_updated(
3593 session.connection_id,
3594 left_buffer.connections,
3595 &proto::UpdateChannelBufferCollaborators {
3596 channel_id: channel_id.to_proto(),
3597 collaborators: left_buffer.collaborators,
3598 },
3599 &session.peer,
3600 );
3601
3602 Ok(())
3603}
3604
3605fn channel_buffer_updated<T: EnvelopedMessage>(
3606 sender_id: ConnectionId,
3607 collaborators: impl IntoIterator<Item = ConnectionId>,
3608 message: &T,
3609 peer: &Peer,
3610) {
3611 broadcast(Some(sender_id), collaborators, |peer_id| {
3612 peer.send(peer_id, message.clone())
3613 });
3614}
3615
3616fn send_notifications(
3617 connection_pool: &ConnectionPool,
3618 peer: &Peer,
3619 notifications: db::NotificationBatch,
3620) {
3621 for (user_id, notification) in notifications {
3622 for connection_id in connection_pool.user_connection_ids(user_id) {
3623 if let Err(error) = peer.send(
3624 connection_id,
3625 proto::AddNotification {
3626 notification: Some(notification.clone()),
3627 },
3628 ) {
3629 tracing::error!(
3630 "failed to send notification to {:?} {}",
3631 connection_id,
3632 error
3633 );
3634 }
3635 }
3636 }
3637}
3638
3639/// Send a message to the channel
3640async fn send_channel_message(
3641 request: proto::SendChannelMessage,
3642 response: Response<proto::SendChannelMessage>,
3643 session: Session,
3644) -> Result<()> {
3645 // Validate the message body.
3646 let body = request.body.trim().to_string();
3647 if body.len() > MAX_MESSAGE_LEN {
3648 return Err(anyhow!("message is too long"))?;
3649 }
3650 if body.is_empty() {
3651 return Err(anyhow!("message can't be blank"))?;
3652 }
3653
3654 // TODO: adjust mentions if body is trimmed
3655
3656 let timestamp = OffsetDateTime::now_utc();
3657 let nonce = request.nonce.context("nonce can't be blank")?;
3658
3659 let channel_id = ChannelId::from_proto(request.channel_id);
3660 let CreatedChannelMessage {
3661 message_id,
3662 participant_connection_ids,
3663 notifications,
3664 } = session
3665 .db()
3666 .await
3667 .create_channel_message(
3668 channel_id,
3669 session.user_id(),
3670 &body,
3671 &request.mentions,
3672 timestamp,
3673 nonce.clone().into(),
3674 request.reply_to_message_id.map(MessageId::from_proto),
3675 )
3676 .await?;
3677
3678 let message = proto::ChannelMessage {
3679 sender_id: session.user_id().to_proto(),
3680 id: message_id.to_proto(),
3681 body,
3682 mentions: request.mentions,
3683 timestamp: timestamp.unix_timestamp() as u64,
3684 nonce: Some(nonce),
3685 reply_to_message_id: request.reply_to_message_id,
3686 edited_at: None,
3687 };
3688 broadcast(
3689 Some(session.connection_id),
3690 participant_connection_ids.clone(),
3691 |connection| {
3692 session.peer.send(
3693 connection,
3694 proto::ChannelMessageSent {
3695 channel_id: channel_id.to_proto(),
3696 message: Some(message.clone()),
3697 },
3698 )
3699 },
3700 );
3701 response.send(proto::SendChannelMessageResponse {
3702 message: Some(message),
3703 })?;
3704
3705 let pool = &*session.connection_pool().await;
3706 let non_participants =
3707 pool.channel_connection_ids(channel_id)
3708 .filter_map(|(connection_id, _)| {
3709 if participant_connection_ids.contains(&connection_id) {
3710 None
3711 } else {
3712 Some(connection_id)
3713 }
3714 });
3715 broadcast(None, non_participants, |peer_id| {
3716 session.peer.send(
3717 peer_id,
3718 proto::UpdateChannels {
3719 latest_channel_message_ids: vec![proto::ChannelMessageId {
3720 channel_id: channel_id.to_proto(),
3721 message_id: message_id.to_proto(),
3722 }],
3723 ..Default::default()
3724 },
3725 )
3726 });
3727 send_notifications(pool, &session.peer, notifications);
3728
3729 Ok(())
3730}
3731
3732/// Delete a channel message
3733async fn remove_channel_message(
3734 request: proto::RemoveChannelMessage,
3735 response: Response<proto::RemoveChannelMessage>,
3736 session: Session,
3737) -> Result<()> {
3738 let channel_id = ChannelId::from_proto(request.channel_id);
3739 let message_id = MessageId::from_proto(request.message_id);
3740 let (connection_ids, existing_notification_ids) = session
3741 .db()
3742 .await
3743 .remove_channel_message(channel_id, message_id, session.user_id())
3744 .await?;
3745
3746 broadcast(
3747 Some(session.connection_id),
3748 connection_ids,
3749 move |connection| {
3750 session.peer.send(connection, request.clone())?;
3751
3752 for notification_id in &existing_notification_ids {
3753 session.peer.send(
3754 connection,
3755 proto::DeleteNotification {
3756 notification_id: (*notification_id).to_proto(),
3757 },
3758 )?;
3759 }
3760
3761 Ok(())
3762 },
3763 );
3764 response.send(proto::Ack {})?;
3765 Ok(())
3766}
3767
3768async fn update_channel_message(
3769 request: proto::UpdateChannelMessage,
3770 response: Response<proto::UpdateChannelMessage>,
3771 session: Session,
3772) -> Result<()> {
3773 let channel_id = ChannelId::from_proto(request.channel_id);
3774 let message_id = MessageId::from_proto(request.message_id);
3775 let updated_at = OffsetDateTime::now_utc();
3776 let UpdatedChannelMessage {
3777 message_id,
3778 participant_connection_ids,
3779 notifications,
3780 reply_to_message_id,
3781 timestamp,
3782 deleted_mention_notification_ids,
3783 updated_mention_notifications,
3784 } = session
3785 .db()
3786 .await
3787 .update_channel_message(
3788 channel_id,
3789 message_id,
3790 session.user_id(),
3791 request.body.as_str(),
3792 &request.mentions,
3793 updated_at,
3794 )
3795 .await?;
3796
3797 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3798
3799 let message = proto::ChannelMessage {
3800 sender_id: session.user_id().to_proto(),
3801 id: message_id.to_proto(),
3802 body: request.body.clone(),
3803 mentions: request.mentions.clone(),
3804 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3805 nonce: Some(nonce),
3806 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3807 edited_at: Some(updated_at.unix_timestamp() as u64),
3808 };
3809
3810 response.send(proto::Ack {})?;
3811
3812 let pool = &*session.connection_pool().await;
3813 broadcast(
3814 Some(session.connection_id),
3815 participant_connection_ids,
3816 |connection| {
3817 session.peer.send(
3818 connection,
3819 proto::ChannelMessageUpdate {
3820 channel_id: channel_id.to_proto(),
3821 message: Some(message.clone()),
3822 },
3823 )?;
3824
3825 for notification_id in &deleted_mention_notification_ids {
3826 session.peer.send(
3827 connection,
3828 proto::DeleteNotification {
3829 notification_id: (*notification_id).to_proto(),
3830 },
3831 )?;
3832 }
3833
3834 for notification in &updated_mention_notifications {
3835 session.peer.send(
3836 connection,
3837 proto::UpdateNotification {
3838 notification: Some(notification.clone()),
3839 },
3840 )?;
3841 }
3842
3843 Ok(())
3844 },
3845 );
3846
3847 send_notifications(pool, &session.peer, notifications);
3848
3849 Ok(())
3850}
3851
3852/// Mark a channel message as read
3853async fn acknowledge_channel_message(
3854 request: proto::AckChannelMessage,
3855 session: Session,
3856) -> Result<()> {
3857 let channel_id = ChannelId::from_proto(request.channel_id);
3858 let message_id = MessageId::from_proto(request.message_id);
3859 let notifications = session
3860 .db()
3861 .await
3862 .observe_channel_message(channel_id, session.user_id(), message_id)
3863 .await?;
3864 send_notifications(
3865 &*session.connection_pool().await,
3866 &session.peer,
3867 notifications,
3868 );
3869 Ok(())
3870}
3871
3872/// Mark a buffer version as synced
3873async fn acknowledge_buffer_version(
3874 request: proto::AckBufferOperation,
3875 session: Session,
3876) -> Result<()> {
3877 let buffer_id = BufferId::from_proto(request.buffer_id);
3878 session
3879 .db()
3880 .await
3881 .observe_buffer_version(
3882 buffer_id,
3883 session.user_id(),
3884 request.epoch as i32,
3885 &request.version,
3886 )
3887 .await?;
3888 Ok(())
3889}
3890
3891/// Get a Supermaven API key for the user
3892async fn get_supermaven_api_key(
3893 _request: proto::GetSupermavenApiKey,
3894 response: Response<proto::GetSupermavenApiKey>,
3895 session: Session,
3896) -> Result<()> {
3897 let user_id: String = session.user_id().to_string();
3898 if !session.is_staff() {
3899 return Err(anyhow!("supermaven not enabled for this account"))?;
3900 }
3901
3902 let email = session.email().context("user must have an email")?;
3903
3904 let supermaven_admin_api = session
3905 .supermaven_client
3906 .as_ref()
3907 .context("supermaven not configured")?;
3908
3909 let result = supermaven_admin_api
3910 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3911 .await?;
3912
3913 response.send(proto::GetSupermavenApiKeyResponse {
3914 api_key: result.api_key,
3915 })?;
3916
3917 Ok(())
3918}
3919
3920/// Start receiving chat updates for a channel
3921async fn join_channel_chat(
3922 request: proto::JoinChannelChat,
3923 response: Response<proto::JoinChannelChat>,
3924 session: Session,
3925) -> Result<()> {
3926 let channel_id = ChannelId::from_proto(request.channel_id);
3927
3928 let db = session.db().await;
3929 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3930 .await?;
3931 let messages = db
3932 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3933 .await?;
3934 response.send(proto::JoinChannelChatResponse {
3935 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3936 messages,
3937 })?;
3938 Ok(())
3939}
3940
3941/// Stop receiving chat updates for a channel
3942async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3943 let channel_id = ChannelId::from_proto(request.channel_id);
3944 session
3945 .db()
3946 .await
3947 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3948 .await?;
3949 Ok(())
3950}
3951
3952/// Retrieve the chat history for a channel
3953async fn get_channel_messages(
3954 request: proto::GetChannelMessages,
3955 response: Response<proto::GetChannelMessages>,
3956 session: Session,
3957) -> Result<()> {
3958 let channel_id = ChannelId::from_proto(request.channel_id);
3959 let messages = session
3960 .db()
3961 .await
3962 .get_channel_messages(
3963 channel_id,
3964 session.user_id(),
3965 MESSAGE_COUNT_PER_PAGE,
3966 Some(MessageId::from_proto(request.before_message_id)),
3967 )
3968 .await?;
3969 response.send(proto::GetChannelMessagesResponse {
3970 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3971 messages,
3972 })?;
3973 Ok(())
3974}
3975
3976/// Retrieve specific chat messages
3977async fn get_channel_messages_by_id(
3978 request: proto::GetChannelMessagesById,
3979 response: Response<proto::GetChannelMessagesById>,
3980 session: Session,
3981) -> Result<()> {
3982 let message_ids = request
3983 .message_ids
3984 .iter()
3985 .map(|id| MessageId::from_proto(*id))
3986 .collect::<Vec<_>>();
3987 let messages = session
3988 .db()
3989 .await
3990 .get_channel_messages_by_id(session.user_id(), &message_ids)
3991 .await?;
3992 response.send(proto::GetChannelMessagesResponse {
3993 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3994 messages,
3995 })?;
3996 Ok(())
3997}
3998
3999/// Retrieve the current users notifications
4000async fn get_notifications(
4001 request: proto::GetNotifications,
4002 response: Response<proto::GetNotifications>,
4003 session: Session,
4004) -> Result<()> {
4005 let notifications = session
4006 .db()
4007 .await
4008 .get_notifications(
4009 session.user_id(),
4010 NOTIFICATION_COUNT_PER_PAGE,
4011 request.before_id.map(db::NotificationId::from_proto),
4012 )
4013 .await?;
4014 response.send(proto::GetNotificationsResponse {
4015 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4016 notifications,
4017 })?;
4018 Ok(())
4019}
4020
4021/// Mark notifications as read
4022async fn mark_notification_as_read(
4023 request: proto::MarkNotificationRead,
4024 response: Response<proto::MarkNotificationRead>,
4025 session: Session,
4026) -> Result<()> {
4027 let database = &session.db().await;
4028 let notifications = database
4029 .mark_notification_as_read_by_id(
4030 session.user_id(),
4031 NotificationId::from_proto(request.notification_id),
4032 )
4033 .await?;
4034 send_notifications(
4035 &*session.connection_pool().await,
4036 &session.peer,
4037 notifications,
4038 );
4039 response.send(proto::Ack {})?;
4040 Ok(())
4041}
4042
4043/// Get the current users information
4044async fn get_private_user_info(
4045 _request: proto::GetPrivateUserInfo,
4046 response: Response<proto::GetPrivateUserInfo>,
4047 session: Session,
4048) -> Result<()> {
4049 let db = session.db().await;
4050
4051 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4052 let user = db
4053 .get_user_by_id(session.user_id())
4054 .await?
4055 .context("user not found")?;
4056 let flags = db.get_user_flags(session.user_id()).await?;
4057
4058 response.send(proto::GetPrivateUserInfoResponse {
4059 metrics_id,
4060 staff: user.admin,
4061 flags,
4062 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4063 })?;
4064 Ok(())
4065}
4066
4067/// Accept the terms of service (tos) on behalf of the current user
4068async fn accept_terms_of_service(
4069 _request: proto::AcceptTermsOfService,
4070 response: Response<proto::AcceptTermsOfService>,
4071 session: Session,
4072) -> Result<()> {
4073 let db = session.db().await;
4074
4075 let accepted_tos_at = Utc::now();
4076 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4077 .await?;
4078
4079 response.send(proto::AcceptTermsOfServiceResponse {
4080 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4081 })?;
4082 Ok(())
4083}
4084
4085async fn get_llm_api_token(
4086 _request: proto::GetLlmToken,
4087 response: Response<proto::GetLlmToken>,
4088 session: Session,
4089) -> Result<()> {
4090 let db = session.db().await;
4091
4092 let flags = db.get_user_flags(session.user_id()).await?;
4093
4094 let user_id = session.user_id();
4095 let user = db
4096 .get_user_by_id(user_id)
4097 .await?
4098 .with_context(|| format!("user {user_id} not found"))?;
4099
4100 if user.accepted_tos_at.is_none() {
4101 Err(anyhow!("terms of service not accepted"))?
4102 }
4103
4104 let stripe_client = session
4105 .app_state
4106 .stripe_client
4107 .as_ref()
4108 .context("failed to retrieve Stripe client")?;
4109
4110 let stripe_billing = session
4111 .app_state
4112 .stripe_billing
4113 .as_ref()
4114 .context("failed to retrieve Stripe billing object")?;
4115
4116 let billing_customer = if let Some(billing_customer) =
4117 db.get_billing_customer_by_user_id(user.id).await?
4118 {
4119 billing_customer
4120 } else {
4121 let customer_id = stripe_billing
4122 .find_or_create_customer_by_email(user.email_address.as_deref())
4123 .await?;
4124
4125 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4126 .await?
4127 .context("billing customer not found")?
4128 };
4129
4130 let billing_subscription =
4131 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4132 billing_subscription
4133 } else {
4134 let stripe_customer_id =
4135 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4136
4137 let stripe_subscription = stripe_billing
4138 .subscribe_to_zed_free(stripe_customer_id)
4139 .await?;
4140
4141 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4142 billing_customer_id: billing_customer.id,
4143 kind: Some(SubscriptionKind::ZedFree),
4144 stripe_subscription_id: stripe_subscription.id.to_string(),
4145 stripe_subscription_status: stripe_subscription.status.into(),
4146 stripe_cancellation_reason: None,
4147 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4148 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4149 })
4150 .await?
4151 };
4152
4153 let billing_preferences = db.get_billing_preferences(user.id).await?;
4154
4155 let token = LlmTokenClaims::create(
4156 &user,
4157 session.is_staff(),
4158 billing_customer,
4159 billing_preferences,
4160 &flags,
4161 billing_subscription,
4162 session.system_id.clone(),
4163 &session.app_state.config,
4164 )?;
4165 response.send(proto::GetLlmTokenResponse { token })?;
4166 Ok(())
4167}
4168
4169fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4170 let message = match message {
4171 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4172 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4173 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4174 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4175 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4176 code: frame.code.into(),
4177 reason: frame.reason.as_str().to_owned().into(),
4178 })),
4179 // We should never receive a frame while reading the message, according
4180 // to the `tungstenite` maintainers:
4181 //
4182 // > It cannot occur when you read messages from the WebSocket, but it
4183 // > can be used when you want to send the raw frames (e.g. you want to
4184 // > send the frames to the WebSocket without composing the full message first).
4185 // >
4186 // > — https://github.com/snapview/tungstenite-rs/issues/268
4187 TungsteniteMessage::Frame(_) => {
4188 bail!("received an unexpected frame while reading the message")
4189 }
4190 };
4191
4192 Ok(message)
4193}
4194
4195fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4196 match message {
4197 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4198 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4199 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4200 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4201 AxumMessage::Close(frame) => {
4202 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4203 code: frame.code.into(),
4204 reason: frame.reason.as_ref().into(),
4205 }))
4206 }
4207 }
4208}
4209
4210fn notify_membership_updated(
4211 connection_pool: &mut ConnectionPool,
4212 result: MembershipUpdated,
4213 user_id: UserId,
4214 peer: &Peer,
4215) {
4216 for membership in &result.new_channels.channel_memberships {
4217 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4218 }
4219 for channel_id in &result.removed_channels {
4220 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4221 }
4222
4223 let user_channels_update = proto::UpdateUserChannels {
4224 channel_memberships: result
4225 .new_channels
4226 .channel_memberships
4227 .iter()
4228 .map(|cm| proto::ChannelMembership {
4229 channel_id: cm.channel_id.to_proto(),
4230 role: cm.role.into(),
4231 })
4232 .collect(),
4233 ..Default::default()
4234 };
4235
4236 let mut update = build_channels_update(result.new_channels);
4237 update.delete_channels = result
4238 .removed_channels
4239 .into_iter()
4240 .map(|id| id.to_proto())
4241 .collect();
4242 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4243
4244 for connection_id in connection_pool.user_connection_ids(user_id) {
4245 peer.send(connection_id, user_channels_update.clone())
4246 .trace_err();
4247 peer.send(connection_id, update.clone()).trace_err();
4248 }
4249}
4250
4251fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4252 proto::UpdateUserChannels {
4253 channel_memberships: channels
4254 .channel_memberships
4255 .iter()
4256 .map(|m| proto::ChannelMembership {
4257 channel_id: m.channel_id.to_proto(),
4258 role: m.role.into(),
4259 })
4260 .collect(),
4261 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4262 observed_channel_message_id: channels.observed_channel_messages.clone(),
4263 }
4264}
4265
4266fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4267 let mut update = proto::UpdateChannels::default();
4268
4269 for channel in channels.channels {
4270 update.channels.push(channel.to_proto());
4271 }
4272
4273 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4274 update.latest_channel_message_ids = channels.latest_channel_messages;
4275
4276 for (channel_id, participants) in channels.channel_participants {
4277 update
4278 .channel_participants
4279 .push(proto::ChannelParticipants {
4280 channel_id: channel_id.to_proto(),
4281 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4282 });
4283 }
4284
4285 for channel in channels.invited_channels {
4286 update.channel_invitations.push(channel.to_proto());
4287 }
4288
4289 update
4290}
4291
4292fn build_initial_contacts_update(
4293 contacts: Vec<db::Contact>,
4294 pool: &ConnectionPool,
4295) -> proto::UpdateContacts {
4296 let mut update = proto::UpdateContacts::default();
4297
4298 for contact in contacts {
4299 match contact {
4300 db::Contact::Accepted { user_id, busy } => {
4301 update.contacts.push(contact_for_user(user_id, busy, pool));
4302 }
4303 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4304 db::Contact::Incoming { user_id } => {
4305 update
4306 .incoming_requests
4307 .push(proto::IncomingContactRequest {
4308 requester_id: user_id.to_proto(),
4309 })
4310 }
4311 }
4312 }
4313
4314 update
4315}
4316
4317fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4318 proto::Contact {
4319 user_id: user_id.to_proto(),
4320 online: pool.is_user_online(user_id),
4321 busy,
4322 }
4323}
4324
4325fn room_updated(room: &proto::Room, peer: &Peer) {
4326 broadcast(
4327 None,
4328 room.participants
4329 .iter()
4330 .filter_map(|participant| Some(participant.peer_id?.into())),
4331 |peer_id| {
4332 peer.send(
4333 peer_id,
4334 proto::RoomUpdated {
4335 room: Some(room.clone()),
4336 },
4337 )
4338 },
4339 );
4340}
4341
4342fn channel_updated(
4343 channel: &db::channel::Model,
4344 room: &proto::Room,
4345 peer: &Peer,
4346 pool: &ConnectionPool,
4347) {
4348 let participants = room
4349 .participants
4350 .iter()
4351 .map(|p| p.user_id)
4352 .collect::<Vec<_>>();
4353
4354 broadcast(
4355 None,
4356 pool.channel_connection_ids(channel.root_id())
4357 .filter_map(|(channel_id, role)| {
4358 role.can_see_channel(channel.visibility)
4359 .then_some(channel_id)
4360 }),
4361 |peer_id| {
4362 peer.send(
4363 peer_id,
4364 proto::UpdateChannels {
4365 channel_participants: vec![proto::ChannelParticipants {
4366 channel_id: channel.id.to_proto(),
4367 participant_user_ids: participants.clone(),
4368 }],
4369 ..Default::default()
4370 },
4371 )
4372 },
4373 );
4374}
4375
4376async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4377 let db = session.db().await;
4378
4379 let contacts = db.get_contacts(user_id).await?;
4380 let busy = db.is_user_busy(user_id).await?;
4381
4382 let pool = session.connection_pool().await;
4383 let updated_contact = contact_for_user(user_id, busy, &pool);
4384 for contact in contacts {
4385 if let db::Contact::Accepted {
4386 user_id: contact_user_id,
4387 ..
4388 } = contact
4389 {
4390 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4391 session
4392 .peer
4393 .send(
4394 contact_conn_id,
4395 proto::UpdateContacts {
4396 contacts: vec![updated_contact.clone()],
4397 remove_contacts: Default::default(),
4398 incoming_requests: Default::default(),
4399 remove_incoming_requests: Default::default(),
4400 outgoing_requests: Default::default(),
4401 remove_outgoing_requests: Default::default(),
4402 },
4403 )
4404 .trace_err();
4405 }
4406 }
4407 }
4408 Ok(())
4409}
4410
4411async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4412 let mut contacts_to_update = HashSet::default();
4413
4414 let room_id;
4415 let canceled_calls_to_user_ids;
4416 let livekit_room;
4417 let delete_livekit_room;
4418 let room;
4419 let channel;
4420
4421 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4422 contacts_to_update.insert(session.user_id());
4423
4424 for project in left_room.left_projects.values() {
4425 project_left(project, session);
4426 }
4427
4428 room_id = RoomId::from_proto(left_room.room.id);
4429 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4430 livekit_room = mem::take(&mut left_room.room.livekit_room);
4431 delete_livekit_room = left_room.deleted;
4432 room = mem::take(&mut left_room.room);
4433 channel = mem::take(&mut left_room.channel);
4434
4435 room_updated(&room, &session.peer);
4436 } else {
4437 return Ok(());
4438 }
4439
4440 if let Some(channel) = channel {
4441 channel_updated(
4442 &channel,
4443 &room,
4444 &session.peer,
4445 &*session.connection_pool().await,
4446 );
4447 }
4448
4449 {
4450 let pool = session.connection_pool().await;
4451 for canceled_user_id in canceled_calls_to_user_ids {
4452 for connection_id in pool.user_connection_ids(canceled_user_id) {
4453 session
4454 .peer
4455 .send(
4456 connection_id,
4457 proto::CallCanceled {
4458 room_id: room_id.to_proto(),
4459 },
4460 )
4461 .trace_err();
4462 }
4463 contacts_to_update.insert(canceled_user_id);
4464 }
4465 }
4466
4467 for contact_user_id in contacts_to_update {
4468 update_user_contacts(contact_user_id, session).await?;
4469 }
4470
4471 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4472 live_kit
4473 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4474 .await
4475 .trace_err();
4476
4477 if delete_livekit_room {
4478 live_kit.delete_room(livekit_room).await.trace_err();
4479 }
4480 }
4481
4482 Ok(())
4483}
4484
4485async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4486 let left_channel_buffers = session
4487 .db()
4488 .await
4489 .leave_channel_buffers(session.connection_id)
4490 .await?;
4491
4492 for left_buffer in left_channel_buffers {
4493 channel_buffer_updated(
4494 session.connection_id,
4495 left_buffer.connections,
4496 &proto::UpdateChannelBufferCollaborators {
4497 channel_id: left_buffer.channel_id.to_proto(),
4498 collaborators: left_buffer.collaborators,
4499 },
4500 &session.peer,
4501 );
4502 }
4503
4504 Ok(())
4505}
4506
4507fn project_left(project: &db::LeftProject, session: &Session) {
4508 for connection_id in &project.connection_ids {
4509 if project.should_unshare {
4510 session
4511 .peer
4512 .send(
4513 *connection_id,
4514 proto::UnshareProject {
4515 project_id: project.id.to_proto(),
4516 },
4517 )
4518 .trace_err();
4519 } else {
4520 session
4521 .peer
4522 .send(
4523 *connection_id,
4524 proto::RemoveProjectCollaborator {
4525 project_id: project.id.to_proto(),
4526 peer_id: Some(session.connection_id.into()),
4527 },
4528 )
4529 .trace_err();
4530 }
4531 }
4532}
4533
4534pub trait ResultExt {
4535 type Ok;
4536
4537 fn trace_err(self) -> Option<Self::Ok>;
4538}
4539
4540impl<T, E> ResultExt for Result<T, E>
4541where
4542 E: std::fmt::Debug,
4543{
4544 type Ok = T;
4545
4546 #[track_caller]
4547 fn trace_err(self) -> Option<T> {
4548 match self {
4549 Ok(value) => Some(value),
4550 Err(error) => {
4551 tracing::error!("{:?}", error);
4552 None
4553 }
4554 }
4555 }
4556}