1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::headers::UserAgent;
27use axum::{
28 Extension, Router, TypedHeader,
29 body::Body,
30 extract::{
31 ConnectInfo, WebSocketUpgrade,
32 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
33 },
34 headers::{Header, HeaderName},
35 http::StatusCode,
36 middleware,
37 response::IntoResponse,
38 routing::get,
39};
40use chrono::Utc;
41use collections::{HashMap, HashSet};
42pub use connection_pool::{ConnectionPool, ZedVersion};
43use core::fmt::{self, Debug, Formatter};
44use reqwest_client::ReqwestClient;
45use rpc::proto::{MultiLspQuery, split_repository_update};
46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
47
48use futures::{
49 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
50 stream::FuturesUnordered,
51};
52use prometheus::{IntGauge, register_int_gauge};
53use rpc::{
54 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
55 proto::{
56 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
57 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
58 },
59};
60use semantic_version::SemanticVersion;
61use serde::{Serialize, Serializer};
62use std::{
63 any::TypeId,
64 future::Future,
65 marker::PhantomData,
66 mem,
67 net::SocketAddr,
68 ops::{Deref, DerefMut},
69 rc::Rc,
70 sync::{
71 Arc, OnceLock,
72 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
73 },
74 time::{Duration, Instant},
75};
76use time::OffsetDateTime;
77use tokio::sync::{Semaphore, watch};
78use tower::ServiceBuilder;
79use tracing::{
80 Instrument,
81 field::{self},
82 info_span, instrument,
83};
84
85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
86
87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
89
90const MESSAGE_COUNT_PER_PAGE: usize = 100;
91const MAX_MESSAGE_LEN: usize = 1024;
92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
94
95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
96
97type MessageHandler =
98 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
99
100pub struct ConnectionGuard;
101
102impl ConnectionGuard {
103 pub fn try_acquire() -> Result<Self, ()> {
104 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
105 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
106 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
107 tracing::error!(
108 "too many concurrent connections: {}",
109 current_connections + 1
110 );
111 return Err(());
112 }
113 Ok(ConnectionGuard)
114 }
115}
116
117impl Drop for ConnectionGuard {
118 fn drop(&mut self) {
119 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
120 }
121}
122
123struct Response<R> {
124 peer: Arc<Peer>,
125 receipt: Receipt<R>,
126 responded: Arc<AtomicBool>,
127}
128
129impl<R: RequestMessage> Response<R> {
130 fn send(self, payload: R::Response) -> Result<()> {
131 self.responded.store(true, SeqCst);
132 self.peer.respond(self.receipt, payload)?;
133 Ok(())
134 }
135}
136
137#[derive(Clone, Debug)]
138pub enum Principal {
139 User(User),
140 Impersonated { user: User, admin: User },
141}
142
143impl Principal {
144 fn user(&self) -> &User {
145 match self {
146 Principal::User(user) => user,
147 Principal::Impersonated { user, .. } => user,
148 }
149 }
150
151 fn update_span(&self, span: &tracing::Span) {
152 match &self {
153 Principal::User(user) => {
154 span.record("user_id", user.id.0);
155 span.record("login", &user.github_login);
156 }
157 Principal::Impersonated { user, admin } => {
158 span.record("user_id", user.id.0);
159 span.record("login", &user.github_login);
160 span.record("impersonator", &admin.github_login);
161 }
162 }
163 }
164}
165
166#[derive(Clone)]
167struct Session {
168 principal: Principal,
169 connection_id: ConnectionId,
170 db: Arc<tokio::sync::Mutex<DbHandle>>,
171 peer: Arc<Peer>,
172 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
173 app_state: Arc<AppState>,
174 supermaven_client: Option<Arc<SupermavenAdminApi>>,
175 /// The GeoIP country code for the user.
176 #[allow(unused)]
177 geoip_country_code: Option<String>,
178 system_id: Option<String>,
179 _executor: Executor,
180}
181
182impl Session {
183 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
184 #[cfg(test)]
185 tokio::task::yield_now().await;
186 let guard = self.db.lock().await;
187 #[cfg(test)]
188 tokio::task::yield_now().await;
189 guard
190 }
191
192 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
193 #[cfg(test)]
194 tokio::task::yield_now().await;
195 let guard = self.connection_pool.lock();
196 ConnectionPoolGuard {
197 guard,
198 _not_send: PhantomData,
199 }
200 }
201
202 fn is_staff(&self) -> bool {
203 match &self.principal {
204 Principal::User(user) => user.admin,
205 Principal::Impersonated { .. } => true,
206 }
207 }
208
209 fn user_id(&self) -> UserId {
210 match &self.principal {
211 Principal::User(user) => user.id,
212 Principal::Impersonated { user, .. } => user.id,
213 }
214 }
215
216 pub fn email(&self) -> Option<String> {
217 match &self.principal {
218 Principal::User(user) => user.email_address.clone(),
219 Principal::Impersonated { user, .. } => user.email_address.clone(),
220 }
221 }
222}
223
224impl Debug for Session {
225 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
226 let mut result = f.debug_struct("Session");
227 match &self.principal {
228 Principal::User(user) => {
229 result.field("user", &user.github_login);
230 }
231 Principal::Impersonated { user, admin } => {
232 result.field("user", &user.github_login);
233 result.field("impersonator", &admin.github_login);
234 }
235 }
236 result.field("connection_id", &self.connection_id).finish()
237 }
238}
239
240struct DbHandle(Arc<Database>);
241
242impl Deref for DbHandle {
243 type Target = Database;
244
245 fn deref(&self) -> &Self::Target {
246 self.0.as_ref()
247 }
248}
249
250pub struct Server {
251 id: parking_lot::Mutex<ServerId>,
252 peer: Arc<Peer>,
253 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
254 app_state: Arc<AppState>,
255 handlers: HashMap<TypeId, MessageHandler>,
256 teardown: watch::Sender<bool>,
257}
258
259pub(crate) struct ConnectionPoolGuard<'a> {
260 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
261 _not_send: PhantomData<Rc<()>>,
262}
263
264#[derive(Serialize)]
265pub struct ServerSnapshot<'a> {
266 peer: &'a Peer,
267 #[serde(serialize_with = "serialize_deref")]
268 connection_pool: ConnectionPoolGuard<'a>,
269}
270
271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
272where
273 S: Serializer,
274 T: Deref<Target = U>,
275 U: Serialize,
276{
277 Serialize::serialize(value.deref(), serializer)
278}
279
280impl Server {
281 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
282 let mut server = Self {
283 id: parking_lot::Mutex::new(id),
284 peer: Peer::new(id.0 as u32),
285 app_state: app_state.clone(),
286 connection_pool: Default::default(),
287 handlers: Default::default(),
288 teardown: watch::channel(false).0,
289 };
290
291 server
292 .add_request_handler(ping)
293 .add_request_handler(create_room)
294 .add_request_handler(join_room)
295 .add_request_handler(rejoin_room)
296 .add_request_handler(leave_room)
297 .add_request_handler(set_room_participant_role)
298 .add_request_handler(call)
299 .add_request_handler(cancel_call)
300 .add_message_handler(decline_call)
301 .add_request_handler(update_participant_location)
302 .add_request_handler(share_project)
303 .add_message_handler(unshare_project)
304 .add_request_handler(join_project)
305 .add_message_handler(leave_project)
306 .add_request_handler(update_project)
307 .add_request_handler(update_worktree)
308 .add_request_handler(update_repository)
309 .add_request_handler(remove_repository)
310 .add_message_handler(start_language_server)
311 .add_message_handler(update_language_server)
312 .add_message_handler(update_diagnostic_summary)
313 .add_message_handler(update_worktree_settings)
314 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
316 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
317 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
318 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
319 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
320 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
321 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
322 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
323 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
324 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
325 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
326 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
327 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
328 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
329 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
330 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
331 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
332 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
333 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
334 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
335 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
336 .add_request_handler(
337 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
338 )
339 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
340 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
341 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
342 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
344 .add_request_handler(
345 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
346 )
347 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
348 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
349 .add_request_handler(
350 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
351 )
352 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
353 .add_request_handler(
354 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
355 )
356 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
357 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
358 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
359 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
360 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
361 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
362 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
363 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
364 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
365 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
368 .add_request_handler(
369 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
370 )
371 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
372 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
373 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
374 .add_request_handler(multi_lsp_query)
375 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
376 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
377 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
378 .add_message_handler(create_buffer_for_peer)
379 .add_request_handler(update_buffer)
380 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
381 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
382 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
383 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
384 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
385 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
386 .add_message_handler(
387 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
388 )
389 .add_request_handler(get_users)
390 .add_request_handler(fuzzy_search_users)
391 .add_request_handler(request_contact)
392 .add_request_handler(remove_contact)
393 .add_request_handler(respond_to_contact_request)
394 .add_message_handler(subscribe_to_channels)
395 .add_request_handler(create_channel)
396 .add_request_handler(delete_channel)
397 .add_request_handler(invite_channel_member)
398 .add_request_handler(remove_channel_member)
399 .add_request_handler(set_channel_member_role)
400 .add_request_handler(set_channel_visibility)
401 .add_request_handler(rename_channel)
402 .add_request_handler(join_channel_buffer)
403 .add_request_handler(leave_channel_buffer)
404 .add_message_handler(update_channel_buffer)
405 .add_request_handler(rejoin_channel_buffers)
406 .add_request_handler(get_channel_members)
407 .add_request_handler(respond_to_channel_invite)
408 .add_request_handler(join_channel)
409 .add_request_handler(join_channel_chat)
410 .add_message_handler(leave_channel_chat)
411 .add_request_handler(send_channel_message)
412 .add_request_handler(remove_channel_message)
413 .add_request_handler(update_channel_message)
414 .add_request_handler(get_channel_messages)
415 .add_request_handler(get_channel_messages_by_id)
416 .add_request_handler(get_notifications)
417 .add_request_handler(mark_notification_as_read)
418 .add_request_handler(move_channel)
419 .add_request_handler(reorder_channel)
420 .add_request_handler(follow)
421 .add_message_handler(unfollow)
422 .add_message_handler(update_followers)
423 .add_request_handler(get_private_user_info)
424 .add_request_handler(get_llm_api_token)
425 .add_request_handler(accept_terms_of_service)
426 .add_message_handler(acknowledge_channel_message)
427 .add_message_handler(acknowledge_buffer_version)
428 .add_request_handler(get_supermaven_api_key)
429 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
430 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
431 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
432 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
433 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
434 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
435 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
436 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
437 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
438 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
439 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
440 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
441 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
442 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
443 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
444 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
445 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
446 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
447 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
448 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
449 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
450 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
451 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
452 .add_message_handler(update_context);
453
454 Arc::new(server)
455 }
456
457 pub async fn start(&self) -> Result<()> {
458 let server_id = *self.id.lock();
459 let app_state = self.app_state.clone();
460 let peer = self.peer.clone();
461 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
462 let pool = self.connection_pool.clone();
463 let livekit_client = self.app_state.livekit_client.clone();
464
465 let span = info_span!("start server");
466 self.app_state.executor.spawn_detached(
467 async move {
468 tracing::info!("waiting for cleanup timeout");
469 timeout.await;
470 tracing::info!("cleanup timeout expired, retrieving stale rooms");
471
472 app_state
473 .db
474 .delete_stale_channel_chat_participants(
475 &app_state.config.zed_environment,
476 server_id,
477 )
478 .await
479 .trace_err();
480
481 if let Some((room_ids, channel_ids)) = app_state
482 .db
483 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
484 .await
485 .trace_err()
486 {
487 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
488 tracing::info!(
489 stale_channel_buffer_count = channel_ids.len(),
490 "retrieved stale channel buffers"
491 );
492
493 for channel_id in channel_ids {
494 if let Some(refreshed_channel_buffer) = app_state
495 .db
496 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
497 .await
498 .trace_err()
499 {
500 for connection_id in refreshed_channel_buffer.connection_ids {
501 peer.send(
502 connection_id,
503 proto::UpdateChannelBufferCollaborators {
504 channel_id: channel_id.to_proto(),
505 collaborators: refreshed_channel_buffer
506 .collaborators
507 .clone(),
508 },
509 )
510 .trace_err();
511 }
512 }
513 }
514
515 for room_id in room_ids {
516 let mut contacts_to_update = HashSet::default();
517 let mut canceled_calls_to_user_ids = Vec::new();
518 let mut livekit_room = String::new();
519 let mut delete_livekit_room = false;
520
521 if let Some(mut refreshed_room) = app_state
522 .db
523 .clear_stale_room_participants(room_id, server_id)
524 .await
525 .trace_err()
526 {
527 tracing::info!(
528 room_id = room_id.0,
529 new_participant_count = refreshed_room.room.participants.len(),
530 "refreshed room"
531 );
532 room_updated(&refreshed_room.room, &peer);
533 if let Some(channel) = refreshed_room.channel.as_ref() {
534 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
535 }
536 contacts_to_update
537 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
538 contacts_to_update
539 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
540 canceled_calls_to_user_ids =
541 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
542 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
543 delete_livekit_room = refreshed_room.room.participants.is_empty();
544 }
545
546 {
547 let pool = pool.lock();
548 for canceled_user_id in canceled_calls_to_user_ids {
549 for connection_id in pool.user_connection_ids(canceled_user_id) {
550 peer.send(
551 connection_id,
552 proto::CallCanceled {
553 room_id: room_id.to_proto(),
554 },
555 )
556 .trace_err();
557 }
558 }
559 }
560
561 for user_id in contacts_to_update {
562 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
563 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
564 if let Some((busy, contacts)) = busy.zip(contacts) {
565 let pool = pool.lock();
566 let updated_contact = contact_for_user(user_id, busy, &pool);
567 for contact in contacts {
568 if let db::Contact::Accepted {
569 user_id: contact_user_id,
570 ..
571 } = contact
572 {
573 for contact_conn_id in
574 pool.user_connection_ids(contact_user_id)
575 {
576 peer.send(
577 contact_conn_id,
578 proto::UpdateContacts {
579 contacts: vec![updated_contact.clone()],
580 remove_contacts: Default::default(),
581 incoming_requests: Default::default(),
582 remove_incoming_requests: Default::default(),
583 outgoing_requests: Default::default(),
584 remove_outgoing_requests: Default::default(),
585 },
586 )
587 .trace_err();
588 }
589 }
590 }
591 }
592 }
593
594 if let Some(live_kit) = livekit_client.as_ref() {
595 if delete_livekit_room {
596 live_kit.delete_room(livekit_room).await.trace_err();
597 }
598 }
599 }
600 }
601
602 app_state
603 .db
604 .delete_stale_channel_chat_participants(
605 &app_state.config.zed_environment,
606 server_id,
607 )
608 .await
609 .trace_err();
610
611 app_state
612 .db
613 .clear_old_worktree_entries(server_id)
614 .await
615 .trace_err();
616
617 app_state
618 .db
619 .delete_stale_servers(&app_state.config.zed_environment, server_id)
620 .await
621 .trace_err();
622 }
623 .instrument(span),
624 );
625 Ok(())
626 }
627
628 pub fn teardown(&self) {
629 self.peer.teardown();
630 self.connection_pool.lock().reset();
631 let _ = self.teardown.send(true);
632 }
633
634 #[cfg(test)]
635 pub fn reset(&self, id: ServerId) {
636 self.teardown();
637 *self.id.lock() = id;
638 self.peer.reset(id.0 as u32);
639 let _ = self.teardown.send(false);
640 }
641
642 #[cfg(test)]
643 pub fn id(&self) -> ServerId {
644 *self.id.lock()
645 }
646
647 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
648 where
649 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
650 Fut: 'static + Send + Future<Output = Result<()>>,
651 M: EnvelopedMessage,
652 {
653 let prev_handler = self.handlers.insert(
654 TypeId::of::<M>(),
655 Box::new(move |envelope, session| {
656 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
657 let received_at = envelope.received_at;
658 tracing::info!("message received");
659 let start_time = Instant::now();
660 let future = (handler)(*envelope, session);
661 async move {
662 let result = future.await;
663 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
664 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
665 let queue_duration_ms = total_duration_ms - processing_duration_ms;
666
667 match result {
668 Err(error) => {
669 tracing::error!(
670 ?error,
671 total_duration_ms,
672 processing_duration_ms,
673 queue_duration_ms,
674 "error handling message"
675 )
676 }
677 Ok(()) => tracing::info!(
678 total_duration_ms,
679 processing_duration_ms,
680 queue_duration_ms,
681 "finished handling message"
682 ),
683 }
684 }
685 .boxed()
686 }),
687 );
688 if prev_handler.is_some() {
689 panic!("registered a handler for the same message twice");
690 }
691 self
692 }
693
694 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
695 where
696 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
697 Fut: 'static + Send + Future<Output = Result<()>>,
698 M: EnvelopedMessage,
699 {
700 self.add_handler(move |envelope, session| handler(envelope.payload, session));
701 self
702 }
703
704 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
705 where
706 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
707 Fut: Send + Future<Output = Result<()>>,
708 M: RequestMessage,
709 {
710 let handler = Arc::new(handler);
711 self.add_handler(move |envelope, session| {
712 let receipt = envelope.receipt();
713 let handler = handler.clone();
714 async move {
715 let peer = session.peer.clone();
716 let responded = Arc::new(AtomicBool::default());
717 let response = Response {
718 peer: peer.clone(),
719 responded: responded.clone(),
720 receipt,
721 };
722 match (handler)(envelope.payload, response, session).await {
723 Ok(()) => {
724 if responded.load(std::sync::atomic::Ordering::SeqCst) {
725 Ok(())
726 } else {
727 Err(anyhow!("handler did not send a response"))?
728 }
729 }
730 Err(error) => {
731 let proto_err = match &error {
732 Error::Internal(err) => err.to_proto(),
733 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
734 };
735 peer.respond_with_error(receipt, proto_err)?;
736 Err(error)
737 }
738 }
739 }
740 })
741 }
742
743 pub fn handle_connection(
744 self: &Arc<Self>,
745 connection: Connection,
746 address: String,
747 principal: Principal,
748 zed_version: ZedVersion,
749 user_agent: Option<String>,
750 geoip_country_code: Option<String>,
751 system_id: Option<String>,
752 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
753 executor: Executor,
754 connection_guard: Option<ConnectionGuard>,
755 ) -> impl Future<Output = ()> + use<> {
756 let this = self.clone();
757 let span = info_span!("handle connection", %address,
758 connection_id=field::Empty,
759 user_id=field::Empty,
760 login=field::Empty,
761 impersonator=field::Empty,
762 user_agent=field::Empty,
763 geoip_country_code=field::Empty
764 );
765 principal.update_span(&span);
766 if let Some(user_agent) = user_agent {
767 span.record("user_agent", user_agent);
768 }
769
770 if let Some(country_code) = geoip_country_code.as_ref() {
771 span.record("geoip_country_code", country_code);
772 }
773
774 let mut teardown = self.teardown.subscribe();
775 async move {
776 if *teardown.borrow() {
777 tracing::error!("server is tearing down");
778 return;
779 }
780
781 let (connection_id, handle_io, mut incoming_rx) =
782 this.peer.add_connection(connection, {
783 let executor = executor.clone();
784 move |duration| executor.sleep(duration)
785 });
786 tracing::Span::current().record("connection_id", format!("{}", connection_id));
787
788 tracing::info!("connection opened");
789
790 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
791 let http_client = match ReqwestClient::user_agent(&user_agent) {
792 Ok(http_client) => Arc::new(http_client),
793 Err(error) => {
794 tracing::error!(?error, "failed to create HTTP client");
795 return;
796 }
797 };
798
799 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
800 |supermaven_admin_api_key| {
801 Arc::new(SupermavenAdminApi::new(
802 supermaven_admin_api_key.to_string(),
803 http_client.clone(),
804 ))
805 },
806 );
807
808 let session = Session {
809 principal: principal.clone(),
810 connection_id,
811 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
812 peer: this.peer.clone(),
813 connection_pool: this.connection_pool.clone(),
814 app_state: this.app_state.clone(),
815 geoip_country_code,
816 system_id,
817 _executor: executor.clone(),
818 supermaven_client,
819 };
820
821 if let Err(error) = this
822 .send_initial_client_update(
823 connection_id,
824 zed_version,
825 send_connection_id,
826 &session,
827 )
828 .await
829 {
830 tracing::error!(?error, "failed to send initial client update");
831 return;
832 }
833 drop(connection_guard);
834
835 let handle_io = handle_io.fuse();
836 futures::pin_mut!(handle_io);
837
838 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
839 // This prevents deadlocks when e.g., client A performs a request to client B and
840 // client B performs a request to client A. If both clients stop processing further
841 // messages until their respective request completes, they won't have a chance to
842 // respond to the other client's request and cause a deadlock.
843 //
844 // This arrangement ensures we will attempt to process earlier messages first, but fall
845 // back to processing messages arrived later in the spirit of making progress.
846 const MAX_CONCURRENT_HANDLERS: usize = 256;
847 let mut foreground_message_handlers = FuturesUnordered::new();
848 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
849 let get_concurrent_handlers = {
850 let concurrent_handlers = concurrent_handlers.clone();
851 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
852 };
853 loop {
854 let next_message = async {
855 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
856 let message = incoming_rx.next().await;
857 // Cache the concurrent_handlers here, so that we know what the
858 // queue looks like as each handler starts
859 (permit, message, get_concurrent_handlers())
860 }
861 .fuse();
862 futures::pin_mut!(next_message);
863 futures::select_biased! {
864 _ = teardown.changed().fuse() => return,
865 result = handle_io => {
866 if let Err(error) = result {
867 tracing::error!(?error, "error handling I/O");
868 }
869 break;
870 }
871 _ = foreground_message_handlers.next() => {}
872 next_message = next_message => {
873 let (permit, message, concurrent_handlers) = next_message;
874 if let Some(message) = message {
875 let type_name = message.payload_type_name();
876 // note: we copy all the fields from the parent span so we can query them in the logs.
877 // (https://github.com/tokio-rs/tracing/issues/2670).
878 let span = tracing::info_span!("receive message",
879 %connection_id,
880 %address,
881 type_name,
882 concurrent_handlers,
883 user_id=field::Empty,
884 login=field::Empty,
885 impersonator=field::Empty,
886 multi_lsp_query_request=field::Empty,
887 );
888 principal.update_span(&span);
889 let span_enter = span.enter();
890 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
891 let is_background = message.is_background();
892 let handle_message = (handler)(message, session.clone());
893 drop(span_enter);
894
895 let handle_message = async move {
896 handle_message.await;
897 drop(permit);
898 }.instrument(span);
899 if is_background {
900 executor.spawn_detached(handle_message);
901 } else {
902 foreground_message_handlers.push(handle_message);
903 }
904 } else {
905 tracing::error!("no message handler");
906 }
907 } else {
908 tracing::info!("connection closed");
909 break;
910 }
911 }
912 }
913 }
914
915 drop(foreground_message_handlers);
916 let concurrent_handlers = get_concurrent_handlers();
917 tracing::info!(concurrent_handlers, "signing out");
918 if let Err(error) = connection_lost(session, teardown, executor).await {
919 tracing::error!(?error, "error signing out");
920 }
921 }
922 .instrument(span)
923 }
924
925 async fn send_initial_client_update(
926 &self,
927 connection_id: ConnectionId,
928 zed_version: ZedVersion,
929 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
930 session: &Session,
931 ) -> Result<()> {
932 self.peer.send(
933 connection_id,
934 proto::Hello {
935 peer_id: Some(connection_id.into()),
936 },
937 )?;
938 tracing::info!("sent hello message");
939 if let Some(send_connection_id) = send_connection_id.take() {
940 let _ = send_connection_id.send(connection_id);
941 }
942
943 match &session.principal {
944 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
945 if !user.connected_once {
946 self.peer.send(connection_id, proto::ShowContacts {})?;
947 self.app_state
948 .db
949 .set_user_connected_once(user.id, true)
950 .await?;
951 }
952
953 update_user_plan(session).await?;
954
955 let contacts = self.app_state.db.get_contacts(user.id).await?;
956
957 {
958 let mut pool = self.connection_pool.lock();
959 pool.add_connection(connection_id, user.id, user.admin, zed_version);
960 self.peer.send(
961 connection_id,
962 build_initial_contacts_update(contacts, &pool),
963 )?;
964 }
965
966 if should_auto_subscribe_to_channels(zed_version) {
967 subscribe_user_to_channels(user.id, session).await?;
968 }
969
970 if let Some(incoming_call) =
971 self.app_state.db.incoming_call_for_user(user.id).await?
972 {
973 self.peer.send(connection_id, incoming_call)?;
974 }
975
976 update_user_contacts(user.id, session).await?;
977 }
978 }
979
980 Ok(())
981 }
982
983 pub async fn invite_code_redeemed(
984 self: &Arc<Self>,
985 inviter_id: UserId,
986 invitee_id: UserId,
987 ) -> Result<()> {
988 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
989 if let Some(code) = &user.invite_code {
990 let pool = self.connection_pool.lock();
991 let invitee_contact = contact_for_user(invitee_id, false, &pool);
992 for connection_id in pool.user_connection_ids(inviter_id) {
993 self.peer.send(
994 connection_id,
995 proto::UpdateContacts {
996 contacts: vec![invitee_contact.clone()],
997 ..Default::default()
998 },
999 )?;
1000 self.peer.send(
1001 connection_id,
1002 proto::UpdateInviteInfo {
1003 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1004 count: user.invite_count as u32,
1005 },
1006 )?;
1007 }
1008 }
1009 }
1010 Ok(())
1011 }
1012
1013 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1014 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1015 if let Some(invite_code) = &user.invite_code {
1016 let pool = self.connection_pool.lock();
1017 for connection_id in pool.user_connection_ids(user_id) {
1018 self.peer.send(
1019 connection_id,
1020 proto::UpdateInviteInfo {
1021 url: format!(
1022 "{}{}",
1023 self.app_state.config.invite_link_prefix, invite_code
1024 ),
1025 count: user.invite_count as u32,
1026 },
1027 )?;
1028 }
1029 }
1030 }
1031 Ok(())
1032 }
1033
1034 pub async fn update_plan_for_user(
1035 self: &Arc<Self>,
1036 user_id: UserId,
1037 update_user_plan: proto::UpdateUserPlan,
1038 ) -> Result<()> {
1039 let pool = self.connection_pool.lock();
1040 for connection_id in pool.user_connection_ids(user_id) {
1041 self.peer
1042 .send(connection_id, update_user_plan.clone())
1043 .trace_err();
1044 }
1045
1046 Ok(())
1047 }
1048
1049 /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1050 /// message on the Collab server.
1051 ///
1052 /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1053 pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1054 let user = self
1055 .app_state
1056 .db
1057 .get_user_by_id(user_id)
1058 .await?
1059 .context("user not found")?;
1060
1061 let update_user_plan = make_update_user_plan_message(
1062 &user,
1063 user.admin,
1064 &self.app_state.db,
1065 self.app_state.llm_db.clone(),
1066 )
1067 .await?;
1068
1069 self.update_plan_for_user(user_id, update_user_plan).await
1070 }
1071
1072 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1073 let pool = self.connection_pool.lock();
1074 for connection_id in pool.user_connection_ids(user_id) {
1075 self.peer
1076 .send(connection_id, proto::RefreshLlmToken {})
1077 .trace_err();
1078 }
1079 }
1080
1081 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1082 ServerSnapshot {
1083 connection_pool: ConnectionPoolGuard {
1084 guard: self.connection_pool.lock(),
1085 _not_send: PhantomData,
1086 },
1087 peer: &self.peer,
1088 }
1089 }
1090}
1091
1092impl Deref for ConnectionPoolGuard<'_> {
1093 type Target = ConnectionPool;
1094
1095 fn deref(&self) -> &Self::Target {
1096 &self.guard
1097 }
1098}
1099
1100impl DerefMut for ConnectionPoolGuard<'_> {
1101 fn deref_mut(&mut self) -> &mut Self::Target {
1102 &mut self.guard
1103 }
1104}
1105
1106impl Drop for ConnectionPoolGuard<'_> {
1107 fn drop(&mut self) {
1108 #[cfg(test)]
1109 self.check_invariants();
1110 }
1111}
1112
1113fn broadcast<F>(
1114 sender_id: Option<ConnectionId>,
1115 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1116 mut f: F,
1117) where
1118 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1119{
1120 for receiver_id in receiver_ids {
1121 if Some(receiver_id) != sender_id {
1122 if let Err(error) = f(receiver_id) {
1123 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1124 }
1125 }
1126 }
1127}
1128
1129pub struct ProtocolVersion(u32);
1130
1131impl Header for ProtocolVersion {
1132 fn name() -> &'static HeaderName {
1133 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1134 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1135 }
1136
1137 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1138 where
1139 Self: Sized,
1140 I: Iterator<Item = &'i axum::http::HeaderValue>,
1141 {
1142 let version = values
1143 .next()
1144 .ok_or_else(axum::headers::Error::invalid)?
1145 .to_str()
1146 .map_err(|_| axum::headers::Error::invalid())?
1147 .parse()
1148 .map_err(|_| axum::headers::Error::invalid())?;
1149 Ok(Self(version))
1150 }
1151
1152 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1153 values.extend([self.0.to_string().parse().unwrap()]);
1154 }
1155}
1156
1157pub struct AppVersionHeader(SemanticVersion);
1158impl Header for AppVersionHeader {
1159 fn name() -> &'static HeaderName {
1160 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1161 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1162 }
1163
1164 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1165 where
1166 Self: Sized,
1167 I: Iterator<Item = &'i axum::http::HeaderValue>,
1168 {
1169 let version = values
1170 .next()
1171 .ok_or_else(axum::headers::Error::invalid)?
1172 .to_str()
1173 .map_err(|_| axum::headers::Error::invalid())?
1174 .parse()
1175 .map_err(|_| axum::headers::Error::invalid())?;
1176 Ok(Self(version))
1177 }
1178
1179 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1180 values.extend([self.0.to_string().parse().unwrap()]);
1181 }
1182}
1183
1184pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1185 Router::new()
1186 .route("/rpc", get(handle_websocket_request))
1187 .layer(
1188 ServiceBuilder::new()
1189 .layer(Extension(server.app_state.clone()))
1190 .layer(middleware::from_fn(auth::validate_header)),
1191 )
1192 .route("/metrics", get(handle_metrics))
1193 .layer(Extension(server))
1194}
1195
1196pub async fn handle_websocket_request(
1197 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1198 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1199 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1200 Extension(server): Extension<Arc<Server>>,
1201 Extension(principal): Extension<Principal>,
1202 user_agent: Option<TypedHeader<UserAgent>>,
1203 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1204 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1205 ws: WebSocketUpgrade,
1206) -> axum::response::Response {
1207 if protocol_version != rpc::PROTOCOL_VERSION {
1208 return (
1209 StatusCode::UPGRADE_REQUIRED,
1210 "client must be upgraded".to_string(),
1211 )
1212 .into_response();
1213 }
1214
1215 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1216 return (
1217 StatusCode::UPGRADE_REQUIRED,
1218 "no version header found".to_string(),
1219 )
1220 .into_response();
1221 };
1222
1223 if !version.can_collaborate() {
1224 return (
1225 StatusCode::UPGRADE_REQUIRED,
1226 "client must be upgraded".to_string(),
1227 )
1228 .into_response();
1229 }
1230
1231 let socket_address = socket_address.to_string();
1232
1233 // Acquire connection guard before WebSocket upgrade
1234 let connection_guard = match ConnectionGuard::try_acquire() {
1235 Ok(guard) => guard,
1236 Err(()) => {
1237 return (
1238 StatusCode::SERVICE_UNAVAILABLE,
1239 "Too many concurrent connections",
1240 )
1241 .into_response();
1242 }
1243 };
1244
1245 ws.on_upgrade(move |socket| {
1246 let socket = socket
1247 .map_ok(to_tungstenite_message)
1248 .err_into()
1249 .with(|message| async move { to_axum_message(message) });
1250 let connection = Connection::new(Box::pin(socket));
1251 async move {
1252 server
1253 .handle_connection(
1254 connection,
1255 socket_address,
1256 principal,
1257 version,
1258 user_agent.map(|header| header.to_string()),
1259 country_code_header.map(|header| header.to_string()),
1260 system_id_header.map(|header| header.to_string()),
1261 None,
1262 Executor::Production,
1263 Some(connection_guard),
1264 )
1265 .await;
1266 }
1267 })
1268}
1269
1270pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1271 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1272 let connections_metric = CONNECTIONS_METRIC
1273 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1274
1275 let connections = server
1276 .connection_pool
1277 .lock()
1278 .connections()
1279 .filter(|connection| !connection.admin)
1280 .count();
1281 connections_metric.set(connections as _);
1282
1283 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1284 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1285 register_int_gauge!(
1286 "shared_projects",
1287 "number of open projects with one or more guests"
1288 )
1289 .unwrap()
1290 });
1291
1292 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1293 shared_projects_metric.set(shared_projects as _);
1294
1295 let encoder = prometheus::TextEncoder::new();
1296 let metric_families = prometheus::gather();
1297 let encoded_metrics = encoder
1298 .encode_to_string(&metric_families)
1299 .map_err(|err| anyhow!("{err}"))?;
1300 Ok(encoded_metrics)
1301}
1302
1303#[instrument(err, skip(executor))]
1304async fn connection_lost(
1305 session: Session,
1306 mut teardown: watch::Receiver<bool>,
1307 executor: Executor,
1308) -> Result<()> {
1309 session.peer.disconnect(session.connection_id);
1310 session
1311 .connection_pool()
1312 .await
1313 .remove_connection(session.connection_id)?;
1314
1315 session
1316 .db()
1317 .await
1318 .connection_lost(session.connection_id)
1319 .await
1320 .trace_err();
1321
1322 futures::select_biased! {
1323 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1324
1325 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1326 leave_room_for_session(&session, session.connection_id).await.trace_err();
1327 leave_channel_buffers_for_session(&session)
1328 .await
1329 .trace_err();
1330
1331 if !session
1332 .connection_pool()
1333 .await
1334 .is_user_online(session.user_id())
1335 {
1336 let db = session.db().await;
1337 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1338 room_updated(&room, &session.peer);
1339 }
1340 }
1341
1342 update_user_contacts(session.user_id(), &session).await?;
1343 },
1344 _ = teardown.changed().fuse() => {}
1345 }
1346
1347 Ok(())
1348}
1349
1350/// Acknowledges a ping from a client, used to keep the connection alive.
1351async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1352 response.send(proto::Ack {})?;
1353 Ok(())
1354}
1355
1356/// Creates a new room for calling (outside of channels)
1357async fn create_room(
1358 _request: proto::CreateRoom,
1359 response: Response<proto::CreateRoom>,
1360 session: Session,
1361) -> Result<()> {
1362 let livekit_room = nanoid::nanoid!(30);
1363
1364 let live_kit_connection_info = util::maybe!(async {
1365 let live_kit = session.app_state.livekit_client.as_ref();
1366 let live_kit = live_kit?;
1367 let user_id = session.user_id().to_string();
1368
1369 let token = live_kit
1370 .room_token(&livekit_room, &user_id.to_string())
1371 .trace_err()?;
1372
1373 Some(proto::LiveKitConnectionInfo {
1374 server_url: live_kit.url().into(),
1375 token,
1376 can_publish: true,
1377 })
1378 })
1379 .await;
1380
1381 let room = session
1382 .db()
1383 .await
1384 .create_room(session.user_id(), session.connection_id, &livekit_room)
1385 .await?;
1386
1387 response.send(proto::CreateRoomResponse {
1388 room: Some(room.clone()),
1389 live_kit_connection_info,
1390 })?;
1391
1392 update_user_contacts(session.user_id(), &session).await?;
1393 Ok(())
1394}
1395
1396/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1397async fn join_room(
1398 request: proto::JoinRoom,
1399 response: Response<proto::JoinRoom>,
1400 session: Session,
1401) -> Result<()> {
1402 let room_id = RoomId::from_proto(request.id);
1403
1404 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1405
1406 if let Some(channel_id) = channel_id {
1407 return join_channel_internal(channel_id, Box::new(response), session).await;
1408 }
1409
1410 let joined_room = {
1411 let room = session
1412 .db()
1413 .await
1414 .join_room(room_id, session.user_id(), session.connection_id)
1415 .await?;
1416 room_updated(&room.room, &session.peer);
1417 room.into_inner()
1418 };
1419
1420 for connection_id in session
1421 .connection_pool()
1422 .await
1423 .user_connection_ids(session.user_id())
1424 {
1425 session
1426 .peer
1427 .send(
1428 connection_id,
1429 proto::CallCanceled {
1430 room_id: room_id.to_proto(),
1431 },
1432 )
1433 .trace_err();
1434 }
1435
1436 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1437 {
1438 live_kit
1439 .room_token(
1440 &joined_room.room.livekit_room,
1441 &session.user_id().to_string(),
1442 )
1443 .trace_err()
1444 .map(|token| proto::LiveKitConnectionInfo {
1445 server_url: live_kit.url().into(),
1446 token,
1447 can_publish: true,
1448 })
1449 } else {
1450 None
1451 };
1452
1453 response.send(proto::JoinRoomResponse {
1454 room: Some(joined_room.room),
1455 channel_id: None,
1456 live_kit_connection_info,
1457 })?;
1458
1459 update_user_contacts(session.user_id(), &session).await?;
1460 Ok(())
1461}
1462
1463/// Rejoin room is used to reconnect to a room after connection errors.
1464async fn rejoin_room(
1465 request: proto::RejoinRoom,
1466 response: Response<proto::RejoinRoom>,
1467 session: Session,
1468) -> Result<()> {
1469 let room;
1470 let channel;
1471 {
1472 let mut rejoined_room = session
1473 .db()
1474 .await
1475 .rejoin_room(request, session.user_id(), session.connection_id)
1476 .await?;
1477
1478 response.send(proto::RejoinRoomResponse {
1479 room: Some(rejoined_room.room.clone()),
1480 reshared_projects: rejoined_room
1481 .reshared_projects
1482 .iter()
1483 .map(|project| proto::ResharedProject {
1484 id: project.id.to_proto(),
1485 collaborators: project
1486 .collaborators
1487 .iter()
1488 .map(|collaborator| collaborator.to_proto())
1489 .collect(),
1490 })
1491 .collect(),
1492 rejoined_projects: rejoined_room
1493 .rejoined_projects
1494 .iter()
1495 .map(|rejoined_project| rejoined_project.to_proto())
1496 .collect(),
1497 })?;
1498 room_updated(&rejoined_room.room, &session.peer);
1499
1500 for project in &rejoined_room.reshared_projects {
1501 for collaborator in &project.collaborators {
1502 session
1503 .peer
1504 .send(
1505 collaborator.connection_id,
1506 proto::UpdateProjectCollaborator {
1507 project_id: project.id.to_proto(),
1508 old_peer_id: Some(project.old_connection_id.into()),
1509 new_peer_id: Some(session.connection_id.into()),
1510 },
1511 )
1512 .trace_err();
1513 }
1514
1515 broadcast(
1516 Some(session.connection_id),
1517 project
1518 .collaborators
1519 .iter()
1520 .map(|collaborator| collaborator.connection_id),
1521 |connection_id| {
1522 session.peer.forward_send(
1523 session.connection_id,
1524 connection_id,
1525 proto::UpdateProject {
1526 project_id: project.id.to_proto(),
1527 worktrees: project.worktrees.clone(),
1528 },
1529 )
1530 },
1531 );
1532 }
1533
1534 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1535
1536 let rejoined_room = rejoined_room.into_inner();
1537
1538 room = rejoined_room.room;
1539 channel = rejoined_room.channel;
1540 }
1541
1542 if let Some(channel) = channel {
1543 channel_updated(
1544 &channel,
1545 &room,
1546 &session.peer,
1547 &*session.connection_pool().await,
1548 );
1549 }
1550
1551 update_user_contacts(session.user_id(), &session).await?;
1552 Ok(())
1553}
1554
1555fn notify_rejoined_projects(
1556 rejoined_projects: &mut Vec<RejoinedProject>,
1557 session: &Session,
1558) -> Result<()> {
1559 for project in rejoined_projects.iter() {
1560 for collaborator in &project.collaborators {
1561 session
1562 .peer
1563 .send(
1564 collaborator.connection_id,
1565 proto::UpdateProjectCollaborator {
1566 project_id: project.id.to_proto(),
1567 old_peer_id: Some(project.old_connection_id.into()),
1568 new_peer_id: Some(session.connection_id.into()),
1569 },
1570 )
1571 .trace_err();
1572 }
1573 }
1574
1575 for project in rejoined_projects {
1576 for worktree in mem::take(&mut project.worktrees) {
1577 // Stream this worktree's entries.
1578 let message = proto::UpdateWorktree {
1579 project_id: project.id.to_proto(),
1580 worktree_id: worktree.id,
1581 abs_path: worktree.abs_path.clone(),
1582 root_name: worktree.root_name,
1583 updated_entries: worktree.updated_entries,
1584 removed_entries: worktree.removed_entries,
1585 scan_id: worktree.scan_id,
1586 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1587 updated_repositories: worktree.updated_repositories,
1588 removed_repositories: worktree.removed_repositories,
1589 };
1590 for update in proto::split_worktree_update(message) {
1591 session.peer.send(session.connection_id, update)?;
1592 }
1593
1594 // Stream this worktree's diagnostics.
1595 for summary in worktree.diagnostic_summaries {
1596 session.peer.send(
1597 session.connection_id,
1598 proto::UpdateDiagnosticSummary {
1599 project_id: project.id.to_proto(),
1600 worktree_id: worktree.id,
1601 summary: Some(summary),
1602 },
1603 )?;
1604 }
1605
1606 for settings_file in worktree.settings_files {
1607 session.peer.send(
1608 session.connection_id,
1609 proto::UpdateWorktreeSettings {
1610 project_id: project.id.to_proto(),
1611 worktree_id: worktree.id,
1612 path: settings_file.path,
1613 content: Some(settings_file.content),
1614 kind: Some(settings_file.kind.to_proto().into()),
1615 },
1616 )?;
1617 }
1618 }
1619
1620 for repository in mem::take(&mut project.updated_repositories) {
1621 for update in split_repository_update(repository) {
1622 session.peer.send(session.connection_id, update)?;
1623 }
1624 }
1625
1626 for id in mem::take(&mut project.removed_repositories) {
1627 session.peer.send(
1628 session.connection_id,
1629 proto::RemoveRepository {
1630 project_id: project.id.to_proto(),
1631 id,
1632 },
1633 )?;
1634 }
1635 }
1636
1637 Ok(())
1638}
1639
1640/// leave room disconnects from the room.
1641async fn leave_room(
1642 _: proto::LeaveRoom,
1643 response: Response<proto::LeaveRoom>,
1644 session: Session,
1645) -> Result<()> {
1646 leave_room_for_session(&session, session.connection_id).await?;
1647 response.send(proto::Ack {})?;
1648 Ok(())
1649}
1650
1651/// Updates the permissions of someone else in the room.
1652async fn set_room_participant_role(
1653 request: proto::SetRoomParticipantRole,
1654 response: Response<proto::SetRoomParticipantRole>,
1655 session: Session,
1656) -> Result<()> {
1657 let user_id = UserId::from_proto(request.user_id);
1658 let role = ChannelRole::from(request.role());
1659
1660 let (livekit_room, can_publish) = {
1661 let room = session
1662 .db()
1663 .await
1664 .set_room_participant_role(
1665 session.user_id(),
1666 RoomId::from_proto(request.room_id),
1667 user_id,
1668 role,
1669 )
1670 .await?;
1671
1672 let livekit_room = room.livekit_room.clone();
1673 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1674 room_updated(&room, &session.peer);
1675 (livekit_room, can_publish)
1676 };
1677
1678 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1679 live_kit
1680 .update_participant(
1681 livekit_room.clone(),
1682 request.user_id.to_string(),
1683 livekit_api::proto::ParticipantPermission {
1684 can_subscribe: true,
1685 can_publish,
1686 can_publish_data: can_publish,
1687 hidden: false,
1688 recorder: false,
1689 },
1690 )
1691 .await
1692 .trace_err();
1693 }
1694
1695 response.send(proto::Ack {})?;
1696 Ok(())
1697}
1698
1699/// Call someone else into the current room
1700async fn call(
1701 request: proto::Call,
1702 response: Response<proto::Call>,
1703 session: Session,
1704) -> Result<()> {
1705 let room_id = RoomId::from_proto(request.room_id);
1706 let calling_user_id = session.user_id();
1707 let calling_connection_id = session.connection_id;
1708 let called_user_id = UserId::from_proto(request.called_user_id);
1709 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1710 if !session
1711 .db()
1712 .await
1713 .has_contact(calling_user_id, called_user_id)
1714 .await?
1715 {
1716 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1717 }
1718
1719 let incoming_call = {
1720 let (room, incoming_call) = &mut *session
1721 .db()
1722 .await
1723 .call(
1724 room_id,
1725 calling_user_id,
1726 calling_connection_id,
1727 called_user_id,
1728 initial_project_id,
1729 )
1730 .await?;
1731 room_updated(room, &session.peer);
1732 mem::take(incoming_call)
1733 };
1734 update_user_contacts(called_user_id, &session).await?;
1735
1736 let mut calls = session
1737 .connection_pool()
1738 .await
1739 .user_connection_ids(called_user_id)
1740 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1741 .collect::<FuturesUnordered<_>>();
1742
1743 while let Some(call_response) = calls.next().await {
1744 match call_response.as_ref() {
1745 Ok(_) => {
1746 response.send(proto::Ack {})?;
1747 return Ok(());
1748 }
1749 Err(_) => {
1750 call_response.trace_err();
1751 }
1752 }
1753 }
1754
1755 {
1756 let room = session
1757 .db()
1758 .await
1759 .call_failed(room_id, called_user_id)
1760 .await?;
1761 room_updated(&room, &session.peer);
1762 }
1763 update_user_contacts(called_user_id, &session).await?;
1764
1765 Err(anyhow!("failed to ring user"))?
1766}
1767
1768/// Cancel an outgoing call.
1769async fn cancel_call(
1770 request: proto::CancelCall,
1771 response: Response<proto::CancelCall>,
1772 session: Session,
1773) -> Result<()> {
1774 let called_user_id = UserId::from_proto(request.called_user_id);
1775 let room_id = RoomId::from_proto(request.room_id);
1776 {
1777 let room = session
1778 .db()
1779 .await
1780 .cancel_call(room_id, session.connection_id, called_user_id)
1781 .await?;
1782 room_updated(&room, &session.peer);
1783 }
1784
1785 for connection_id in session
1786 .connection_pool()
1787 .await
1788 .user_connection_ids(called_user_id)
1789 {
1790 session
1791 .peer
1792 .send(
1793 connection_id,
1794 proto::CallCanceled {
1795 room_id: room_id.to_proto(),
1796 },
1797 )
1798 .trace_err();
1799 }
1800 response.send(proto::Ack {})?;
1801
1802 update_user_contacts(called_user_id, &session).await?;
1803 Ok(())
1804}
1805
1806/// Decline an incoming call.
1807async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1808 let room_id = RoomId::from_proto(message.room_id);
1809 {
1810 let room = session
1811 .db()
1812 .await
1813 .decline_call(Some(room_id), session.user_id())
1814 .await?
1815 .context("declining call")?;
1816 room_updated(&room, &session.peer);
1817 }
1818
1819 for connection_id in session
1820 .connection_pool()
1821 .await
1822 .user_connection_ids(session.user_id())
1823 {
1824 session
1825 .peer
1826 .send(
1827 connection_id,
1828 proto::CallCanceled {
1829 room_id: room_id.to_proto(),
1830 },
1831 )
1832 .trace_err();
1833 }
1834 update_user_contacts(session.user_id(), &session).await?;
1835 Ok(())
1836}
1837
1838/// Updates other participants in the room with your current location.
1839async fn update_participant_location(
1840 request: proto::UpdateParticipantLocation,
1841 response: Response<proto::UpdateParticipantLocation>,
1842 session: Session,
1843) -> Result<()> {
1844 let room_id = RoomId::from_proto(request.room_id);
1845 let location = request.location.context("invalid location")?;
1846
1847 let db = session.db().await;
1848 let room = db
1849 .update_room_participant_location(room_id, session.connection_id, location)
1850 .await?;
1851
1852 room_updated(&room, &session.peer);
1853 response.send(proto::Ack {})?;
1854 Ok(())
1855}
1856
1857/// Share a project into the room.
1858async fn share_project(
1859 request: proto::ShareProject,
1860 response: Response<proto::ShareProject>,
1861 session: Session,
1862) -> Result<()> {
1863 let (project_id, room) = &*session
1864 .db()
1865 .await
1866 .share_project(
1867 RoomId::from_proto(request.room_id),
1868 session.connection_id,
1869 &request.worktrees,
1870 request.is_ssh_project,
1871 )
1872 .await?;
1873 response.send(proto::ShareProjectResponse {
1874 project_id: project_id.to_proto(),
1875 })?;
1876 room_updated(room, &session.peer);
1877
1878 Ok(())
1879}
1880
1881/// Unshare a project from the room.
1882async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1883 let project_id = ProjectId::from_proto(message.project_id);
1884 unshare_project_internal(project_id, session.connection_id, &session).await
1885}
1886
1887async fn unshare_project_internal(
1888 project_id: ProjectId,
1889 connection_id: ConnectionId,
1890 session: &Session,
1891) -> Result<()> {
1892 let delete = {
1893 let room_guard = session
1894 .db()
1895 .await
1896 .unshare_project(project_id, connection_id)
1897 .await?;
1898
1899 let (delete, room, guest_connection_ids) = &*room_guard;
1900
1901 let message = proto::UnshareProject {
1902 project_id: project_id.to_proto(),
1903 };
1904
1905 broadcast(
1906 Some(connection_id),
1907 guest_connection_ids.iter().copied(),
1908 |conn_id| session.peer.send(conn_id, message.clone()),
1909 );
1910 if let Some(room) = room {
1911 room_updated(room, &session.peer);
1912 }
1913
1914 *delete
1915 };
1916
1917 if delete {
1918 let db = session.db().await;
1919 db.delete_project(project_id).await?;
1920 }
1921
1922 Ok(())
1923}
1924
1925/// Join someone elses shared project.
1926async fn join_project(
1927 request: proto::JoinProject,
1928 response: Response<proto::JoinProject>,
1929 session: Session,
1930) -> Result<()> {
1931 let project_id = ProjectId::from_proto(request.project_id);
1932
1933 tracing::info!(%project_id, "join project");
1934
1935 let db = session.db().await;
1936 let (project, replica_id) = &mut *db
1937 .join_project(
1938 project_id,
1939 session.connection_id,
1940 session.user_id(),
1941 request.committer_name.clone(),
1942 request.committer_email.clone(),
1943 )
1944 .await?;
1945 drop(db);
1946 tracing::info!(%project_id, "join remote project");
1947 let collaborators = project
1948 .collaborators
1949 .iter()
1950 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1951 .map(|collaborator| collaborator.to_proto())
1952 .collect::<Vec<_>>();
1953 let project_id = project.id;
1954 let guest_user_id = session.user_id();
1955
1956 let worktrees = project
1957 .worktrees
1958 .iter()
1959 .map(|(id, worktree)| proto::WorktreeMetadata {
1960 id: *id,
1961 root_name: worktree.root_name.clone(),
1962 visible: worktree.visible,
1963 abs_path: worktree.abs_path.clone(),
1964 })
1965 .collect::<Vec<_>>();
1966
1967 let add_project_collaborator = proto::AddProjectCollaborator {
1968 project_id: project_id.to_proto(),
1969 collaborator: Some(proto::Collaborator {
1970 peer_id: Some(session.connection_id.into()),
1971 replica_id: replica_id.0 as u32,
1972 user_id: guest_user_id.to_proto(),
1973 is_host: false,
1974 committer_name: request.committer_name.clone(),
1975 committer_email: request.committer_email.clone(),
1976 }),
1977 };
1978
1979 for collaborator in &collaborators {
1980 session
1981 .peer
1982 .send(
1983 collaborator.peer_id.unwrap().into(),
1984 add_project_collaborator.clone(),
1985 )
1986 .trace_err();
1987 }
1988
1989 // First, we send the metadata associated with each worktree.
1990 let (language_servers, language_server_capabilities) = project
1991 .language_servers
1992 .clone()
1993 .into_iter()
1994 .map(|server| (server.server, server.capabilities))
1995 .unzip();
1996 response.send(proto::JoinProjectResponse {
1997 project_id: project.id.0 as u64,
1998 worktrees: worktrees.clone(),
1999 replica_id: replica_id.0 as u32,
2000 collaborators: collaborators.clone(),
2001 language_servers,
2002 language_server_capabilities,
2003 role: project.role.into(),
2004 })?;
2005
2006 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2007 // Stream this worktree's entries.
2008 let message = proto::UpdateWorktree {
2009 project_id: project_id.to_proto(),
2010 worktree_id,
2011 abs_path: worktree.abs_path.clone(),
2012 root_name: worktree.root_name,
2013 updated_entries: worktree.entries,
2014 removed_entries: Default::default(),
2015 scan_id: worktree.scan_id,
2016 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2017 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2018 removed_repositories: Default::default(),
2019 };
2020 for update in proto::split_worktree_update(message) {
2021 session.peer.send(session.connection_id, update.clone())?;
2022 }
2023
2024 // Stream this worktree's diagnostics.
2025 for summary in worktree.diagnostic_summaries {
2026 session.peer.send(
2027 session.connection_id,
2028 proto::UpdateDiagnosticSummary {
2029 project_id: project_id.to_proto(),
2030 worktree_id: worktree.id,
2031 summary: Some(summary),
2032 },
2033 )?;
2034 }
2035
2036 for settings_file in worktree.settings_files {
2037 session.peer.send(
2038 session.connection_id,
2039 proto::UpdateWorktreeSettings {
2040 project_id: project_id.to_proto(),
2041 worktree_id: worktree.id,
2042 path: settings_file.path,
2043 content: Some(settings_file.content),
2044 kind: Some(settings_file.kind.to_proto() as i32),
2045 },
2046 )?;
2047 }
2048 }
2049
2050 for repository in mem::take(&mut project.repositories) {
2051 for update in split_repository_update(repository) {
2052 session.peer.send(session.connection_id, update)?;
2053 }
2054 }
2055
2056 for language_server in &project.language_servers {
2057 session.peer.send(
2058 session.connection_id,
2059 proto::UpdateLanguageServer {
2060 project_id: project_id.to_proto(),
2061 server_name: Some(language_server.server.name.clone()),
2062 language_server_id: language_server.server.id,
2063 variant: Some(
2064 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2065 proto::LspDiskBasedDiagnosticsUpdated {},
2066 ),
2067 ),
2068 },
2069 )?;
2070 }
2071
2072 Ok(())
2073}
2074
2075/// Leave someone elses shared project.
2076async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2077 let sender_id = session.connection_id;
2078 let project_id = ProjectId::from_proto(request.project_id);
2079 let db = session.db().await;
2080
2081 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2082 tracing::info!(
2083 %project_id,
2084 "leave project"
2085 );
2086
2087 project_left(project, &session);
2088 if let Some(room) = room {
2089 room_updated(room, &session.peer);
2090 }
2091
2092 Ok(())
2093}
2094
2095/// Updates other participants with changes to the project
2096async fn update_project(
2097 request: proto::UpdateProject,
2098 response: Response<proto::UpdateProject>,
2099 session: Session,
2100) -> Result<()> {
2101 let project_id = ProjectId::from_proto(request.project_id);
2102 let (room, guest_connection_ids) = &*session
2103 .db()
2104 .await
2105 .update_project(project_id, session.connection_id, &request.worktrees)
2106 .await?;
2107 broadcast(
2108 Some(session.connection_id),
2109 guest_connection_ids.iter().copied(),
2110 |connection_id| {
2111 session
2112 .peer
2113 .forward_send(session.connection_id, connection_id, request.clone())
2114 },
2115 );
2116 if let Some(room) = room {
2117 room_updated(room, &session.peer);
2118 }
2119 response.send(proto::Ack {})?;
2120
2121 Ok(())
2122}
2123
2124/// Updates other participants with changes to the worktree
2125async fn update_worktree(
2126 request: proto::UpdateWorktree,
2127 response: Response<proto::UpdateWorktree>,
2128 session: Session,
2129) -> Result<()> {
2130 let guest_connection_ids = session
2131 .db()
2132 .await
2133 .update_worktree(&request, session.connection_id)
2134 .await?;
2135
2136 broadcast(
2137 Some(session.connection_id),
2138 guest_connection_ids.iter().copied(),
2139 |connection_id| {
2140 session
2141 .peer
2142 .forward_send(session.connection_id, connection_id, request.clone())
2143 },
2144 );
2145 response.send(proto::Ack {})?;
2146 Ok(())
2147}
2148
2149async fn update_repository(
2150 request: proto::UpdateRepository,
2151 response: Response<proto::UpdateRepository>,
2152 session: Session,
2153) -> Result<()> {
2154 let guest_connection_ids = session
2155 .db()
2156 .await
2157 .update_repository(&request, session.connection_id)
2158 .await?;
2159
2160 broadcast(
2161 Some(session.connection_id),
2162 guest_connection_ids.iter().copied(),
2163 |connection_id| {
2164 session
2165 .peer
2166 .forward_send(session.connection_id, connection_id, request.clone())
2167 },
2168 );
2169 response.send(proto::Ack {})?;
2170 Ok(())
2171}
2172
2173async fn remove_repository(
2174 request: proto::RemoveRepository,
2175 response: Response<proto::RemoveRepository>,
2176 session: Session,
2177) -> Result<()> {
2178 let guest_connection_ids = session
2179 .db()
2180 .await
2181 .remove_repository(&request, session.connection_id)
2182 .await?;
2183
2184 broadcast(
2185 Some(session.connection_id),
2186 guest_connection_ids.iter().copied(),
2187 |connection_id| {
2188 session
2189 .peer
2190 .forward_send(session.connection_id, connection_id, request.clone())
2191 },
2192 );
2193 response.send(proto::Ack {})?;
2194 Ok(())
2195}
2196
2197/// Updates other participants with changes to the diagnostics
2198async fn update_diagnostic_summary(
2199 message: proto::UpdateDiagnosticSummary,
2200 session: Session,
2201) -> Result<()> {
2202 let guest_connection_ids = session
2203 .db()
2204 .await
2205 .update_diagnostic_summary(&message, session.connection_id)
2206 .await?;
2207
2208 broadcast(
2209 Some(session.connection_id),
2210 guest_connection_ids.iter().copied(),
2211 |connection_id| {
2212 session
2213 .peer
2214 .forward_send(session.connection_id, connection_id, message.clone())
2215 },
2216 );
2217
2218 Ok(())
2219}
2220
2221/// Updates other participants with changes to the worktree settings
2222async fn update_worktree_settings(
2223 message: proto::UpdateWorktreeSettings,
2224 session: Session,
2225) -> Result<()> {
2226 let guest_connection_ids = session
2227 .db()
2228 .await
2229 .update_worktree_settings(&message, session.connection_id)
2230 .await?;
2231
2232 broadcast(
2233 Some(session.connection_id),
2234 guest_connection_ids.iter().copied(),
2235 |connection_id| {
2236 session
2237 .peer
2238 .forward_send(session.connection_id, connection_id, message.clone())
2239 },
2240 );
2241
2242 Ok(())
2243}
2244
2245/// Notify other participants that a language server has started.
2246async fn start_language_server(
2247 request: proto::StartLanguageServer,
2248 session: Session,
2249) -> Result<()> {
2250 let guest_connection_ids = session
2251 .db()
2252 .await
2253 .start_language_server(&request, session.connection_id)
2254 .await?;
2255
2256 broadcast(
2257 Some(session.connection_id),
2258 guest_connection_ids.iter().copied(),
2259 |connection_id| {
2260 session
2261 .peer
2262 .forward_send(session.connection_id, connection_id, request.clone())
2263 },
2264 );
2265 Ok(())
2266}
2267
2268/// Notify other participants that a language server has changed.
2269async fn update_language_server(
2270 request: proto::UpdateLanguageServer,
2271 session: Session,
2272) -> Result<()> {
2273 let project_id = ProjectId::from_proto(request.project_id);
2274 let db = session.db().await;
2275
2276 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2277 {
2278 if let Some(capabilities) = update.capabilities.clone() {
2279 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2280 .await?;
2281 }
2282 }
2283
2284 let project_connection_ids = db
2285 .project_connection_ids(project_id, session.connection_id, true)
2286 .await?;
2287 broadcast(
2288 Some(session.connection_id),
2289 project_connection_ids.iter().copied(),
2290 |connection_id| {
2291 session
2292 .peer
2293 .forward_send(session.connection_id, connection_id, request.clone())
2294 },
2295 );
2296 Ok(())
2297}
2298
2299/// forward a project request to the host. These requests should be read only
2300/// as guests are allowed to send them.
2301async fn forward_read_only_project_request<T>(
2302 request: T,
2303 response: Response<T>,
2304 session: Session,
2305) -> Result<()>
2306where
2307 T: EntityMessage + RequestMessage,
2308{
2309 let project_id = ProjectId::from_proto(request.remote_entity_id());
2310 let host_connection_id = session
2311 .db()
2312 .await
2313 .host_for_read_only_project_request(project_id, session.connection_id)
2314 .await?;
2315 let payload = session
2316 .peer
2317 .forward_request(session.connection_id, host_connection_id, request)
2318 .await?;
2319 response.send(payload)?;
2320 Ok(())
2321}
2322
2323/// forward a project request to the host. These requests are disallowed
2324/// for guests.
2325async fn forward_mutating_project_request<T>(
2326 request: T,
2327 response: Response<T>,
2328 session: Session,
2329) -> Result<()>
2330where
2331 T: EntityMessage + RequestMessage,
2332{
2333 let project_id = ProjectId::from_proto(request.remote_entity_id());
2334
2335 let host_connection_id = session
2336 .db()
2337 .await
2338 .host_for_mutating_project_request(project_id, session.connection_id)
2339 .await?;
2340 let payload = session
2341 .peer
2342 .forward_request(session.connection_id, host_connection_id, request)
2343 .await?;
2344 response.send(payload)?;
2345 Ok(())
2346}
2347
2348async fn multi_lsp_query(
2349 request: MultiLspQuery,
2350 response: Response<MultiLspQuery>,
2351 session: Session,
2352) -> Result<()> {
2353 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2354 tracing::info!("multi_lsp_query message received");
2355 forward_mutating_project_request(request, response, session).await
2356}
2357
2358/// Notify other participants that a new buffer has been created
2359async fn create_buffer_for_peer(
2360 request: proto::CreateBufferForPeer,
2361 session: Session,
2362) -> Result<()> {
2363 session
2364 .db()
2365 .await
2366 .check_user_is_project_host(
2367 ProjectId::from_proto(request.project_id),
2368 session.connection_id,
2369 )
2370 .await?;
2371 let peer_id = request.peer_id.context("invalid peer id")?;
2372 session
2373 .peer
2374 .forward_send(session.connection_id, peer_id.into(), request)?;
2375 Ok(())
2376}
2377
2378/// Notify other participants that a buffer has been updated. This is
2379/// allowed for guests as long as the update is limited to selections.
2380async fn update_buffer(
2381 request: proto::UpdateBuffer,
2382 response: Response<proto::UpdateBuffer>,
2383 session: Session,
2384) -> Result<()> {
2385 let project_id = ProjectId::from_proto(request.project_id);
2386 let mut capability = Capability::ReadOnly;
2387
2388 for op in request.operations.iter() {
2389 match op.variant {
2390 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2391 Some(_) => capability = Capability::ReadWrite,
2392 }
2393 }
2394
2395 let host = {
2396 let guard = session
2397 .db()
2398 .await
2399 .connections_for_buffer_update(project_id, session.connection_id, capability)
2400 .await?;
2401
2402 let (host, guests) = &*guard;
2403
2404 broadcast(
2405 Some(session.connection_id),
2406 guests.clone(),
2407 |connection_id| {
2408 session
2409 .peer
2410 .forward_send(session.connection_id, connection_id, request.clone())
2411 },
2412 );
2413
2414 *host
2415 };
2416
2417 if host != session.connection_id {
2418 session
2419 .peer
2420 .forward_request(session.connection_id, host, request.clone())
2421 .await?;
2422 }
2423
2424 response.send(proto::Ack {})?;
2425 Ok(())
2426}
2427
2428async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2429 let project_id = ProjectId::from_proto(message.project_id);
2430
2431 let operation = message.operation.as_ref().context("invalid operation")?;
2432 let capability = match operation.variant.as_ref() {
2433 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2434 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2435 match buffer_op.variant {
2436 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2437 Capability::ReadOnly
2438 }
2439 _ => Capability::ReadWrite,
2440 }
2441 } else {
2442 Capability::ReadWrite
2443 }
2444 }
2445 Some(_) => Capability::ReadWrite,
2446 None => Capability::ReadOnly,
2447 };
2448
2449 let guard = session
2450 .db()
2451 .await
2452 .connections_for_buffer_update(project_id, session.connection_id, capability)
2453 .await?;
2454
2455 let (host, guests) = &*guard;
2456
2457 broadcast(
2458 Some(session.connection_id),
2459 guests.iter().chain([host]).copied(),
2460 |connection_id| {
2461 session
2462 .peer
2463 .forward_send(session.connection_id, connection_id, message.clone())
2464 },
2465 );
2466
2467 Ok(())
2468}
2469
2470/// Notify other participants that a project has been updated.
2471async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2472 request: T,
2473 session: Session,
2474) -> Result<()> {
2475 let project_id = ProjectId::from_proto(request.remote_entity_id());
2476 let project_connection_ids = session
2477 .db()
2478 .await
2479 .project_connection_ids(project_id, session.connection_id, false)
2480 .await?;
2481
2482 broadcast(
2483 Some(session.connection_id),
2484 project_connection_ids.iter().copied(),
2485 |connection_id| {
2486 session
2487 .peer
2488 .forward_send(session.connection_id, connection_id, request.clone())
2489 },
2490 );
2491 Ok(())
2492}
2493
2494/// Start following another user in a call.
2495async fn follow(
2496 request: proto::Follow,
2497 response: Response<proto::Follow>,
2498 session: Session,
2499) -> Result<()> {
2500 let room_id = RoomId::from_proto(request.room_id);
2501 let project_id = request.project_id.map(ProjectId::from_proto);
2502 let leader_id = request.leader_id.context("invalid leader id")?.into();
2503 let follower_id = session.connection_id;
2504
2505 session
2506 .db()
2507 .await
2508 .check_room_participants(room_id, leader_id, session.connection_id)
2509 .await?;
2510
2511 let response_payload = session
2512 .peer
2513 .forward_request(session.connection_id, leader_id, request)
2514 .await?;
2515 response.send(response_payload)?;
2516
2517 if let Some(project_id) = project_id {
2518 let room = session
2519 .db()
2520 .await
2521 .follow(room_id, project_id, leader_id, follower_id)
2522 .await?;
2523 room_updated(&room, &session.peer);
2524 }
2525
2526 Ok(())
2527}
2528
2529/// Stop following another user in a call.
2530async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2531 let room_id = RoomId::from_proto(request.room_id);
2532 let project_id = request.project_id.map(ProjectId::from_proto);
2533 let leader_id = request.leader_id.context("invalid leader id")?.into();
2534 let follower_id = session.connection_id;
2535
2536 session
2537 .db()
2538 .await
2539 .check_room_participants(room_id, leader_id, session.connection_id)
2540 .await?;
2541
2542 session
2543 .peer
2544 .forward_send(session.connection_id, leader_id, request)?;
2545
2546 if let Some(project_id) = project_id {
2547 let room = session
2548 .db()
2549 .await
2550 .unfollow(room_id, project_id, leader_id, follower_id)
2551 .await?;
2552 room_updated(&room, &session.peer);
2553 }
2554
2555 Ok(())
2556}
2557
2558/// Notify everyone following you of your current location.
2559async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2560 let room_id = RoomId::from_proto(request.room_id);
2561 let database = session.db.lock().await;
2562
2563 let connection_ids = if let Some(project_id) = request.project_id {
2564 let project_id = ProjectId::from_proto(project_id);
2565 database
2566 .project_connection_ids(project_id, session.connection_id, true)
2567 .await?
2568 } else {
2569 database
2570 .room_connection_ids(room_id, session.connection_id)
2571 .await?
2572 };
2573
2574 // For now, don't send view update messages back to that view's current leader.
2575 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2576 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2577 _ => None,
2578 });
2579
2580 for connection_id in connection_ids.iter().cloned() {
2581 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2582 session
2583 .peer
2584 .forward_send(session.connection_id, connection_id, request.clone())?;
2585 }
2586 }
2587 Ok(())
2588}
2589
2590/// Get public data about users.
2591async fn get_users(
2592 request: proto::GetUsers,
2593 response: Response<proto::GetUsers>,
2594 session: Session,
2595) -> Result<()> {
2596 let user_ids = request
2597 .user_ids
2598 .into_iter()
2599 .map(UserId::from_proto)
2600 .collect();
2601 let users = session
2602 .db()
2603 .await
2604 .get_users_by_ids(user_ids)
2605 .await?
2606 .into_iter()
2607 .map(|user| proto::User {
2608 id: user.id.to_proto(),
2609 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2610 github_login: user.github_login,
2611 name: user.name,
2612 })
2613 .collect();
2614 response.send(proto::UsersResponse { users })?;
2615 Ok(())
2616}
2617
2618/// Search for users (to invite) buy Github login
2619async fn fuzzy_search_users(
2620 request: proto::FuzzySearchUsers,
2621 response: Response<proto::FuzzySearchUsers>,
2622 session: Session,
2623) -> Result<()> {
2624 let query = request.query;
2625 let users = match query.len() {
2626 0 => vec![],
2627 1 | 2 => session
2628 .db()
2629 .await
2630 .get_user_by_github_login(&query)
2631 .await?
2632 .into_iter()
2633 .collect(),
2634 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2635 };
2636 let users = users
2637 .into_iter()
2638 .filter(|user| user.id != session.user_id())
2639 .map(|user| proto::User {
2640 id: user.id.to_proto(),
2641 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2642 github_login: user.github_login,
2643 name: user.name,
2644 })
2645 .collect();
2646 response.send(proto::UsersResponse { users })?;
2647 Ok(())
2648}
2649
2650/// Send a contact request to another user.
2651async fn request_contact(
2652 request: proto::RequestContact,
2653 response: Response<proto::RequestContact>,
2654 session: Session,
2655) -> Result<()> {
2656 let requester_id = session.user_id();
2657 let responder_id = UserId::from_proto(request.responder_id);
2658 if requester_id == responder_id {
2659 return Err(anyhow!("cannot add yourself as a contact"))?;
2660 }
2661
2662 let notifications = session
2663 .db()
2664 .await
2665 .send_contact_request(requester_id, responder_id)
2666 .await?;
2667
2668 // Update outgoing contact requests of requester
2669 let mut update = proto::UpdateContacts::default();
2670 update.outgoing_requests.push(responder_id.to_proto());
2671 for connection_id in session
2672 .connection_pool()
2673 .await
2674 .user_connection_ids(requester_id)
2675 {
2676 session.peer.send(connection_id, update.clone())?;
2677 }
2678
2679 // Update incoming contact requests of responder
2680 let mut update = proto::UpdateContacts::default();
2681 update
2682 .incoming_requests
2683 .push(proto::IncomingContactRequest {
2684 requester_id: requester_id.to_proto(),
2685 });
2686 let connection_pool = session.connection_pool().await;
2687 for connection_id in connection_pool.user_connection_ids(responder_id) {
2688 session.peer.send(connection_id, update.clone())?;
2689 }
2690
2691 send_notifications(&connection_pool, &session.peer, notifications);
2692
2693 response.send(proto::Ack {})?;
2694 Ok(())
2695}
2696
2697/// Accept or decline a contact request
2698async fn respond_to_contact_request(
2699 request: proto::RespondToContactRequest,
2700 response: Response<proto::RespondToContactRequest>,
2701 session: Session,
2702) -> Result<()> {
2703 let responder_id = session.user_id();
2704 let requester_id = UserId::from_proto(request.requester_id);
2705 let db = session.db().await;
2706 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2707 db.dismiss_contact_notification(responder_id, requester_id)
2708 .await?;
2709 } else {
2710 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2711
2712 let notifications = db
2713 .respond_to_contact_request(responder_id, requester_id, accept)
2714 .await?;
2715 let requester_busy = db.is_user_busy(requester_id).await?;
2716 let responder_busy = db.is_user_busy(responder_id).await?;
2717
2718 let pool = session.connection_pool().await;
2719 // Update responder with new contact
2720 let mut update = proto::UpdateContacts::default();
2721 if accept {
2722 update
2723 .contacts
2724 .push(contact_for_user(requester_id, requester_busy, &pool));
2725 }
2726 update
2727 .remove_incoming_requests
2728 .push(requester_id.to_proto());
2729 for connection_id in pool.user_connection_ids(responder_id) {
2730 session.peer.send(connection_id, update.clone())?;
2731 }
2732
2733 // Update requester with new contact
2734 let mut update = proto::UpdateContacts::default();
2735 if accept {
2736 update
2737 .contacts
2738 .push(contact_for_user(responder_id, responder_busy, &pool));
2739 }
2740 update
2741 .remove_outgoing_requests
2742 .push(responder_id.to_proto());
2743
2744 for connection_id in pool.user_connection_ids(requester_id) {
2745 session.peer.send(connection_id, update.clone())?;
2746 }
2747
2748 send_notifications(&pool, &session.peer, notifications);
2749 }
2750
2751 response.send(proto::Ack {})?;
2752 Ok(())
2753}
2754
2755/// Remove a contact.
2756async fn remove_contact(
2757 request: proto::RemoveContact,
2758 response: Response<proto::RemoveContact>,
2759 session: Session,
2760) -> Result<()> {
2761 let requester_id = session.user_id();
2762 let responder_id = UserId::from_proto(request.user_id);
2763 let db = session.db().await;
2764 let (contact_accepted, deleted_notification_id) =
2765 db.remove_contact(requester_id, responder_id).await?;
2766
2767 let pool = session.connection_pool().await;
2768 // Update outgoing contact requests of requester
2769 let mut update = proto::UpdateContacts::default();
2770 if contact_accepted {
2771 update.remove_contacts.push(responder_id.to_proto());
2772 } else {
2773 update
2774 .remove_outgoing_requests
2775 .push(responder_id.to_proto());
2776 }
2777 for connection_id in pool.user_connection_ids(requester_id) {
2778 session.peer.send(connection_id, update.clone())?;
2779 }
2780
2781 // Update incoming contact requests of responder
2782 let mut update = proto::UpdateContacts::default();
2783 if contact_accepted {
2784 update.remove_contacts.push(requester_id.to_proto());
2785 } else {
2786 update
2787 .remove_incoming_requests
2788 .push(requester_id.to_proto());
2789 }
2790 for connection_id in pool.user_connection_ids(responder_id) {
2791 session.peer.send(connection_id, update.clone())?;
2792 if let Some(notification_id) = deleted_notification_id {
2793 session.peer.send(
2794 connection_id,
2795 proto::DeleteNotification {
2796 notification_id: notification_id.to_proto(),
2797 },
2798 )?;
2799 }
2800 }
2801
2802 response.send(proto::Ack {})?;
2803 Ok(())
2804}
2805
2806fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2807 version.0.minor() < 139
2808}
2809
2810async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2811 if is_staff {
2812 return Ok(proto::Plan::ZedPro);
2813 }
2814
2815 let subscription = db.get_active_billing_subscription(user_id).await?;
2816 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2817
2818 let plan = if let Some(subscription_kind) = subscription_kind {
2819 match subscription_kind {
2820 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2821 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2822 SubscriptionKind::ZedFree => proto::Plan::Free,
2823 }
2824 } else {
2825 proto::Plan::Free
2826 };
2827
2828 Ok(plan)
2829}
2830
2831async fn make_update_user_plan_message(
2832 user: &User,
2833 is_staff: bool,
2834 db: &Arc<Database>,
2835 llm_db: Option<Arc<LlmDatabase>>,
2836) -> Result<proto::UpdateUserPlan> {
2837 let feature_flags = db.get_user_flags(user.id).await?;
2838 let plan = current_plan(db, user.id, is_staff).await?;
2839 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2840 let billing_preferences = db.get_billing_preferences(user.id).await?;
2841
2842 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2843 let subscription = db.get_active_billing_subscription(user.id).await?;
2844
2845 let subscription_period =
2846 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2847
2848 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2849 llm_db
2850 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2851 .await?
2852 } else {
2853 None
2854 };
2855
2856 (subscription_period, usage)
2857 } else {
2858 (None, None)
2859 };
2860
2861 let bypass_account_age_check = feature_flags
2862 .iter()
2863 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2864 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2865 && !bypass_account_age_check
2866 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2867
2868 Ok(proto::UpdateUserPlan {
2869 plan: plan.into(),
2870 trial_started_at: billing_customer
2871 .as_ref()
2872 .and_then(|billing_customer| billing_customer.trial_started_at)
2873 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2874 is_usage_based_billing_enabled: if is_staff {
2875 Some(true)
2876 } else {
2877 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2878 },
2879 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2880 proto::SubscriptionPeriod {
2881 started_at: started_at.timestamp() as u64,
2882 ended_at: ended_at.timestamp() as u64,
2883 }
2884 }),
2885 account_too_young: Some(account_too_young),
2886 has_overdue_invoices: billing_customer
2887 .map(|billing_customer| billing_customer.has_overdue_invoices),
2888 usage: Some(
2889 usage
2890 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2891 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2892 ),
2893 })
2894}
2895
2896fn model_requests_limit(
2897 plan: cloud_llm_client::Plan,
2898 feature_flags: &Vec<String>,
2899) -> cloud_llm_client::UsageLimit {
2900 match plan.model_requests_limit() {
2901 cloud_llm_client::UsageLimit::Limited(limit) => {
2902 let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2903 && feature_flags
2904 .iter()
2905 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2906 {
2907 1_000
2908 } else {
2909 limit
2910 };
2911
2912 cloud_llm_client::UsageLimit::Limited(limit)
2913 }
2914 cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2915 }
2916}
2917
2918fn subscription_usage_to_proto(
2919 plan: proto::Plan,
2920 usage: crate::llm::db::subscription_usage::Model,
2921 feature_flags: &Vec<String>,
2922) -> proto::SubscriptionUsage {
2923 let plan = match plan {
2924 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2925 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2926 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2927 };
2928
2929 proto::SubscriptionUsage {
2930 model_requests_usage_amount: usage.model_requests as u32,
2931 model_requests_usage_limit: Some(proto::UsageLimit {
2932 variant: Some(match model_requests_limit(plan, feature_flags) {
2933 cloud_llm_client::UsageLimit::Limited(limit) => {
2934 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2935 limit: limit as u32,
2936 })
2937 }
2938 cloud_llm_client::UsageLimit::Unlimited => {
2939 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2940 }
2941 }),
2942 }),
2943 edit_predictions_usage_amount: usage.edit_predictions as u32,
2944 edit_predictions_usage_limit: Some(proto::UsageLimit {
2945 variant: Some(match plan.edit_predictions_limit() {
2946 cloud_llm_client::UsageLimit::Limited(limit) => {
2947 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2948 limit: limit as u32,
2949 })
2950 }
2951 cloud_llm_client::UsageLimit::Unlimited => {
2952 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2953 }
2954 }),
2955 }),
2956 }
2957}
2958
2959fn make_default_subscription_usage(
2960 plan: proto::Plan,
2961 feature_flags: &Vec<String>,
2962) -> proto::SubscriptionUsage {
2963 let plan = match plan {
2964 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2965 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2966 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2967 };
2968
2969 proto::SubscriptionUsage {
2970 model_requests_usage_amount: 0,
2971 model_requests_usage_limit: Some(proto::UsageLimit {
2972 variant: Some(match model_requests_limit(plan, feature_flags) {
2973 cloud_llm_client::UsageLimit::Limited(limit) => {
2974 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2975 limit: limit as u32,
2976 })
2977 }
2978 cloud_llm_client::UsageLimit::Unlimited => {
2979 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2980 }
2981 }),
2982 }),
2983 edit_predictions_usage_amount: 0,
2984 edit_predictions_usage_limit: Some(proto::UsageLimit {
2985 variant: Some(match plan.edit_predictions_limit() {
2986 cloud_llm_client::UsageLimit::Limited(limit) => {
2987 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2988 limit: limit as u32,
2989 })
2990 }
2991 cloud_llm_client::UsageLimit::Unlimited => {
2992 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2993 }
2994 }),
2995 }),
2996 }
2997}
2998
2999async fn update_user_plan(session: &Session) -> Result<()> {
3000 let db = session.db().await;
3001
3002 let update_user_plan = make_update_user_plan_message(
3003 session.principal.user(),
3004 session.is_staff(),
3005 &db.0,
3006 session.app_state.llm_db.clone(),
3007 )
3008 .await?;
3009
3010 session
3011 .peer
3012 .send(session.connection_id, update_user_plan)
3013 .trace_err();
3014
3015 Ok(())
3016}
3017
3018async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3019 subscribe_user_to_channels(session.user_id(), &session).await?;
3020 Ok(())
3021}
3022
3023async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3024 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3025 let mut pool = session.connection_pool().await;
3026 for membership in &channels_for_user.channel_memberships {
3027 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3028 }
3029 session.peer.send(
3030 session.connection_id,
3031 build_update_user_channels(&channels_for_user),
3032 )?;
3033 session.peer.send(
3034 session.connection_id,
3035 build_channels_update(channels_for_user),
3036 )?;
3037 Ok(())
3038}
3039
3040/// Creates a new channel.
3041async fn create_channel(
3042 request: proto::CreateChannel,
3043 response: Response<proto::CreateChannel>,
3044 session: Session,
3045) -> Result<()> {
3046 let db = session.db().await;
3047
3048 let parent_id = request.parent_id.map(ChannelId::from_proto);
3049 let (channel, membership) = db
3050 .create_channel(&request.name, parent_id, session.user_id())
3051 .await?;
3052
3053 let root_id = channel.root_id();
3054 let channel = Channel::from_model(channel);
3055
3056 response.send(proto::CreateChannelResponse {
3057 channel: Some(channel.to_proto()),
3058 parent_id: request.parent_id,
3059 })?;
3060
3061 let mut connection_pool = session.connection_pool().await;
3062 if let Some(membership) = membership {
3063 connection_pool.subscribe_to_channel(
3064 membership.user_id,
3065 membership.channel_id,
3066 membership.role,
3067 );
3068 let update = proto::UpdateUserChannels {
3069 channel_memberships: vec![proto::ChannelMembership {
3070 channel_id: membership.channel_id.to_proto(),
3071 role: membership.role.into(),
3072 }],
3073 ..Default::default()
3074 };
3075 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3076 session.peer.send(connection_id, update.clone())?;
3077 }
3078 }
3079
3080 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3081 if !role.can_see_channel(channel.visibility) {
3082 continue;
3083 }
3084
3085 let update = proto::UpdateChannels {
3086 channels: vec![channel.to_proto()],
3087 ..Default::default()
3088 };
3089 session.peer.send(connection_id, update.clone())?;
3090 }
3091
3092 Ok(())
3093}
3094
3095/// Delete a channel
3096async fn delete_channel(
3097 request: proto::DeleteChannel,
3098 response: Response<proto::DeleteChannel>,
3099 session: Session,
3100) -> Result<()> {
3101 let db = session.db().await;
3102
3103 let channel_id = request.channel_id;
3104 let (root_channel, removed_channels) = db
3105 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3106 .await?;
3107 response.send(proto::Ack {})?;
3108
3109 // Notify members of removed channels
3110 let mut update = proto::UpdateChannels::default();
3111 update
3112 .delete_channels
3113 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3114
3115 let connection_pool = session.connection_pool().await;
3116 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3117 session.peer.send(connection_id, update.clone())?;
3118 }
3119
3120 Ok(())
3121}
3122
3123/// Invite someone to join a channel.
3124async fn invite_channel_member(
3125 request: proto::InviteChannelMember,
3126 response: Response<proto::InviteChannelMember>,
3127 session: Session,
3128) -> Result<()> {
3129 let db = session.db().await;
3130 let channel_id = ChannelId::from_proto(request.channel_id);
3131 let invitee_id = UserId::from_proto(request.user_id);
3132 let InviteMemberResult {
3133 channel,
3134 notifications,
3135 } = db
3136 .invite_channel_member(
3137 channel_id,
3138 invitee_id,
3139 session.user_id(),
3140 request.role().into(),
3141 )
3142 .await?;
3143
3144 let update = proto::UpdateChannels {
3145 channel_invitations: vec![channel.to_proto()],
3146 ..Default::default()
3147 };
3148
3149 let connection_pool = session.connection_pool().await;
3150 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3151 session.peer.send(connection_id, update.clone())?;
3152 }
3153
3154 send_notifications(&connection_pool, &session.peer, notifications);
3155
3156 response.send(proto::Ack {})?;
3157 Ok(())
3158}
3159
3160/// remove someone from a channel
3161async fn remove_channel_member(
3162 request: proto::RemoveChannelMember,
3163 response: Response<proto::RemoveChannelMember>,
3164 session: Session,
3165) -> Result<()> {
3166 let db = session.db().await;
3167 let channel_id = ChannelId::from_proto(request.channel_id);
3168 let member_id = UserId::from_proto(request.user_id);
3169
3170 let RemoveChannelMemberResult {
3171 membership_update,
3172 notification_id,
3173 } = db
3174 .remove_channel_member(channel_id, member_id, session.user_id())
3175 .await?;
3176
3177 let mut connection_pool = session.connection_pool().await;
3178 notify_membership_updated(
3179 &mut connection_pool,
3180 membership_update,
3181 member_id,
3182 &session.peer,
3183 );
3184 for connection_id in connection_pool.user_connection_ids(member_id) {
3185 if let Some(notification_id) = notification_id {
3186 session
3187 .peer
3188 .send(
3189 connection_id,
3190 proto::DeleteNotification {
3191 notification_id: notification_id.to_proto(),
3192 },
3193 )
3194 .trace_err();
3195 }
3196 }
3197
3198 response.send(proto::Ack {})?;
3199 Ok(())
3200}
3201
3202/// Toggle the channel between public and private.
3203/// Care is taken to maintain the invariant that public channels only descend from public channels,
3204/// (though members-only channels can appear at any point in the hierarchy).
3205async fn set_channel_visibility(
3206 request: proto::SetChannelVisibility,
3207 response: Response<proto::SetChannelVisibility>,
3208 session: Session,
3209) -> Result<()> {
3210 let db = session.db().await;
3211 let channel_id = ChannelId::from_proto(request.channel_id);
3212 let visibility = request.visibility().into();
3213
3214 let channel_model = db
3215 .set_channel_visibility(channel_id, visibility, session.user_id())
3216 .await?;
3217 let root_id = channel_model.root_id();
3218 let channel = Channel::from_model(channel_model);
3219
3220 let mut connection_pool = session.connection_pool().await;
3221 for (user_id, role) in connection_pool
3222 .channel_user_ids(root_id)
3223 .collect::<Vec<_>>()
3224 .into_iter()
3225 {
3226 let update = if role.can_see_channel(channel.visibility) {
3227 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3228 proto::UpdateChannels {
3229 channels: vec![channel.to_proto()],
3230 ..Default::default()
3231 }
3232 } else {
3233 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3234 proto::UpdateChannels {
3235 delete_channels: vec![channel.id.to_proto()],
3236 ..Default::default()
3237 }
3238 };
3239
3240 for connection_id in connection_pool.user_connection_ids(user_id) {
3241 session.peer.send(connection_id, update.clone())?;
3242 }
3243 }
3244
3245 response.send(proto::Ack {})?;
3246 Ok(())
3247}
3248
3249/// Alter the role for a user in the channel.
3250async fn set_channel_member_role(
3251 request: proto::SetChannelMemberRole,
3252 response: Response<proto::SetChannelMemberRole>,
3253 session: Session,
3254) -> Result<()> {
3255 let db = session.db().await;
3256 let channel_id = ChannelId::from_proto(request.channel_id);
3257 let member_id = UserId::from_proto(request.user_id);
3258 let result = db
3259 .set_channel_member_role(
3260 channel_id,
3261 session.user_id(),
3262 member_id,
3263 request.role().into(),
3264 )
3265 .await?;
3266
3267 match result {
3268 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3269 let mut connection_pool = session.connection_pool().await;
3270 notify_membership_updated(
3271 &mut connection_pool,
3272 membership_update,
3273 member_id,
3274 &session.peer,
3275 )
3276 }
3277 db::SetMemberRoleResult::InviteUpdated(channel) => {
3278 let update = proto::UpdateChannels {
3279 channel_invitations: vec![channel.to_proto()],
3280 ..Default::default()
3281 };
3282
3283 for connection_id in session
3284 .connection_pool()
3285 .await
3286 .user_connection_ids(member_id)
3287 {
3288 session.peer.send(connection_id, update.clone())?;
3289 }
3290 }
3291 }
3292
3293 response.send(proto::Ack {})?;
3294 Ok(())
3295}
3296
3297/// Change the name of a channel
3298async fn rename_channel(
3299 request: proto::RenameChannel,
3300 response: Response<proto::RenameChannel>,
3301 session: Session,
3302) -> Result<()> {
3303 let db = session.db().await;
3304 let channel_id = ChannelId::from_proto(request.channel_id);
3305 let channel_model = db
3306 .rename_channel(channel_id, session.user_id(), &request.name)
3307 .await?;
3308 let root_id = channel_model.root_id();
3309 let channel = Channel::from_model(channel_model);
3310
3311 response.send(proto::RenameChannelResponse {
3312 channel: Some(channel.to_proto()),
3313 })?;
3314
3315 let connection_pool = session.connection_pool().await;
3316 let update = proto::UpdateChannels {
3317 channels: vec![channel.to_proto()],
3318 ..Default::default()
3319 };
3320 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3321 if role.can_see_channel(channel.visibility) {
3322 session.peer.send(connection_id, update.clone())?;
3323 }
3324 }
3325
3326 Ok(())
3327}
3328
3329/// Move a channel to a new parent.
3330async fn move_channel(
3331 request: proto::MoveChannel,
3332 response: Response<proto::MoveChannel>,
3333 session: Session,
3334) -> Result<()> {
3335 let channel_id = ChannelId::from_proto(request.channel_id);
3336 let to = ChannelId::from_proto(request.to);
3337
3338 let (root_id, channels) = session
3339 .db()
3340 .await
3341 .move_channel(channel_id, to, session.user_id())
3342 .await?;
3343
3344 let connection_pool = session.connection_pool().await;
3345 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3346 let channels = channels
3347 .iter()
3348 .filter_map(|channel| {
3349 if role.can_see_channel(channel.visibility) {
3350 Some(channel.to_proto())
3351 } else {
3352 None
3353 }
3354 })
3355 .collect::<Vec<_>>();
3356 if channels.is_empty() {
3357 continue;
3358 }
3359
3360 let update = proto::UpdateChannels {
3361 channels,
3362 ..Default::default()
3363 };
3364
3365 session.peer.send(connection_id, update.clone())?;
3366 }
3367
3368 response.send(Ack {})?;
3369 Ok(())
3370}
3371
3372async fn reorder_channel(
3373 request: proto::ReorderChannel,
3374 response: Response<proto::ReorderChannel>,
3375 session: Session,
3376) -> Result<()> {
3377 let channel_id = ChannelId::from_proto(request.channel_id);
3378 let direction = request.direction();
3379
3380 let updated_channels = session
3381 .db()
3382 .await
3383 .reorder_channel(channel_id, direction, session.user_id())
3384 .await?;
3385
3386 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3387 let connection_pool = session.connection_pool().await;
3388 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3389 let channels = updated_channels
3390 .iter()
3391 .filter_map(|channel| {
3392 if role.can_see_channel(channel.visibility) {
3393 Some(channel.to_proto())
3394 } else {
3395 None
3396 }
3397 })
3398 .collect::<Vec<_>>();
3399
3400 if channels.is_empty() {
3401 continue;
3402 }
3403
3404 let update = proto::UpdateChannels {
3405 channels,
3406 ..Default::default()
3407 };
3408
3409 session.peer.send(connection_id, update.clone())?;
3410 }
3411 }
3412
3413 response.send(Ack {})?;
3414 Ok(())
3415}
3416
3417/// Get the list of channel members
3418async fn get_channel_members(
3419 request: proto::GetChannelMembers,
3420 response: Response<proto::GetChannelMembers>,
3421 session: Session,
3422) -> Result<()> {
3423 let db = session.db().await;
3424 let channel_id = ChannelId::from_proto(request.channel_id);
3425 let limit = if request.limit == 0 {
3426 u16::MAX as u64
3427 } else {
3428 request.limit
3429 };
3430 let (members, users) = db
3431 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3432 .await?;
3433 response.send(proto::GetChannelMembersResponse { members, users })?;
3434 Ok(())
3435}
3436
3437/// Accept or decline a channel invitation.
3438async fn respond_to_channel_invite(
3439 request: proto::RespondToChannelInvite,
3440 response: Response<proto::RespondToChannelInvite>,
3441 session: Session,
3442) -> Result<()> {
3443 let db = session.db().await;
3444 let channel_id = ChannelId::from_proto(request.channel_id);
3445 let RespondToChannelInvite {
3446 membership_update,
3447 notifications,
3448 } = db
3449 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3450 .await?;
3451
3452 let mut connection_pool = session.connection_pool().await;
3453 if let Some(membership_update) = membership_update {
3454 notify_membership_updated(
3455 &mut connection_pool,
3456 membership_update,
3457 session.user_id(),
3458 &session.peer,
3459 );
3460 } else {
3461 let update = proto::UpdateChannels {
3462 remove_channel_invitations: vec![channel_id.to_proto()],
3463 ..Default::default()
3464 };
3465
3466 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3467 session.peer.send(connection_id, update.clone())?;
3468 }
3469 };
3470
3471 send_notifications(&connection_pool, &session.peer, notifications);
3472
3473 response.send(proto::Ack {})?;
3474
3475 Ok(())
3476}
3477
3478/// Join the channels' room
3479async fn join_channel(
3480 request: proto::JoinChannel,
3481 response: Response<proto::JoinChannel>,
3482 session: Session,
3483) -> Result<()> {
3484 let channel_id = ChannelId::from_proto(request.channel_id);
3485 join_channel_internal(channel_id, Box::new(response), session).await
3486}
3487
3488trait JoinChannelInternalResponse {
3489 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3490}
3491impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3492 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3493 Response::<proto::JoinChannel>::send(self, result)
3494 }
3495}
3496impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3497 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3498 Response::<proto::JoinRoom>::send(self, result)
3499 }
3500}
3501
3502async fn join_channel_internal(
3503 channel_id: ChannelId,
3504 response: Box<impl JoinChannelInternalResponse>,
3505 session: Session,
3506) -> Result<()> {
3507 let joined_room = {
3508 let mut db = session.db().await;
3509 // If zed quits without leaving the room, and the user re-opens zed before the
3510 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3511 // room they were in.
3512 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3513 tracing::info!(
3514 stale_connection_id = %connection,
3515 "cleaning up stale connection",
3516 );
3517 drop(db);
3518 leave_room_for_session(&session, connection).await?;
3519 db = session.db().await;
3520 }
3521
3522 let (joined_room, membership_updated, role) = db
3523 .join_channel(channel_id, session.user_id(), session.connection_id)
3524 .await?;
3525
3526 let live_kit_connection_info =
3527 session
3528 .app_state
3529 .livekit_client
3530 .as_ref()
3531 .and_then(|live_kit| {
3532 let (can_publish, token) = if role == ChannelRole::Guest {
3533 (
3534 false,
3535 live_kit
3536 .guest_token(
3537 &joined_room.room.livekit_room,
3538 &session.user_id().to_string(),
3539 )
3540 .trace_err()?,
3541 )
3542 } else {
3543 (
3544 true,
3545 live_kit
3546 .room_token(
3547 &joined_room.room.livekit_room,
3548 &session.user_id().to_string(),
3549 )
3550 .trace_err()?,
3551 )
3552 };
3553
3554 Some(LiveKitConnectionInfo {
3555 server_url: live_kit.url().into(),
3556 token,
3557 can_publish,
3558 })
3559 });
3560
3561 response.send(proto::JoinRoomResponse {
3562 room: Some(joined_room.room.clone()),
3563 channel_id: joined_room
3564 .channel
3565 .as_ref()
3566 .map(|channel| channel.id.to_proto()),
3567 live_kit_connection_info,
3568 })?;
3569
3570 let mut connection_pool = session.connection_pool().await;
3571 if let Some(membership_updated) = membership_updated {
3572 notify_membership_updated(
3573 &mut connection_pool,
3574 membership_updated,
3575 session.user_id(),
3576 &session.peer,
3577 );
3578 }
3579
3580 room_updated(&joined_room.room, &session.peer);
3581
3582 joined_room
3583 };
3584
3585 channel_updated(
3586 &joined_room.channel.context("channel not returned")?,
3587 &joined_room.room,
3588 &session.peer,
3589 &*session.connection_pool().await,
3590 );
3591
3592 update_user_contacts(session.user_id(), &session).await?;
3593 Ok(())
3594}
3595
3596/// Start editing the channel notes
3597async fn join_channel_buffer(
3598 request: proto::JoinChannelBuffer,
3599 response: Response<proto::JoinChannelBuffer>,
3600 session: Session,
3601) -> Result<()> {
3602 let db = session.db().await;
3603 let channel_id = ChannelId::from_proto(request.channel_id);
3604
3605 let open_response = db
3606 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3607 .await?;
3608
3609 let collaborators = open_response.collaborators.clone();
3610 response.send(open_response)?;
3611
3612 let update = UpdateChannelBufferCollaborators {
3613 channel_id: channel_id.to_proto(),
3614 collaborators: collaborators.clone(),
3615 };
3616 channel_buffer_updated(
3617 session.connection_id,
3618 collaborators
3619 .iter()
3620 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3621 &update,
3622 &session.peer,
3623 );
3624
3625 Ok(())
3626}
3627
3628/// Edit the channel notes
3629async fn update_channel_buffer(
3630 request: proto::UpdateChannelBuffer,
3631 session: Session,
3632) -> Result<()> {
3633 let db = session.db().await;
3634 let channel_id = ChannelId::from_proto(request.channel_id);
3635
3636 let (collaborators, epoch, version) = db
3637 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3638 .await?;
3639
3640 channel_buffer_updated(
3641 session.connection_id,
3642 collaborators.clone(),
3643 &proto::UpdateChannelBuffer {
3644 channel_id: channel_id.to_proto(),
3645 operations: request.operations,
3646 },
3647 &session.peer,
3648 );
3649
3650 let pool = &*session.connection_pool().await;
3651
3652 let non_collaborators =
3653 pool.channel_connection_ids(channel_id)
3654 .filter_map(|(connection_id, _)| {
3655 if collaborators.contains(&connection_id) {
3656 None
3657 } else {
3658 Some(connection_id)
3659 }
3660 });
3661
3662 broadcast(None, non_collaborators, |peer_id| {
3663 session.peer.send(
3664 peer_id,
3665 proto::UpdateChannels {
3666 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3667 channel_id: channel_id.to_proto(),
3668 epoch: epoch as u64,
3669 version: version.clone(),
3670 }],
3671 ..Default::default()
3672 },
3673 )
3674 });
3675
3676 Ok(())
3677}
3678
3679/// Rejoin the channel notes after a connection blip
3680async fn rejoin_channel_buffers(
3681 request: proto::RejoinChannelBuffers,
3682 response: Response<proto::RejoinChannelBuffers>,
3683 session: Session,
3684) -> Result<()> {
3685 let db = session.db().await;
3686 let buffers = db
3687 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3688 .await?;
3689
3690 for rejoined_buffer in &buffers {
3691 let collaborators_to_notify = rejoined_buffer
3692 .buffer
3693 .collaborators
3694 .iter()
3695 .filter_map(|c| Some(c.peer_id?.into()));
3696 channel_buffer_updated(
3697 session.connection_id,
3698 collaborators_to_notify,
3699 &proto::UpdateChannelBufferCollaborators {
3700 channel_id: rejoined_buffer.buffer.channel_id,
3701 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3702 },
3703 &session.peer,
3704 );
3705 }
3706
3707 response.send(proto::RejoinChannelBuffersResponse {
3708 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3709 })?;
3710
3711 Ok(())
3712}
3713
3714/// Stop editing the channel notes
3715async fn leave_channel_buffer(
3716 request: proto::LeaveChannelBuffer,
3717 response: Response<proto::LeaveChannelBuffer>,
3718 session: Session,
3719) -> Result<()> {
3720 let db = session.db().await;
3721 let channel_id = ChannelId::from_proto(request.channel_id);
3722
3723 let left_buffer = db
3724 .leave_channel_buffer(channel_id, session.connection_id)
3725 .await?;
3726
3727 response.send(Ack {})?;
3728
3729 channel_buffer_updated(
3730 session.connection_id,
3731 left_buffer.connections,
3732 &proto::UpdateChannelBufferCollaborators {
3733 channel_id: channel_id.to_proto(),
3734 collaborators: left_buffer.collaborators,
3735 },
3736 &session.peer,
3737 );
3738
3739 Ok(())
3740}
3741
3742fn channel_buffer_updated<T: EnvelopedMessage>(
3743 sender_id: ConnectionId,
3744 collaborators: impl IntoIterator<Item = ConnectionId>,
3745 message: &T,
3746 peer: &Peer,
3747) {
3748 broadcast(Some(sender_id), collaborators, |peer_id| {
3749 peer.send(peer_id, message.clone())
3750 });
3751}
3752
3753fn send_notifications(
3754 connection_pool: &ConnectionPool,
3755 peer: &Peer,
3756 notifications: db::NotificationBatch,
3757) {
3758 for (user_id, notification) in notifications {
3759 for connection_id in connection_pool.user_connection_ids(user_id) {
3760 if let Err(error) = peer.send(
3761 connection_id,
3762 proto::AddNotification {
3763 notification: Some(notification.clone()),
3764 },
3765 ) {
3766 tracing::error!(
3767 "failed to send notification to {:?} {}",
3768 connection_id,
3769 error
3770 );
3771 }
3772 }
3773 }
3774}
3775
3776/// Send a message to the channel
3777async fn send_channel_message(
3778 request: proto::SendChannelMessage,
3779 response: Response<proto::SendChannelMessage>,
3780 session: Session,
3781) -> Result<()> {
3782 // Validate the message body.
3783 let body = request.body.trim().to_string();
3784 if body.len() > MAX_MESSAGE_LEN {
3785 return Err(anyhow!("message is too long"))?;
3786 }
3787 if body.is_empty() {
3788 return Err(anyhow!("message can't be blank"))?;
3789 }
3790
3791 // TODO: adjust mentions if body is trimmed
3792
3793 let timestamp = OffsetDateTime::now_utc();
3794 let nonce = request.nonce.context("nonce can't be blank")?;
3795
3796 let channel_id = ChannelId::from_proto(request.channel_id);
3797 let CreatedChannelMessage {
3798 message_id,
3799 participant_connection_ids,
3800 notifications,
3801 } = session
3802 .db()
3803 .await
3804 .create_channel_message(
3805 channel_id,
3806 session.user_id(),
3807 &body,
3808 &request.mentions,
3809 timestamp,
3810 nonce.clone().into(),
3811 request.reply_to_message_id.map(MessageId::from_proto),
3812 )
3813 .await?;
3814
3815 let message = proto::ChannelMessage {
3816 sender_id: session.user_id().to_proto(),
3817 id: message_id.to_proto(),
3818 body,
3819 mentions: request.mentions,
3820 timestamp: timestamp.unix_timestamp() as u64,
3821 nonce: Some(nonce),
3822 reply_to_message_id: request.reply_to_message_id,
3823 edited_at: None,
3824 };
3825 broadcast(
3826 Some(session.connection_id),
3827 participant_connection_ids.clone(),
3828 |connection| {
3829 session.peer.send(
3830 connection,
3831 proto::ChannelMessageSent {
3832 channel_id: channel_id.to_proto(),
3833 message: Some(message.clone()),
3834 },
3835 )
3836 },
3837 );
3838 response.send(proto::SendChannelMessageResponse {
3839 message: Some(message),
3840 })?;
3841
3842 let pool = &*session.connection_pool().await;
3843 let non_participants =
3844 pool.channel_connection_ids(channel_id)
3845 .filter_map(|(connection_id, _)| {
3846 if participant_connection_ids.contains(&connection_id) {
3847 None
3848 } else {
3849 Some(connection_id)
3850 }
3851 });
3852 broadcast(None, non_participants, |peer_id| {
3853 session.peer.send(
3854 peer_id,
3855 proto::UpdateChannels {
3856 latest_channel_message_ids: vec![proto::ChannelMessageId {
3857 channel_id: channel_id.to_proto(),
3858 message_id: message_id.to_proto(),
3859 }],
3860 ..Default::default()
3861 },
3862 )
3863 });
3864 send_notifications(pool, &session.peer, notifications);
3865
3866 Ok(())
3867}
3868
3869/// Delete a channel message
3870async fn remove_channel_message(
3871 request: proto::RemoveChannelMessage,
3872 response: Response<proto::RemoveChannelMessage>,
3873 session: Session,
3874) -> Result<()> {
3875 let channel_id = ChannelId::from_proto(request.channel_id);
3876 let message_id = MessageId::from_proto(request.message_id);
3877 let (connection_ids, existing_notification_ids) = session
3878 .db()
3879 .await
3880 .remove_channel_message(channel_id, message_id, session.user_id())
3881 .await?;
3882
3883 broadcast(
3884 Some(session.connection_id),
3885 connection_ids,
3886 move |connection| {
3887 session.peer.send(connection, request.clone())?;
3888
3889 for notification_id in &existing_notification_ids {
3890 session.peer.send(
3891 connection,
3892 proto::DeleteNotification {
3893 notification_id: (*notification_id).to_proto(),
3894 },
3895 )?;
3896 }
3897
3898 Ok(())
3899 },
3900 );
3901 response.send(proto::Ack {})?;
3902 Ok(())
3903}
3904
3905async fn update_channel_message(
3906 request: proto::UpdateChannelMessage,
3907 response: Response<proto::UpdateChannelMessage>,
3908 session: Session,
3909) -> Result<()> {
3910 let channel_id = ChannelId::from_proto(request.channel_id);
3911 let message_id = MessageId::from_proto(request.message_id);
3912 let updated_at = OffsetDateTime::now_utc();
3913 let UpdatedChannelMessage {
3914 message_id,
3915 participant_connection_ids,
3916 notifications,
3917 reply_to_message_id,
3918 timestamp,
3919 deleted_mention_notification_ids,
3920 updated_mention_notifications,
3921 } = session
3922 .db()
3923 .await
3924 .update_channel_message(
3925 channel_id,
3926 message_id,
3927 session.user_id(),
3928 request.body.as_str(),
3929 &request.mentions,
3930 updated_at,
3931 )
3932 .await?;
3933
3934 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3935
3936 let message = proto::ChannelMessage {
3937 sender_id: session.user_id().to_proto(),
3938 id: message_id.to_proto(),
3939 body: request.body.clone(),
3940 mentions: request.mentions.clone(),
3941 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3942 nonce: Some(nonce),
3943 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3944 edited_at: Some(updated_at.unix_timestamp() as u64),
3945 };
3946
3947 response.send(proto::Ack {})?;
3948
3949 let pool = &*session.connection_pool().await;
3950 broadcast(
3951 Some(session.connection_id),
3952 participant_connection_ids,
3953 |connection| {
3954 session.peer.send(
3955 connection,
3956 proto::ChannelMessageUpdate {
3957 channel_id: channel_id.to_proto(),
3958 message: Some(message.clone()),
3959 },
3960 )?;
3961
3962 for notification_id in &deleted_mention_notification_ids {
3963 session.peer.send(
3964 connection,
3965 proto::DeleteNotification {
3966 notification_id: (*notification_id).to_proto(),
3967 },
3968 )?;
3969 }
3970
3971 for notification in &updated_mention_notifications {
3972 session.peer.send(
3973 connection,
3974 proto::UpdateNotification {
3975 notification: Some(notification.clone()),
3976 },
3977 )?;
3978 }
3979
3980 Ok(())
3981 },
3982 );
3983
3984 send_notifications(pool, &session.peer, notifications);
3985
3986 Ok(())
3987}
3988
3989/// Mark a channel message as read
3990async fn acknowledge_channel_message(
3991 request: proto::AckChannelMessage,
3992 session: Session,
3993) -> Result<()> {
3994 let channel_id = ChannelId::from_proto(request.channel_id);
3995 let message_id = MessageId::from_proto(request.message_id);
3996 let notifications = session
3997 .db()
3998 .await
3999 .observe_channel_message(channel_id, session.user_id(), message_id)
4000 .await?;
4001 send_notifications(
4002 &*session.connection_pool().await,
4003 &session.peer,
4004 notifications,
4005 );
4006 Ok(())
4007}
4008
4009/// Mark a buffer version as synced
4010async fn acknowledge_buffer_version(
4011 request: proto::AckBufferOperation,
4012 session: Session,
4013) -> Result<()> {
4014 let buffer_id = BufferId::from_proto(request.buffer_id);
4015 session
4016 .db()
4017 .await
4018 .observe_buffer_version(
4019 buffer_id,
4020 session.user_id(),
4021 request.epoch as i32,
4022 &request.version,
4023 )
4024 .await?;
4025 Ok(())
4026}
4027
4028/// Get a Supermaven API key for the user
4029async fn get_supermaven_api_key(
4030 _request: proto::GetSupermavenApiKey,
4031 response: Response<proto::GetSupermavenApiKey>,
4032 session: Session,
4033) -> Result<()> {
4034 let user_id: String = session.user_id().to_string();
4035 if !session.is_staff() {
4036 return Err(anyhow!("supermaven not enabled for this account"))?;
4037 }
4038
4039 let email = session.email().context("user must have an email")?;
4040
4041 let supermaven_admin_api = session
4042 .supermaven_client
4043 .as_ref()
4044 .context("supermaven not configured")?;
4045
4046 let result = supermaven_admin_api
4047 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4048 .await?;
4049
4050 response.send(proto::GetSupermavenApiKeyResponse {
4051 api_key: result.api_key,
4052 })?;
4053
4054 Ok(())
4055}
4056
4057/// Start receiving chat updates for a channel
4058async fn join_channel_chat(
4059 request: proto::JoinChannelChat,
4060 response: Response<proto::JoinChannelChat>,
4061 session: Session,
4062) -> Result<()> {
4063 let channel_id = ChannelId::from_proto(request.channel_id);
4064
4065 let db = session.db().await;
4066 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4067 .await?;
4068 let messages = db
4069 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4070 .await?;
4071 response.send(proto::JoinChannelChatResponse {
4072 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4073 messages,
4074 })?;
4075 Ok(())
4076}
4077
4078/// Stop receiving chat updates for a channel
4079async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4080 let channel_id = ChannelId::from_proto(request.channel_id);
4081 session
4082 .db()
4083 .await
4084 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4085 .await?;
4086 Ok(())
4087}
4088
4089/// Retrieve the chat history for a channel
4090async fn get_channel_messages(
4091 request: proto::GetChannelMessages,
4092 response: Response<proto::GetChannelMessages>,
4093 session: Session,
4094) -> Result<()> {
4095 let channel_id = ChannelId::from_proto(request.channel_id);
4096 let messages = session
4097 .db()
4098 .await
4099 .get_channel_messages(
4100 channel_id,
4101 session.user_id(),
4102 MESSAGE_COUNT_PER_PAGE,
4103 Some(MessageId::from_proto(request.before_message_id)),
4104 )
4105 .await?;
4106 response.send(proto::GetChannelMessagesResponse {
4107 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4108 messages,
4109 })?;
4110 Ok(())
4111}
4112
4113/// Retrieve specific chat messages
4114async fn get_channel_messages_by_id(
4115 request: proto::GetChannelMessagesById,
4116 response: Response<proto::GetChannelMessagesById>,
4117 session: Session,
4118) -> Result<()> {
4119 let message_ids = request
4120 .message_ids
4121 .iter()
4122 .map(|id| MessageId::from_proto(*id))
4123 .collect::<Vec<_>>();
4124 let messages = session
4125 .db()
4126 .await
4127 .get_channel_messages_by_id(session.user_id(), &message_ids)
4128 .await?;
4129 response.send(proto::GetChannelMessagesResponse {
4130 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4131 messages,
4132 })?;
4133 Ok(())
4134}
4135
4136/// Retrieve the current users notifications
4137async fn get_notifications(
4138 request: proto::GetNotifications,
4139 response: Response<proto::GetNotifications>,
4140 session: Session,
4141) -> Result<()> {
4142 let notifications = session
4143 .db()
4144 .await
4145 .get_notifications(
4146 session.user_id(),
4147 NOTIFICATION_COUNT_PER_PAGE,
4148 request.before_id.map(db::NotificationId::from_proto),
4149 )
4150 .await?;
4151 response.send(proto::GetNotificationsResponse {
4152 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4153 notifications,
4154 })?;
4155 Ok(())
4156}
4157
4158/// Mark notifications as read
4159async fn mark_notification_as_read(
4160 request: proto::MarkNotificationRead,
4161 response: Response<proto::MarkNotificationRead>,
4162 session: Session,
4163) -> Result<()> {
4164 let database = &session.db().await;
4165 let notifications = database
4166 .mark_notification_as_read_by_id(
4167 session.user_id(),
4168 NotificationId::from_proto(request.notification_id),
4169 )
4170 .await?;
4171 send_notifications(
4172 &*session.connection_pool().await,
4173 &session.peer,
4174 notifications,
4175 );
4176 response.send(proto::Ack {})?;
4177 Ok(())
4178}
4179
4180/// Get the current users information
4181async fn get_private_user_info(
4182 _request: proto::GetPrivateUserInfo,
4183 response: Response<proto::GetPrivateUserInfo>,
4184 session: Session,
4185) -> Result<()> {
4186 let db = session.db().await;
4187
4188 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4189 let user = db
4190 .get_user_by_id(session.user_id())
4191 .await?
4192 .context("user not found")?;
4193 let flags = db.get_user_flags(session.user_id()).await?;
4194
4195 response.send(proto::GetPrivateUserInfoResponse {
4196 metrics_id,
4197 staff: user.admin,
4198 flags,
4199 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4200 })?;
4201 Ok(())
4202}
4203
4204/// Accept the terms of service (tos) on behalf of the current user
4205async fn accept_terms_of_service(
4206 _request: proto::AcceptTermsOfService,
4207 response: Response<proto::AcceptTermsOfService>,
4208 session: Session,
4209) -> Result<()> {
4210 let db = session.db().await;
4211
4212 let accepted_tos_at = Utc::now();
4213 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4214 .await?;
4215
4216 response.send(proto::AcceptTermsOfServiceResponse {
4217 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4218 })?;
4219
4220 // When the user accepts the terms of service, we want to refresh their LLM
4221 // token to grant access.
4222 session
4223 .peer
4224 .send(session.connection_id, proto::RefreshLlmToken {})?;
4225
4226 Ok(())
4227}
4228
4229async fn get_llm_api_token(
4230 _request: proto::GetLlmToken,
4231 response: Response<proto::GetLlmToken>,
4232 session: Session,
4233) -> Result<()> {
4234 let db = session.db().await;
4235
4236 let flags = db.get_user_flags(session.user_id()).await?;
4237
4238 let user_id = session.user_id();
4239 let user = db
4240 .get_user_by_id(user_id)
4241 .await?
4242 .with_context(|| format!("user {user_id} not found"))?;
4243
4244 if user.accepted_tos_at.is_none() {
4245 Err(anyhow!("terms of service not accepted"))?
4246 }
4247
4248 let stripe_client = session
4249 .app_state
4250 .stripe_client
4251 .as_ref()
4252 .context("failed to retrieve Stripe client")?;
4253
4254 let stripe_billing = session
4255 .app_state
4256 .stripe_billing
4257 .as_ref()
4258 .context("failed to retrieve Stripe billing object")?;
4259
4260 let billing_customer = if let Some(billing_customer) =
4261 db.get_billing_customer_by_user_id(user.id).await?
4262 {
4263 billing_customer
4264 } else {
4265 let customer_id = stripe_billing
4266 .find_or_create_customer_by_email(user.email_address.as_deref())
4267 .await?;
4268
4269 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4270 .await?
4271 .context("billing customer not found")?
4272 };
4273
4274 let billing_subscription =
4275 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4276 billing_subscription
4277 } else {
4278 let stripe_customer_id =
4279 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4280
4281 let stripe_subscription = stripe_billing
4282 .subscribe_to_zed_free(stripe_customer_id)
4283 .await?;
4284
4285 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4286 billing_customer_id: billing_customer.id,
4287 kind: Some(SubscriptionKind::ZedFree),
4288 stripe_subscription_id: stripe_subscription.id.to_string(),
4289 stripe_subscription_status: stripe_subscription.status.into(),
4290 stripe_cancellation_reason: None,
4291 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4292 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4293 })
4294 .await?
4295 };
4296
4297 let billing_preferences = db.get_billing_preferences(user.id).await?;
4298
4299 let token = LlmTokenClaims::create(
4300 &user,
4301 session.is_staff(),
4302 billing_customer,
4303 billing_preferences,
4304 &flags,
4305 billing_subscription,
4306 session.system_id.clone(),
4307 &session.app_state.config,
4308 )?;
4309 response.send(proto::GetLlmTokenResponse { token })?;
4310 Ok(())
4311}
4312
4313fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4314 let message = match message {
4315 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4316 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4317 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4318 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4319 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4320 code: frame.code.into(),
4321 reason: frame.reason.as_str().to_owned().into(),
4322 })),
4323 // We should never receive a frame while reading the message, according
4324 // to the `tungstenite` maintainers:
4325 //
4326 // > It cannot occur when you read messages from the WebSocket, but it
4327 // > can be used when you want to send the raw frames (e.g. you want to
4328 // > send the frames to the WebSocket without composing the full message first).
4329 // >
4330 // > — https://github.com/snapview/tungstenite-rs/issues/268
4331 TungsteniteMessage::Frame(_) => {
4332 bail!("received an unexpected frame while reading the message")
4333 }
4334 };
4335
4336 Ok(message)
4337}
4338
4339fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4340 match message {
4341 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4342 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4343 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4344 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4345 AxumMessage::Close(frame) => {
4346 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4347 code: frame.code.into(),
4348 reason: frame.reason.as_ref().into(),
4349 }))
4350 }
4351 }
4352}
4353
4354fn notify_membership_updated(
4355 connection_pool: &mut ConnectionPool,
4356 result: MembershipUpdated,
4357 user_id: UserId,
4358 peer: &Peer,
4359) {
4360 for membership in &result.new_channels.channel_memberships {
4361 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4362 }
4363 for channel_id in &result.removed_channels {
4364 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4365 }
4366
4367 let user_channels_update = proto::UpdateUserChannels {
4368 channel_memberships: result
4369 .new_channels
4370 .channel_memberships
4371 .iter()
4372 .map(|cm| proto::ChannelMembership {
4373 channel_id: cm.channel_id.to_proto(),
4374 role: cm.role.into(),
4375 })
4376 .collect(),
4377 ..Default::default()
4378 };
4379
4380 let mut update = build_channels_update(result.new_channels);
4381 update.delete_channels = result
4382 .removed_channels
4383 .into_iter()
4384 .map(|id| id.to_proto())
4385 .collect();
4386 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4387
4388 for connection_id in connection_pool.user_connection_ids(user_id) {
4389 peer.send(connection_id, user_channels_update.clone())
4390 .trace_err();
4391 peer.send(connection_id, update.clone()).trace_err();
4392 }
4393}
4394
4395fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4396 proto::UpdateUserChannels {
4397 channel_memberships: channels
4398 .channel_memberships
4399 .iter()
4400 .map(|m| proto::ChannelMembership {
4401 channel_id: m.channel_id.to_proto(),
4402 role: m.role.into(),
4403 })
4404 .collect(),
4405 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4406 observed_channel_message_id: channels.observed_channel_messages.clone(),
4407 }
4408}
4409
4410fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4411 let mut update = proto::UpdateChannels::default();
4412
4413 for channel in channels.channels {
4414 update.channels.push(channel.to_proto());
4415 }
4416
4417 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4418 update.latest_channel_message_ids = channels.latest_channel_messages;
4419
4420 for (channel_id, participants) in channels.channel_participants {
4421 update
4422 .channel_participants
4423 .push(proto::ChannelParticipants {
4424 channel_id: channel_id.to_proto(),
4425 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4426 });
4427 }
4428
4429 for channel in channels.invited_channels {
4430 update.channel_invitations.push(channel.to_proto());
4431 }
4432
4433 update
4434}
4435
4436fn build_initial_contacts_update(
4437 contacts: Vec<db::Contact>,
4438 pool: &ConnectionPool,
4439) -> proto::UpdateContacts {
4440 let mut update = proto::UpdateContacts::default();
4441
4442 for contact in contacts {
4443 match contact {
4444 db::Contact::Accepted { user_id, busy } => {
4445 update.contacts.push(contact_for_user(user_id, busy, pool));
4446 }
4447 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4448 db::Contact::Incoming { user_id } => {
4449 update
4450 .incoming_requests
4451 .push(proto::IncomingContactRequest {
4452 requester_id: user_id.to_proto(),
4453 })
4454 }
4455 }
4456 }
4457
4458 update
4459}
4460
4461fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4462 proto::Contact {
4463 user_id: user_id.to_proto(),
4464 online: pool.is_user_online(user_id),
4465 busy,
4466 }
4467}
4468
4469fn room_updated(room: &proto::Room, peer: &Peer) {
4470 broadcast(
4471 None,
4472 room.participants
4473 .iter()
4474 .filter_map(|participant| Some(participant.peer_id?.into())),
4475 |peer_id| {
4476 peer.send(
4477 peer_id,
4478 proto::RoomUpdated {
4479 room: Some(room.clone()),
4480 },
4481 )
4482 },
4483 );
4484}
4485
4486fn channel_updated(
4487 channel: &db::channel::Model,
4488 room: &proto::Room,
4489 peer: &Peer,
4490 pool: &ConnectionPool,
4491) {
4492 let participants = room
4493 .participants
4494 .iter()
4495 .map(|p| p.user_id)
4496 .collect::<Vec<_>>();
4497
4498 broadcast(
4499 None,
4500 pool.channel_connection_ids(channel.root_id())
4501 .filter_map(|(channel_id, role)| {
4502 role.can_see_channel(channel.visibility)
4503 .then_some(channel_id)
4504 }),
4505 |peer_id| {
4506 peer.send(
4507 peer_id,
4508 proto::UpdateChannels {
4509 channel_participants: vec![proto::ChannelParticipants {
4510 channel_id: channel.id.to_proto(),
4511 participant_user_ids: participants.clone(),
4512 }],
4513 ..Default::default()
4514 },
4515 )
4516 },
4517 );
4518}
4519
4520async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4521 let db = session.db().await;
4522
4523 let contacts = db.get_contacts(user_id).await?;
4524 let busy = db.is_user_busy(user_id).await?;
4525
4526 let pool = session.connection_pool().await;
4527 let updated_contact = contact_for_user(user_id, busy, &pool);
4528 for contact in contacts {
4529 if let db::Contact::Accepted {
4530 user_id: contact_user_id,
4531 ..
4532 } = contact
4533 {
4534 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4535 session
4536 .peer
4537 .send(
4538 contact_conn_id,
4539 proto::UpdateContacts {
4540 contacts: vec![updated_contact.clone()],
4541 remove_contacts: Default::default(),
4542 incoming_requests: Default::default(),
4543 remove_incoming_requests: Default::default(),
4544 outgoing_requests: Default::default(),
4545 remove_outgoing_requests: Default::default(),
4546 },
4547 )
4548 .trace_err();
4549 }
4550 }
4551 }
4552 Ok(())
4553}
4554
4555async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4556 let mut contacts_to_update = HashSet::default();
4557
4558 let room_id;
4559 let canceled_calls_to_user_ids;
4560 let livekit_room;
4561 let delete_livekit_room;
4562 let room;
4563 let channel;
4564
4565 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4566 contacts_to_update.insert(session.user_id());
4567
4568 for project in left_room.left_projects.values() {
4569 project_left(project, session);
4570 }
4571
4572 room_id = RoomId::from_proto(left_room.room.id);
4573 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4574 livekit_room = mem::take(&mut left_room.room.livekit_room);
4575 delete_livekit_room = left_room.deleted;
4576 room = mem::take(&mut left_room.room);
4577 channel = mem::take(&mut left_room.channel);
4578
4579 room_updated(&room, &session.peer);
4580 } else {
4581 return Ok(());
4582 }
4583
4584 if let Some(channel) = channel {
4585 channel_updated(
4586 &channel,
4587 &room,
4588 &session.peer,
4589 &*session.connection_pool().await,
4590 );
4591 }
4592
4593 {
4594 let pool = session.connection_pool().await;
4595 for canceled_user_id in canceled_calls_to_user_ids {
4596 for connection_id in pool.user_connection_ids(canceled_user_id) {
4597 session
4598 .peer
4599 .send(
4600 connection_id,
4601 proto::CallCanceled {
4602 room_id: room_id.to_proto(),
4603 },
4604 )
4605 .trace_err();
4606 }
4607 contacts_to_update.insert(canceled_user_id);
4608 }
4609 }
4610
4611 for contact_user_id in contacts_to_update {
4612 update_user_contacts(contact_user_id, session).await?;
4613 }
4614
4615 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4616 live_kit
4617 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4618 .await
4619 .trace_err();
4620
4621 if delete_livekit_room {
4622 live_kit.delete_room(livekit_room).await.trace_err();
4623 }
4624 }
4625
4626 Ok(())
4627}
4628
4629async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4630 let left_channel_buffers = session
4631 .db()
4632 .await
4633 .leave_channel_buffers(session.connection_id)
4634 .await?;
4635
4636 for left_buffer in left_channel_buffers {
4637 channel_buffer_updated(
4638 session.connection_id,
4639 left_buffer.connections,
4640 &proto::UpdateChannelBufferCollaborators {
4641 channel_id: left_buffer.channel_id.to_proto(),
4642 collaborators: left_buffer.collaborators,
4643 },
4644 &session.peer,
4645 );
4646 }
4647
4648 Ok(())
4649}
4650
4651fn project_left(project: &db::LeftProject, session: &Session) {
4652 for connection_id in &project.connection_ids {
4653 if project.should_unshare {
4654 session
4655 .peer
4656 .send(
4657 *connection_id,
4658 proto::UnshareProject {
4659 project_id: project.id.to_proto(),
4660 },
4661 )
4662 .trace_err();
4663 } else {
4664 session
4665 .peer
4666 .send(
4667 *connection_id,
4668 proto::RemoveProjectCollaborator {
4669 project_id: project.id.to_proto(),
4670 peer_id: Some(session.connection_id.into()),
4671 },
4672 )
4673 .trace_err();
4674 }
4675 }
4676}
4677
4678pub trait ResultExt {
4679 type Ok;
4680
4681 fn trace_err(self) -> Option<Self::Ok>;
4682}
4683
4684impl<T, E> ResultExt for Result<T, E>
4685where
4686 E: std::fmt::Debug,
4687{
4688 type Ok = T;
4689
4690 #[track_caller]
4691 fn trace_err(self) -> Option<T> {
4692 match self {
4693 Ok(value) => Some(value),
4694 Err(error) => {
4695 tracing::error!("{:?}", error);
4696 None
4697 }
4698 }
4699 }
4700}