1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::headers::UserAgent;
27use axum::{
28 Extension, Router, TypedHeader,
29 body::Body,
30 extract::{
31 ConnectInfo, WebSocketUpgrade,
32 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
33 },
34 headers::{Header, HeaderName},
35 http::StatusCode,
36 middleware,
37 response::IntoResponse,
38 routing::get,
39};
40use chrono::Utc;
41use collections::{HashMap, HashSet};
42pub use connection_pool::{ConnectionPool, ZedVersion};
43use core::fmt::{self, Debug, Formatter};
44use reqwest_client::ReqwestClient;
45use rpc::proto::{MultiLspQuery, split_repository_update};
46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
47
48use futures::{
49 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
50 stream::FuturesUnordered,
51};
52use prometheus::{IntGauge, register_int_gauge};
53use rpc::{
54 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
55 proto::{
56 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
57 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
58 },
59};
60use semantic_version::SemanticVersion;
61use serde::{Serialize, Serializer};
62use std::{
63 any::TypeId,
64 future::Future,
65 marker::PhantomData,
66 mem,
67 net::SocketAddr,
68 ops::{Deref, DerefMut},
69 rc::Rc,
70 sync::{
71 Arc, OnceLock,
72 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
73 },
74 time::{Duration, Instant},
75};
76use time::OffsetDateTime;
77use tokio::sync::{Semaphore, watch};
78use tower::ServiceBuilder;
79use tracing::{
80 Instrument,
81 field::{self},
82 info_span, instrument,
83};
84
85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
86
87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
89
90const MESSAGE_COUNT_PER_PAGE: usize = 100;
91const MAX_MESSAGE_LEN: usize = 1024;
92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
94
95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
96
97type MessageHandler =
98 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
99
100pub struct ConnectionGuard;
101
102impl ConnectionGuard {
103 pub fn try_acquire() -> Result<Self, ()> {
104 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
105 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
106 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
107 tracing::error!(
108 "too many concurrent connections: {}",
109 current_connections + 1
110 );
111 return Err(());
112 }
113 Ok(ConnectionGuard)
114 }
115}
116
117impl Drop for ConnectionGuard {
118 fn drop(&mut self) {
119 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
120 }
121}
122
123struct Response<R> {
124 peer: Arc<Peer>,
125 receipt: Receipt<R>,
126 responded: Arc<AtomicBool>,
127}
128
129impl<R: RequestMessage> Response<R> {
130 fn send(self, payload: R::Response) -> Result<()> {
131 self.responded.store(true, SeqCst);
132 self.peer.respond(self.receipt, payload)?;
133 Ok(())
134 }
135}
136
137#[derive(Clone, Debug)]
138pub enum Principal {
139 User(User),
140 Impersonated { user: User, admin: User },
141}
142
143impl Principal {
144 fn user(&self) -> &User {
145 match self {
146 Principal::User(user) => user,
147 Principal::Impersonated { user, .. } => user,
148 }
149 }
150
151 fn update_span(&self, span: &tracing::Span) {
152 match &self {
153 Principal::User(user) => {
154 span.record("user_id", user.id.0);
155 span.record("login", &user.github_login);
156 }
157 Principal::Impersonated { user, admin } => {
158 span.record("user_id", user.id.0);
159 span.record("login", &user.github_login);
160 span.record("impersonator", &admin.github_login);
161 }
162 }
163 }
164}
165
166#[derive(Clone)]
167struct Session {
168 principal: Principal,
169 connection_id: ConnectionId,
170 db: Arc<tokio::sync::Mutex<DbHandle>>,
171 peer: Arc<Peer>,
172 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
173 app_state: Arc<AppState>,
174 supermaven_client: Option<Arc<SupermavenAdminApi>>,
175 /// The GeoIP country code for the user.
176 #[allow(unused)]
177 geoip_country_code: Option<String>,
178 system_id: Option<String>,
179 _executor: Executor,
180}
181
182impl Session {
183 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
184 #[cfg(test)]
185 tokio::task::yield_now().await;
186 let guard = self.db.lock().await;
187 #[cfg(test)]
188 tokio::task::yield_now().await;
189 guard
190 }
191
192 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
193 #[cfg(test)]
194 tokio::task::yield_now().await;
195 let guard = self.connection_pool.lock();
196 ConnectionPoolGuard {
197 guard,
198 _not_send: PhantomData,
199 }
200 }
201
202 fn is_staff(&self) -> bool {
203 match &self.principal {
204 Principal::User(user) => user.admin,
205 Principal::Impersonated { .. } => true,
206 }
207 }
208
209 fn user_id(&self) -> UserId {
210 match &self.principal {
211 Principal::User(user) => user.id,
212 Principal::Impersonated { user, .. } => user.id,
213 }
214 }
215
216 pub fn email(&self) -> Option<String> {
217 match &self.principal {
218 Principal::User(user) => user.email_address.clone(),
219 Principal::Impersonated { user, .. } => user.email_address.clone(),
220 }
221 }
222}
223
224impl Debug for Session {
225 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
226 let mut result = f.debug_struct("Session");
227 match &self.principal {
228 Principal::User(user) => {
229 result.field("user", &user.github_login);
230 }
231 Principal::Impersonated { user, admin } => {
232 result.field("user", &user.github_login);
233 result.field("impersonator", &admin.github_login);
234 }
235 }
236 result.field("connection_id", &self.connection_id).finish()
237 }
238}
239
240struct DbHandle(Arc<Database>);
241
242impl Deref for DbHandle {
243 type Target = Database;
244
245 fn deref(&self) -> &Self::Target {
246 self.0.as_ref()
247 }
248}
249
250pub struct Server {
251 id: parking_lot::Mutex<ServerId>,
252 peer: Arc<Peer>,
253 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
254 app_state: Arc<AppState>,
255 handlers: HashMap<TypeId, MessageHandler>,
256 teardown: watch::Sender<bool>,
257}
258
259pub(crate) struct ConnectionPoolGuard<'a> {
260 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
261 _not_send: PhantomData<Rc<()>>,
262}
263
264#[derive(Serialize)]
265pub struct ServerSnapshot<'a> {
266 peer: &'a Peer,
267 #[serde(serialize_with = "serialize_deref")]
268 connection_pool: ConnectionPoolGuard<'a>,
269}
270
271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
272where
273 S: Serializer,
274 T: Deref<Target = U>,
275 U: Serialize,
276{
277 Serialize::serialize(value.deref(), serializer)
278}
279
280impl Server {
281 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
282 let mut server = Self {
283 id: parking_lot::Mutex::new(id),
284 peer: Peer::new(id.0 as u32),
285 app_state: app_state.clone(),
286 connection_pool: Default::default(),
287 handlers: Default::default(),
288 teardown: watch::channel(false).0,
289 };
290
291 server
292 .add_request_handler(ping)
293 .add_request_handler(create_room)
294 .add_request_handler(join_room)
295 .add_request_handler(rejoin_room)
296 .add_request_handler(leave_room)
297 .add_request_handler(set_room_participant_role)
298 .add_request_handler(call)
299 .add_request_handler(cancel_call)
300 .add_message_handler(decline_call)
301 .add_request_handler(update_participant_location)
302 .add_request_handler(share_project)
303 .add_message_handler(unshare_project)
304 .add_request_handler(join_project)
305 .add_message_handler(leave_project)
306 .add_request_handler(update_project)
307 .add_request_handler(update_worktree)
308 .add_request_handler(update_repository)
309 .add_request_handler(remove_repository)
310 .add_message_handler(start_language_server)
311 .add_message_handler(update_language_server)
312 .add_message_handler(update_diagnostic_summary)
313 .add_message_handler(update_worktree_settings)
314 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
316 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
317 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
318 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
319 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
320 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
321 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
322 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
323 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
324 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
325 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
326 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
327 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
328 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
329 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
330 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
331 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
332 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
333 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
334 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
335 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
336 .add_request_handler(
337 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
338 )
339 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
340 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
341 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
342 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
344 .add_request_handler(
345 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
346 )
347 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
348 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
349 .add_request_handler(
350 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
351 )
352 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
353 .add_request_handler(
354 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
355 )
356 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
357 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
358 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
359 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
360 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
361 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
362 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
363 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
364 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
365 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
368 .add_request_handler(
369 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
370 )
371 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
372 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
373 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
374 .add_request_handler(multi_lsp_query)
375 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
376 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
377 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
378 .add_message_handler(create_buffer_for_peer)
379 .add_request_handler(update_buffer)
380 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
381 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
382 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
383 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
384 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
385 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
386 .add_message_handler(
387 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
388 )
389 .add_request_handler(get_users)
390 .add_request_handler(fuzzy_search_users)
391 .add_request_handler(request_contact)
392 .add_request_handler(remove_contact)
393 .add_request_handler(respond_to_contact_request)
394 .add_message_handler(subscribe_to_channels)
395 .add_request_handler(create_channel)
396 .add_request_handler(delete_channel)
397 .add_request_handler(invite_channel_member)
398 .add_request_handler(remove_channel_member)
399 .add_request_handler(set_channel_member_role)
400 .add_request_handler(set_channel_visibility)
401 .add_request_handler(rename_channel)
402 .add_request_handler(join_channel_buffer)
403 .add_request_handler(leave_channel_buffer)
404 .add_message_handler(update_channel_buffer)
405 .add_request_handler(rejoin_channel_buffers)
406 .add_request_handler(get_channel_members)
407 .add_request_handler(respond_to_channel_invite)
408 .add_request_handler(join_channel)
409 .add_request_handler(join_channel_chat)
410 .add_message_handler(leave_channel_chat)
411 .add_request_handler(send_channel_message)
412 .add_request_handler(remove_channel_message)
413 .add_request_handler(update_channel_message)
414 .add_request_handler(get_channel_messages)
415 .add_request_handler(get_channel_messages_by_id)
416 .add_request_handler(get_notifications)
417 .add_request_handler(mark_notification_as_read)
418 .add_request_handler(move_channel)
419 .add_request_handler(reorder_channel)
420 .add_request_handler(follow)
421 .add_message_handler(unfollow)
422 .add_message_handler(update_followers)
423 .add_request_handler(get_private_user_info)
424 .add_request_handler(get_llm_api_token)
425 .add_request_handler(accept_terms_of_service)
426 .add_message_handler(acknowledge_channel_message)
427 .add_message_handler(acknowledge_buffer_version)
428 .add_request_handler(get_supermaven_api_key)
429 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
430 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
431 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
432 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
433 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
434 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
435 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
436 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
437 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
438 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
439 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
440 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
441 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
442 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
443 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
444 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
445 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
446 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
447 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
448 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
449 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
450 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
451 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
452 .add_message_handler(update_context);
453
454 Arc::new(server)
455 }
456
457 pub async fn start(&self) -> Result<()> {
458 let server_id = *self.id.lock();
459 let app_state = self.app_state.clone();
460 let peer = self.peer.clone();
461 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
462 let pool = self.connection_pool.clone();
463 let livekit_client = self.app_state.livekit_client.clone();
464
465 let span = info_span!("start server");
466 self.app_state.executor.spawn_detached(
467 async move {
468 tracing::info!("waiting for cleanup timeout");
469 timeout.await;
470 tracing::info!("cleanup timeout expired, retrieving stale rooms");
471
472 app_state
473 .db
474 .delete_stale_channel_chat_participants(
475 &app_state.config.zed_environment,
476 server_id,
477 )
478 .await
479 .trace_err();
480
481 if let Some((room_ids, channel_ids)) = app_state
482 .db
483 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
484 .await
485 .trace_err()
486 {
487 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
488 tracing::info!(
489 stale_channel_buffer_count = channel_ids.len(),
490 "retrieved stale channel buffers"
491 );
492
493 for channel_id in channel_ids {
494 if let Some(refreshed_channel_buffer) = app_state
495 .db
496 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
497 .await
498 .trace_err()
499 {
500 for connection_id in refreshed_channel_buffer.connection_ids {
501 peer.send(
502 connection_id,
503 proto::UpdateChannelBufferCollaborators {
504 channel_id: channel_id.to_proto(),
505 collaborators: refreshed_channel_buffer
506 .collaborators
507 .clone(),
508 },
509 )
510 .trace_err();
511 }
512 }
513 }
514
515 for room_id in room_ids {
516 let mut contacts_to_update = HashSet::default();
517 let mut canceled_calls_to_user_ids = Vec::new();
518 let mut livekit_room = String::new();
519 let mut delete_livekit_room = false;
520
521 if let Some(mut refreshed_room) = app_state
522 .db
523 .clear_stale_room_participants(room_id, server_id)
524 .await
525 .trace_err()
526 {
527 tracing::info!(
528 room_id = room_id.0,
529 new_participant_count = refreshed_room.room.participants.len(),
530 "refreshed room"
531 );
532 room_updated(&refreshed_room.room, &peer);
533 if let Some(channel) = refreshed_room.channel.as_ref() {
534 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
535 }
536 contacts_to_update
537 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
538 contacts_to_update
539 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
540 canceled_calls_to_user_ids =
541 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
542 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
543 delete_livekit_room = refreshed_room.room.participants.is_empty();
544 }
545
546 {
547 let pool = pool.lock();
548 for canceled_user_id in canceled_calls_to_user_ids {
549 for connection_id in pool.user_connection_ids(canceled_user_id) {
550 peer.send(
551 connection_id,
552 proto::CallCanceled {
553 room_id: room_id.to_proto(),
554 },
555 )
556 .trace_err();
557 }
558 }
559 }
560
561 for user_id in contacts_to_update {
562 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
563 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
564 if let Some((busy, contacts)) = busy.zip(contacts) {
565 let pool = pool.lock();
566 let updated_contact = contact_for_user(user_id, busy, &pool);
567 for contact in contacts {
568 if let db::Contact::Accepted {
569 user_id: contact_user_id,
570 ..
571 } = contact
572 {
573 for contact_conn_id in
574 pool.user_connection_ids(contact_user_id)
575 {
576 peer.send(
577 contact_conn_id,
578 proto::UpdateContacts {
579 contacts: vec![updated_contact.clone()],
580 remove_contacts: Default::default(),
581 incoming_requests: Default::default(),
582 remove_incoming_requests: Default::default(),
583 outgoing_requests: Default::default(),
584 remove_outgoing_requests: Default::default(),
585 },
586 )
587 .trace_err();
588 }
589 }
590 }
591 }
592 }
593
594 if let Some(live_kit) = livekit_client.as_ref() {
595 if delete_livekit_room {
596 live_kit.delete_room(livekit_room).await.trace_err();
597 }
598 }
599 }
600 }
601
602 app_state
603 .db
604 .delete_stale_channel_chat_participants(
605 &app_state.config.zed_environment,
606 server_id,
607 )
608 .await
609 .trace_err();
610
611 app_state
612 .db
613 .clear_old_worktree_entries(server_id)
614 .await
615 .trace_err();
616
617 app_state
618 .db
619 .delete_stale_servers(&app_state.config.zed_environment, server_id)
620 .await
621 .trace_err();
622 }
623 .instrument(span),
624 );
625 Ok(())
626 }
627
628 pub fn teardown(&self) {
629 self.peer.teardown();
630 self.connection_pool.lock().reset();
631 let _ = self.teardown.send(true);
632 }
633
634 #[cfg(test)]
635 pub fn reset(&self, id: ServerId) {
636 self.teardown();
637 *self.id.lock() = id;
638 self.peer.reset(id.0 as u32);
639 let _ = self.teardown.send(false);
640 }
641
642 #[cfg(test)]
643 pub fn id(&self) -> ServerId {
644 *self.id.lock()
645 }
646
647 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
648 where
649 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
650 Fut: 'static + Send + Future<Output = Result<()>>,
651 M: EnvelopedMessage,
652 {
653 let prev_handler = self.handlers.insert(
654 TypeId::of::<M>(),
655 Box::new(move |envelope, session| {
656 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
657 let received_at = envelope.received_at;
658 tracing::info!("message received");
659 let start_time = Instant::now();
660 let future = (handler)(*envelope, session);
661 async move {
662 let result = future.await;
663 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
664 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
665 let queue_duration_ms = total_duration_ms - processing_duration_ms;
666
667 match result {
668 Err(error) => {
669 tracing::error!(
670 ?error,
671 total_duration_ms,
672 processing_duration_ms,
673 queue_duration_ms,
674 "error handling message"
675 )
676 }
677 Ok(()) => tracing::info!(
678 total_duration_ms,
679 processing_duration_ms,
680 queue_duration_ms,
681 "finished handling message"
682 ),
683 }
684 }
685 .boxed()
686 }),
687 );
688 if prev_handler.is_some() {
689 panic!("registered a handler for the same message twice");
690 }
691 self
692 }
693
694 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
695 where
696 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
697 Fut: 'static + Send + Future<Output = Result<()>>,
698 M: EnvelopedMessage,
699 {
700 self.add_handler(move |envelope, session| handler(envelope.payload, session));
701 self
702 }
703
704 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
705 where
706 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
707 Fut: Send + Future<Output = Result<()>>,
708 M: RequestMessage,
709 {
710 let handler = Arc::new(handler);
711 self.add_handler(move |envelope, session| {
712 let receipt = envelope.receipt();
713 let handler = handler.clone();
714 async move {
715 let peer = session.peer.clone();
716 let responded = Arc::new(AtomicBool::default());
717 let response = Response {
718 peer: peer.clone(),
719 responded: responded.clone(),
720 receipt,
721 };
722 match (handler)(envelope.payload, response, session).await {
723 Ok(()) => {
724 if responded.load(std::sync::atomic::Ordering::SeqCst) {
725 Ok(())
726 } else {
727 Err(anyhow!("handler did not send a response"))?
728 }
729 }
730 Err(error) => {
731 let proto_err = match &error {
732 Error::Internal(err) => err.to_proto(),
733 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
734 };
735 peer.respond_with_error(receipt, proto_err)?;
736 Err(error)
737 }
738 }
739 }
740 })
741 }
742
743 pub fn handle_connection(
744 self: &Arc<Self>,
745 connection: Connection,
746 address: String,
747 principal: Principal,
748 zed_version: ZedVersion,
749 release_channel: Option<String>,
750 user_agent: Option<String>,
751 geoip_country_code: Option<String>,
752 system_id: Option<String>,
753 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
754 executor: Executor,
755 connection_guard: Option<ConnectionGuard>,
756 ) -> impl Future<Output = ()> + use<> {
757 let this = self.clone();
758 let span = info_span!("handle connection", %address,
759 connection_id=field::Empty,
760 user_id=field::Empty,
761 login=field::Empty,
762 impersonator=field::Empty,
763 user_agent=field::Empty,
764 geoip_country_code=field::Empty,
765 release_channel=field::Empty,
766 );
767 principal.update_span(&span);
768 if let Some(user_agent) = user_agent {
769 span.record("user_agent", user_agent);
770 }
771 if let Some(release_channel) = release_channel {
772 span.record("release_channel", release_channel);
773 }
774
775 if let Some(country_code) = geoip_country_code.as_ref() {
776 span.record("geoip_country_code", country_code);
777 }
778
779 let mut teardown = self.teardown.subscribe();
780 async move {
781 if *teardown.borrow() {
782 tracing::error!("server is tearing down");
783 return;
784 }
785
786 let (connection_id, handle_io, mut incoming_rx) =
787 this.peer.add_connection(connection, {
788 let executor = executor.clone();
789 move |duration| executor.sleep(duration)
790 });
791 tracing::Span::current().record("connection_id", format!("{}", connection_id));
792
793 tracing::info!("connection opened");
794
795 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
796 let http_client = match ReqwestClient::user_agent(&user_agent) {
797 Ok(http_client) => Arc::new(http_client),
798 Err(error) => {
799 tracing::error!(?error, "failed to create HTTP client");
800 return;
801 }
802 };
803
804 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
805 |supermaven_admin_api_key| {
806 Arc::new(SupermavenAdminApi::new(
807 supermaven_admin_api_key.to_string(),
808 http_client.clone(),
809 ))
810 },
811 );
812
813 let session = Session {
814 principal: principal.clone(),
815 connection_id,
816 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
817 peer: this.peer.clone(),
818 connection_pool: this.connection_pool.clone(),
819 app_state: this.app_state.clone(),
820 geoip_country_code,
821 system_id,
822 _executor: executor.clone(),
823 supermaven_client,
824 };
825
826 if let Err(error) = this
827 .send_initial_client_update(
828 connection_id,
829 zed_version,
830 send_connection_id,
831 &session,
832 )
833 .await
834 {
835 tracing::error!(?error, "failed to send initial client update");
836 return;
837 }
838 drop(connection_guard);
839
840 let handle_io = handle_io.fuse();
841 futures::pin_mut!(handle_io);
842
843 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
844 // This prevents deadlocks when e.g., client A performs a request to client B and
845 // client B performs a request to client A. If both clients stop processing further
846 // messages until their respective request completes, they won't have a chance to
847 // respond to the other client's request and cause a deadlock.
848 //
849 // This arrangement ensures we will attempt to process earlier messages first, but fall
850 // back to processing messages arrived later in the spirit of making progress.
851 const MAX_CONCURRENT_HANDLERS: usize = 256;
852 let mut foreground_message_handlers = FuturesUnordered::new();
853 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
854 let get_concurrent_handlers = {
855 let concurrent_handlers = concurrent_handlers.clone();
856 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
857 };
858 loop {
859 let next_message = async {
860 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
861 let message = incoming_rx.next().await;
862 // Cache the concurrent_handlers here, so that we know what the
863 // queue looks like as each handler starts
864 (permit, message, get_concurrent_handlers())
865 }
866 .fuse();
867 futures::pin_mut!(next_message);
868 futures::select_biased! {
869 _ = teardown.changed().fuse() => return,
870 result = handle_io => {
871 if let Err(error) = result {
872 tracing::error!(?error, "error handling I/O");
873 }
874 break;
875 }
876 _ = foreground_message_handlers.next() => {}
877 next_message = next_message => {
878 let (permit, message, concurrent_handlers) = next_message;
879 if let Some(message) = message {
880 let type_name = message.payload_type_name();
881 // note: we copy all the fields from the parent span so we can query them in the logs.
882 // (https://github.com/tokio-rs/tracing/issues/2670).
883 let span = tracing::info_span!("receive message",
884 %connection_id,
885 %address,
886 type_name,
887 concurrent_handlers,
888 user_id=field::Empty,
889 login=field::Empty,
890 impersonator=field::Empty,
891 multi_lsp_query_request=field::Empty,
892 );
893 principal.update_span(&span);
894 let span_enter = span.enter();
895 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
896 let is_background = message.is_background();
897 let handle_message = (handler)(message, session.clone());
898 drop(span_enter);
899
900 let handle_message = async move {
901 handle_message.await;
902 drop(permit);
903 }.instrument(span);
904 if is_background {
905 executor.spawn_detached(handle_message);
906 } else {
907 foreground_message_handlers.push(handle_message);
908 }
909 } else {
910 tracing::error!("no message handler");
911 }
912 } else {
913 tracing::info!("connection closed");
914 break;
915 }
916 }
917 }
918 }
919
920 drop(foreground_message_handlers);
921 let concurrent_handlers = get_concurrent_handlers();
922 tracing::info!(concurrent_handlers, "signing out");
923 if let Err(error) = connection_lost(session, teardown, executor).await {
924 tracing::error!(?error, "error signing out");
925 }
926 }
927 .instrument(span)
928 }
929
930 async fn send_initial_client_update(
931 &self,
932 connection_id: ConnectionId,
933 zed_version: ZedVersion,
934 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
935 session: &Session,
936 ) -> Result<()> {
937 self.peer.send(
938 connection_id,
939 proto::Hello {
940 peer_id: Some(connection_id.into()),
941 },
942 )?;
943 tracing::info!("sent hello message");
944 if let Some(send_connection_id) = send_connection_id.take() {
945 let _ = send_connection_id.send(connection_id);
946 }
947
948 match &session.principal {
949 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
950 if !user.connected_once {
951 self.peer.send(connection_id, proto::ShowContacts {})?;
952 self.app_state
953 .db
954 .set_user_connected_once(user.id, true)
955 .await?;
956 }
957
958 update_user_plan(session).await?;
959
960 let contacts = self.app_state.db.get_contacts(user.id).await?;
961
962 {
963 let mut pool = self.connection_pool.lock();
964 pool.add_connection(connection_id, user.id, user.admin, zed_version);
965 self.peer.send(
966 connection_id,
967 build_initial_contacts_update(contacts, &pool),
968 )?;
969 }
970
971 if should_auto_subscribe_to_channels(zed_version) {
972 subscribe_user_to_channels(user.id, session).await?;
973 }
974
975 if let Some(incoming_call) =
976 self.app_state.db.incoming_call_for_user(user.id).await?
977 {
978 self.peer.send(connection_id, incoming_call)?;
979 }
980
981 update_user_contacts(user.id, session).await?;
982 }
983 }
984
985 Ok(())
986 }
987
988 pub async fn invite_code_redeemed(
989 self: &Arc<Self>,
990 inviter_id: UserId,
991 invitee_id: UserId,
992 ) -> Result<()> {
993 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
994 if let Some(code) = &user.invite_code {
995 let pool = self.connection_pool.lock();
996 let invitee_contact = contact_for_user(invitee_id, false, &pool);
997 for connection_id in pool.user_connection_ids(inviter_id) {
998 self.peer.send(
999 connection_id,
1000 proto::UpdateContacts {
1001 contacts: vec![invitee_contact.clone()],
1002 ..Default::default()
1003 },
1004 )?;
1005 self.peer.send(
1006 connection_id,
1007 proto::UpdateInviteInfo {
1008 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1009 count: user.invite_count as u32,
1010 },
1011 )?;
1012 }
1013 }
1014 }
1015 Ok(())
1016 }
1017
1018 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1019 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1020 if let Some(invite_code) = &user.invite_code {
1021 let pool = self.connection_pool.lock();
1022 for connection_id in pool.user_connection_ids(user_id) {
1023 self.peer.send(
1024 connection_id,
1025 proto::UpdateInviteInfo {
1026 url: format!(
1027 "{}{}",
1028 self.app_state.config.invite_link_prefix, invite_code
1029 ),
1030 count: user.invite_count as u32,
1031 },
1032 )?;
1033 }
1034 }
1035 }
1036 Ok(())
1037 }
1038
1039 pub async fn update_plan_for_user(
1040 self: &Arc<Self>,
1041 user_id: UserId,
1042 update_user_plan: proto::UpdateUserPlan,
1043 ) -> Result<()> {
1044 let pool = self.connection_pool.lock();
1045 for connection_id in pool.user_connection_ids(user_id) {
1046 self.peer
1047 .send(connection_id, update_user_plan.clone())
1048 .trace_err();
1049 }
1050
1051 Ok(())
1052 }
1053
1054 /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1055 /// message on the Collab server.
1056 ///
1057 /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1058 pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1059 let user = self
1060 .app_state
1061 .db
1062 .get_user_by_id(user_id)
1063 .await?
1064 .context("user not found")?;
1065
1066 let update_user_plan = make_update_user_plan_message(
1067 &user,
1068 user.admin,
1069 &self.app_state.db,
1070 self.app_state.llm_db.clone(),
1071 )
1072 .await?;
1073
1074 self.update_plan_for_user(user_id, update_user_plan).await
1075 }
1076
1077 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1078 let pool = self.connection_pool.lock();
1079 for connection_id in pool.user_connection_ids(user_id) {
1080 self.peer
1081 .send(connection_id, proto::RefreshLlmToken {})
1082 .trace_err();
1083 }
1084 }
1085
1086 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1087 ServerSnapshot {
1088 connection_pool: ConnectionPoolGuard {
1089 guard: self.connection_pool.lock(),
1090 _not_send: PhantomData,
1091 },
1092 peer: &self.peer,
1093 }
1094 }
1095}
1096
1097impl Deref for ConnectionPoolGuard<'_> {
1098 type Target = ConnectionPool;
1099
1100 fn deref(&self) -> &Self::Target {
1101 &self.guard
1102 }
1103}
1104
1105impl DerefMut for ConnectionPoolGuard<'_> {
1106 fn deref_mut(&mut self) -> &mut Self::Target {
1107 &mut self.guard
1108 }
1109}
1110
1111impl Drop for ConnectionPoolGuard<'_> {
1112 fn drop(&mut self) {
1113 #[cfg(test)]
1114 self.check_invariants();
1115 }
1116}
1117
1118fn broadcast<F>(
1119 sender_id: Option<ConnectionId>,
1120 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1121 mut f: F,
1122) where
1123 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1124{
1125 for receiver_id in receiver_ids {
1126 if Some(receiver_id) != sender_id {
1127 if let Err(error) = f(receiver_id) {
1128 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1129 }
1130 }
1131 }
1132}
1133
1134pub struct ProtocolVersion(u32);
1135
1136impl Header for ProtocolVersion {
1137 fn name() -> &'static HeaderName {
1138 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1139 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1140 }
1141
1142 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1143 where
1144 Self: Sized,
1145 I: Iterator<Item = &'i axum::http::HeaderValue>,
1146 {
1147 let version = values
1148 .next()
1149 .ok_or_else(axum::headers::Error::invalid)?
1150 .to_str()
1151 .map_err(|_| axum::headers::Error::invalid())?
1152 .parse()
1153 .map_err(|_| axum::headers::Error::invalid())?;
1154 Ok(Self(version))
1155 }
1156
1157 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1158 values.extend([self.0.to_string().parse().unwrap()]);
1159 }
1160}
1161
1162pub struct AppVersionHeader(SemanticVersion);
1163impl Header for AppVersionHeader {
1164 fn name() -> &'static HeaderName {
1165 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1166 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1167 }
1168
1169 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1170 where
1171 Self: Sized,
1172 I: Iterator<Item = &'i axum::http::HeaderValue>,
1173 {
1174 let version = values
1175 .next()
1176 .ok_or_else(axum::headers::Error::invalid)?
1177 .to_str()
1178 .map_err(|_| axum::headers::Error::invalid())?
1179 .parse()
1180 .map_err(|_| axum::headers::Error::invalid())?;
1181 Ok(Self(version))
1182 }
1183
1184 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1185 values.extend([self.0.to_string().parse().unwrap()]);
1186 }
1187}
1188
1189#[derive(Debug)]
1190pub struct ReleaseChannelHeader(String);
1191
1192impl Header for ReleaseChannelHeader {
1193 fn name() -> &'static HeaderName {
1194 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1195 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1196 }
1197
1198 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1199 where
1200 Self: Sized,
1201 I: Iterator<Item = &'i axum::http::HeaderValue>,
1202 {
1203 Ok(Self(
1204 values
1205 .next()
1206 .ok_or_else(axum::headers::Error::invalid)?
1207 .to_str()
1208 .map_err(|_| axum::headers::Error::invalid())?
1209 .to_owned(),
1210 ))
1211 }
1212
1213 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1214 values.extend([self.0.parse().unwrap()]);
1215 }
1216}
1217
1218pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1219 Router::new()
1220 .route("/rpc", get(handle_websocket_request))
1221 .layer(
1222 ServiceBuilder::new()
1223 .layer(Extension(server.app_state.clone()))
1224 .layer(middleware::from_fn(auth::validate_header)),
1225 )
1226 .route("/metrics", get(handle_metrics))
1227 .layer(Extension(server))
1228}
1229
1230pub async fn handle_websocket_request(
1231 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1232 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1233 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1234 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1235 Extension(server): Extension<Arc<Server>>,
1236 Extension(principal): Extension<Principal>,
1237 user_agent: Option<TypedHeader<UserAgent>>,
1238 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1239 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1240 ws: WebSocketUpgrade,
1241) -> axum::response::Response {
1242 if protocol_version != rpc::PROTOCOL_VERSION {
1243 return (
1244 StatusCode::UPGRADE_REQUIRED,
1245 "client must be upgraded".to_string(),
1246 )
1247 .into_response();
1248 }
1249
1250 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1251 return (
1252 StatusCode::UPGRADE_REQUIRED,
1253 "no version header found".to_string(),
1254 )
1255 .into_response();
1256 };
1257
1258 let release_channel = release_channel_header.map(|header| header.0.0);
1259
1260 if !version.can_collaborate() {
1261 return (
1262 StatusCode::UPGRADE_REQUIRED,
1263 "client must be upgraded".to_string(),
1264 )
1265 .into_response();
1266 }
1267
1268 let socket_address = socket_address.to_string();
1269
1270 // Acquire connection guard before WebSocket upgrade
1271 let connection_guard = match ConnectionGuard::try_acquire() {
1272 Ok(guard) => guard,
1273 Err(()) => {
1274 return (
1275 StatusCode::SERVICE_UNAVAILABLE,
1276 "Too many concurrent connections",
1277 )
1278 .into_response();
1279 }
1280 };
1281
1282 ws.on_upgrade(move |socket| {
1283 let socket = socket
1284 .map_ok(to_tungstenite_message)
1285 .err_into()
1286 .with(|message| async move { to_axum_message(message) });
1287 let connection = Connection::new(Box::pin(socket));
1288 async move {
1289 server
1290 .handle_connection(
1291 connection,
1292 socket_address,
1293 principal,
1294 version,
1295 release_channel,
1296 user_agent.map(|header| header.to_string()),
1297 country_code_header.map(|header| header.to_string()),
1298 system_id_header.map(|header| header.to_string()),
1299 None,
1300 Executor::Production,
1301 Some(connection_guard),
1302 )
1303 .await;
1304 }
1305 })
1306}
1307
1308pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1309 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1310 let connections_metric = CONNECTIONS_METRIC
1311 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1312
1313 let connections = server
1314 .connection_pool
1315 .lock()
1316 .connections()
1317 .filter(|connection| !connection.admin)
1318 .count();
1319 connections_metric.set(connections as _);
1320
1321 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1322 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1323 register_int_gauge!(
1324 "shared_projects",
1325 "number of open projects with one or more guests"
1326 )
1327 .unwrap()
1328 });
1329
1330 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1331 shared_projects_metric.set(shared_projects as _);
1332
1333 let encoder = prometheus::TextEncoder::new();
1334 let metric_families = prometheus::gather();
1335 let encoded_metrics = encoder
1336 .encode_to_string(&metric_families)
1337 .map_err(|err| anyhow!("{err}"))?;
1338 Ok(encoded_metrics)
1339}
1340
1341#[instrument(err, skip(executor))]
1342async fn connection_lost(
1343 session: Session,
1344 mut teardown: watch::Receiver<bool>,
1345 executor: Executor,
1346) -> Result<()> {
1347 session.peer.disconnect(session.connection_id);
1348 session
1349 .connection_pool()
1350 .await
1351 .remove_connection(session.connection_id)?;
1352
1353 session
1354 .db()
1355 .await
1356 .connection_lost(session.connection_id)
1357 .await
1358 .trace_err();
1359
1360 futures::select_biased! {
1361 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1362
1363 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1364 leave_room_for_session(&session, session.connection_id).await.trace_err();
1365 leave_channel_buffers_for_session(&session)
1366 .await
1367 .trace_err();
1368
1369 if !session
1370 .connection_pool()
1371 .await
1372 .is_user_online(session.user_id())
1373 {
1374 let db = session.db().await;
1375 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1376 room_updated(&room, &session.peer);
1377 }
1378 }
1379
1380 update_user_contacts(session.user_id(), &session).await?;
1381 },
1382 _ = teardown.changed().fuse() => {}
1383 }
1384
1385 Ok(())
1386}
1387
1388/// Acknowledges a ping from a client, used to keep the connection alive.
1389async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1390 response.send(proto::Ack {})?;
1391 Ok(())
1392}
1393
1394/// Creates a new room for calling (outside of channels)
1395async fn create_room(
1396 _request: proto::CreateRoom,
1397 response: Response<proto::CreateRoom>,
1398 session: Session,
1399) -> Result<()> {
1400 let livekit_room = nanoid::nanoid!(30);
1401
1402 let live_kit_connection_info = util::maybe!(async {
1403 let live_kit = session.app_state.livekit_client.as_ref();
1404 let live_kit = live_kit?;
1405 let user_id = session.user_id().to_string();
1406
1407 let token = live_kit
1408 .room_token(&livekit_room, &user_id.to_string())
1409 .trace_err()?;
1410
1411 Some(proto::LiveKitConnectionInfo {
1412 server_url: live_kit.url().into(),
1413 token,
1414 can_publish: true,
1415 })
1416 })
1417 .await;
1418
1419 let room = session
1420 .db()
1421 .await
1422 .create_room(session.user_id(), session.connection_id, &livekit_room)
1423 .await?;
1424
1425 response.send(proto::CreateRoomResponse {
1426 room: Some(room.clone()),
1427 live_kit_connection_info,
1428 })?;
1429
1430 update_user_contacts(session.user_id(), &session).await?;
1431 Ok(())
1432}
1433
1434/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1435async fn join_room(
1436 request: proto::JoinRoom,
1437 response: Response<proto::JoinRoom>,
1438 session: Session,
1439) -> Result<()> {
1440 let room_id = RoomId::from_proto(request.id);
1441
1442 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1443
1444 if let Some(channel_id) = channel_id {
1445 return join_channel_internal(channel_id, Box::new(response), session).await;
1446 }
1447
1448 let joined_room = {
1449 let room = session
1450 .db()
1451 .await
1452 .join_room(room_id, session.user_id(), session.connection_id)
1453 .await?;
1454 room_updated(&room.room, &session.peer);
1455 room.into_inner()
1456 };
1457
1458 for connection_id in session
1459 .connection_pool()
1460 .await
1461 .user_connection_ids(session.user_id())
1462 {
1463 session
1464 .peer
1465 .send(
1466 connection_id,
1467 proto::CallCanceled {
1468 room_id: room_id.to_proto(),
1469 },
1470 )
1471 .trace_err();
1472 }
1473
1474 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1475 {
1476 live_kit
1477 .room_token(
1478 &joined_room.room.livekit_room,
1479 &session.user_id().to_string(),
1480 )
1481 .trace_err()
1482 .map(|token| proto::LiveKitConnectionInfo {
1483 server_url: live_kit.url().into(),
1484 token,
1485 can_publish: true,
1486 })
1487 } else {
1488 None
1489 };
1490
1491 response.send(proto::JoinRoomResponse {
1492 room: Some(joined_room.room),
1493 channel_id: None,
1494 live_kit_connection_info,
1495 })?;
1496
1497 update_user_contacts(session.user_id(), &session).await?;
1498 Ok(())
1499}
1500
1501/// Rejoin room is used to reconnect to a room after connection errors.
1502async fn rejoin_room(
1503 request: proto::RejoinRoom,
1504 response: Response<proto::RejoinRoom>,
1505 session: Session,
1506) -> Result<()> {
1507 let room;
1508 let channel;
1509 {
1510 let mut rejoined_room = session
1511 .db()
1512 .await
1513 .rejoin_room(request, session.user_id(), session.connection_id)
1514 .await?;
1515
1516 response.send(proto::RejoinRoomResponse {
1517 room: Some(rejoined_room.room.clone()),
1518 reshared_projects: rejoined_room
1519 .reshared_projects
1520 .iter()
1521 .map(|project| proto::ResharedProject {
1522 id: project.id.to_proto(),
1523 collaborators: project
1524 .collaborators
1525 .iter()
1526 .map(|collaborator| collaborator.to_proto())
1527 .collect(),
1528 })
1529 .collect(),
1530 rejoined_projects: rejoined_room
1531 .rejoined_projects
1532 .iter()
1533 .map(|rejoined_project| rejoined_project.to_proto())
1534 .collect(),
1535 })?;
1536 room_updated(&rejoined_room.room, &session.peer);
1537
1538 for project in &rejoined_room.reshared_projects {
1539 for collaborator in &project.collaborators {
1540 session
1541 .peer
1542 .send(
1543 collaborator.connection_id,
1544 proto::UpdateProjectCollaborator {
1545 project_id: project.id.to_proto(),
1546 old_peer_id: Some(project.old_connection_id.into()),
1547 new_peer_id: Some(session.connection_id.into()),
1548 },
1549 )
1550 .trace_err();
1551 }
1552
1553 broadcast(
1554 Some(session.connection_id),
1555 project
1556 .collaborators
1557 .iter()
1558 .map(|collaborator| collaborator.connection_id),
1559 |connection_id| {
1560 session.peer.forward_send(
1561 session.connection_id,
1562 connection_id,
1563 proto::UpdateProject {
1564 project_id: project.id.to_proto(),
1565 worktrees: project.worktrees.clone(),
1566 },
1567 )
1568 },
1569 );
1570 }
1571
1572 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1573
1574 let rejoined_room = rejoined_room.into_inner();
1575
1576 room = rejoined_room.room;
1577 channel = rejoined_room.channel;
1578 }
1579
1580 if let Some(channel) = channel {
1581 channel_updated(
1582 &channel,
1583 &room,
1584 &session.peer,
1585 &*session.connection_pool().await,
1586 );
1587 }
1588
1589 update_user_contacts(session.user_id(), &session).await?;
1590 Ok(())
1591}
1592
1593fn notify_rejoined_projects(
1594 rejoined_projects: &mut Vec<RejoinedProject>,
1595 session: &Session,
1596) -> Result<()> {
1597 for project in rejoined_projects.iter() {
1598 for collaborator in &project.collaborators {
1599 session
1600 .peer
1601 .send(
1602 collaborator.connection_id,
1603 proto::UpdateProjectCollaborator {
1604 project_id: project.id.to_proto(),
1605 old_peer_id: Some(project.old_connection_id.into()),
1606 new_peer_id: Some(session.connection_id.into()),
1607 },
1608 )
1609 .trace_err();
1610 }
1611 }
1612
1613 for project in rejoined_projects {
1614 for worktree in mem::take(&mut project.worktrees) {
1615 // Stream this worktree's entries.
1616 let message = proto::UpdateWorktree {
1617 project_id: project.id.to_proto(),
1618 worktree_id: worktree.id,
1619 abs_path: worktree.abs_path.clone(),
1620 root_name: worktree.root_name,
1621 updated_entries: worktree.updated_entries,
1622 removed_entries: worktree.removed_entries,
1623 scan_id: worktree.scan_id,
1624 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1625 updated_repositories: worktree.updated_repositories,
1626 removed_repositories: worktree.removed_repositories,
1627 };
1628 for update in proto::split_worktree_update(message) {
1629 session.peer.send(session.connection_id, update)?;
1630 }
1631
1632 // Stream this worktree's diagnostics.
1633 for summary in worktree.diagnostic_summaries {
1634 session.peer.send(
1635 session.connection_id,
1636 proto::UpdateDiagnosticSummary {
1637 project_id: project.id.to_proto(),
1638 worktree_id: worktree.id,
1639 summary: Some(summary),
1640 },
1641 )?;
1642 }
1643
1644 for settings_file in worktree.settings_files {
1645 session.peer.send(
1646 session.connection_id,
1647 proto::UpdateWorktreeSettings {
1648 project_id: project.id.to_proto(),
1649 worktree_id: worktree.id,
1650 path: settings_file.path,
1651 content: Some(settings_file.content),
1652 kind: Some(settings_file.kind.to_proto().into()),
1653 },
1654 )?;
1655 }
1656 }
1657
1658 for repository in mem::take(&mut project.updated_repositories) {
1659 for update in split_repository_update(repository) {
1660 session.peer.send(session.connection_id, update)?;
1661 }
1662 }
1663
1664 for id in mem::take(&mut project.removed_repositories) {
1665 session.peer.send(
1666 session.connection_id,
1667 proto::RemoveRepository {
1668 project_id: project.id.to_proto(),
1669 id,
1670 },
1671 )?;
1672 }
1673 }
1674
1675 Ok(())
1676}
1677
1678/// leave room disconnects from the room.
1679async fn leave_room(
1680 _: proto::LeaveRoom,
1681 response: Response<proto::LeaveRoom>,
1682 session: Session,
1683) -> Result<()> {
1684 leave_room_for_session(&session, session.connection_id).await?;
1685 response.send(proto::Ack {})?;
1686 Ok(())
1687}
1688
1689/// Updates the permissions of someone else in the room.
1690async fn set_room_participant_role(
1691 request: proto::SetRoomParticipantRole,
1692 response: Response<proto::SetRoomParticipantRole>,
1693 session: Session,
1694) -> Result<()> {
1695 let user_id = UserId::from_proto(request.user_id);
1696 let role = ChannelRole::from(request.role());
1697
1698 let (livekit_room, can_publish) = {
1699 let room = session
1700 .db()
1701 .await
1702 .set_room_participant_role(
1703 session.user_id(),
1704 RoomId::from_proto(request.room_id),
1705 user_id,
1706 role,
1707 )
1708 .await?;
1709
1710 let livekit_room = room.livekit_room.clone();
1711 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1712 room_updated(&room, &session.peer);
1713 (livekit_room, can_publish)
1714 };
1715
1716 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1717 live_kit
1718 .update_participant(
1719 livekit_room.clone(),
1720 request.user_id.to_string(),
1721 livekit_api::proto::ParticipantPermission {
1722 can_subscribe: true,
1723 can_publish,
1724 can_publish_data: can_publish,
1725 hidden: false,
1726 recorder: false,
1727 },
1728 )
1729 .await
1730 .trace_err();
1731 }
1732
1733 response.send(proto::Ack {})?;
1734 Ok(())
1735}
1736
1737/// Call someone else into the current room
1738async fn call(
1739 request: proto::Call,
1740 response: Response<proto::Call>,
1741 session: Session,
1742) -> Result<()> {
1743 let room_id = RoomId::from_proto(request.room_id);
1744 let calling_user_id = session.user_id();
1745 let calling_connection_id = session.connection_id;
1746 let called_user_id = UserId::from_proto(request.called_user_id);
1747 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1748 if !session
1749 .db()
1750 .await
1751 .has_contact(calling_user_id, called_user_id)
1752 .await?
1753 {
1754 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1755 }
1756
1757 let incoming_call = {
1758 let (room, incoming_call) = &mut *session
1759 .db()
1760 .await
1761 .call(
1762 room_id,
1763 calling_user_id,
1764 calling_connection_id,
1765 called_user_id,
1766 initial_project_id,
1767 )
1768 .await?;
1769 room_updated(room, &session.peer);
1770 mem::take(incoming_call)
1771 };
1772 update_user_contacts(called_user_id, &session).await?;
1773
1774 let mut calls = session
1775 .connection_pool()
1776 .await
1777 .user_connection_ids(called_user_id)
1778 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1779 .collect::<FuturesUnordered<_>>();
1780
1781 while let Some(call_response) = calls.next().await {
1782 match call_response.as_ref() {
1783 Ok(_) => {
1784 response.send(proto::Ack {})?;
1785 return Ok(());
1786 }
1787 Err(_) => {
1788 call_response.trace_err();
1789 }
1790 }
1791 }
1792
1793 {
1794 let room = session
1795 .db()
1796 .await
1797 .call_failed(room_id, called_user_id)
1798 .await?;
1799 room_updated(&room, &session.peer);
1800 }
1801 update_user_contacts(called_user_id, &session).await?;
1802
1803 Err(anyhow!("failed to ring user"))?
1804}
1805
1806/// Cancel an outgoing call.
1807async fn cancel_call(
1808 request: proto::CancelCall,
1809 response: Response<proto::CancelCall>,
1810 session: Session,
1811) -> Result<()> {
1812 let called_user_id = UserId::from_proto(request.called_user_id);
1813 let room_id = RoomId::from_proto(request.room_id);
1814 {
1815 let room = session
1816 .db()
1817 .await
1818 .cancel_call(room_id, session.connection_id, called_user_id)
1819 .await?;
1820 room_updated(&room, &session.peer);
1821 }
1822
1823 for connection_id in session
1824 .connection_pool()
1825 .await
1826 .user_connection_ids(called_user_id)
1827 {
1828 session
1829 .peer
1830 .send(
1831 connection_id,
1832 proto::CallCanceled {
1833 room_id: room_id.to_proto(),
1834 },
1835 )
1836 .trace_err();
1837 }
1838 response.send(proto::Ack {})?;
1839
1840 update_user_contacts(called_user_id, &session).await?;
1841 Ok(())
1842}
1843
1844/// Decline an incoming call.
1845async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1846 let room_id = RoomId::from_proto(message.room_id);
1847 {
1848 let room = session
1849 .db()
1850 .await
1851 .decline_call(Some(room_id), session.user_id())
1852 .await?
1853 .context("declining call")?;
1854 room_updated(&room, &session.peer);
1855 }
1856
1857 for connection_id in session
1858 .connection_pool()
1859 .await
1860 .user_connection_ids(session.user_id())
1861 {
1862 session
1863 .peer
1864 .send(
1865 connection_id,
1866 proto::CallCanceled {
1867 room_id: room_id.to_proto(),
1868 },
1869 )
1870 .trace_err();
1871 }
1872 update_user_contacts(session.user_id(), &session).await?;
1873 Ok(())
1874}
1875
1876/// Updates other participants in the room with your current location.
1877async fn update_participant_location(
1878 request: proto::UpdateParticipantLocation,
1879 response: Response<proto::UpdateParticipantLocation>,
1880 session: Session,
1881) -> Result<()> {
1882 let room_id = RoomId::from_proto(request.room_id);
1883 let location = request.location.context("invalid location")?;
1884
1885 let db = session.db().await;
1886 let room = db
1887 .update_room_participant_location(room_id, session.connection_id, location)
1888 .await?;
1889
1890 room_updated(&room, &session.peer);
1891 response.send(proto::Ack {})?;
1892 Ok(())
1893}
1894
1895/// Share a project into the room.
1896async fn share_project(
1897 request: proto::ShareProject,
1898 response: Response<proto::ShareProject>,
1899 session: Session,
1900) -> Result<()> {
1901 let (project_id, room) = &*session
1902 .db()
1903 .await
1904 .share_project(
1905 RoomId::from_proto(request.room_id),
1906 session.connection_id,
1907 &request.worktrees,
1908 request.is_ssh_project,
1909 )
1910 .await?;
1911 response.send(proto::ShareProjectResponse {
1912 project_id: project_id.to_proto(),
1913 })?;
1914 room_updated(room, &session.peer);
1915
1916 Ok(())
1917}
1918
1919/// Unshare a project from the room.
1920async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1921 let project_id = ProjectId::from_proto(message.project_id);
1922 unshare_project_internal(project_id, session.connection_id, &session).await
1923}
1924
1925async fn unshare_project_internal(
1926 project_id: ProjectId,
1927 connection_id: ConnectionId,
1928 session: &Session,
1929) -> Result<()> {
1930 let delete = {
1931 let room_guard = session
1932 .db()
1933 .await
1934 .unshare_project(project_id, connection_id)
1935 .await?;
1936
1937 let (delete, room, guest_connection_ids) = &*room_guard;
1938
1939 let message = proto::UnshareProject {
1940 project_id: project_id.to_proto(),
1941 };
1942
1943 broadcast(
1944 Some(connection_id),
1945 guest_connection_ids.iter().copied(),
1946 |conn_id| session.peer.send(conn_id, message.clone()),
1947 );
1948 if let Some(room) = room {
1949 room_updated(room, &session.peer);
1950 }
1951
1952 *delete
1953 };
1954
1955 if delete {
1956 let db = session.db().await;
1957 db.delete_project(project_id).await?;
1958 }
1959
1960 Ok(())
1961}
1962
1963/// Join someone elses shared project.
1964async fn join_project(
1965 request: proto::JoinProject,
1966 response: Response<proto::JoinProject>,
1967 session: Session,
1968) -> Result<()> {
1969 let project_id = ProjectId::from_proto(request.project_id);
1970
1971 tracing::info!(%project_id, "join project");
1972
1973 let db = session.db().await;
1974 let (project, replica_id) = &mut *db
1975 .join_project(
1976 project_id,
1977 session.connection_id,
1978 session.user_id(),
1979 request.committer_name.clone(),
1980 request.committer_email.clone(),
1981 )
1982 .await?;
1983 drop(db);
1984 tracing::info!(%project_id, "join remote project");
1985 let collaborators = project
1986 .collaborators
1987 .iter()
1988 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1989 .map(|collaborator| collaborator.to_proto())
1990 .collect::<Vec<_>>();
1991 let project_id = project.id;
1992 let guest_user_id = session.user_id();
1993
1994 let worktrees = project
1995 .worktrees
1996 .iter()
1997 .map(|(id, worktree)| proto::WorktreeMetadata {
1998 id: *id,
1999 root_name: worktree.root_name.clone(),
2000 visible: worktree.visible,
2001 abs_path: worktree.abs_path.clone(),
2002 })
2003 .collect::<Vec<_>>();
2004
2005 let add_project_collaborator = proto::AddProjectCollaborator {
2006 project_id: project_id.to_proto(),
2007 collaborator: Some(proto::Collaborator {
2008 peer_id: Some(session.connection_id.into()),
2009 replica_id: replica_id.0 as u32,
2010 user_id: guest_user_id.to_proto(),
2011 is_host: false,
2012 committer_name: request.committer_name.clone(),
2013 committer_email: request.committer_email.clone(),
2014 }),
2015 };
2016
2017 for collaborator in &collaborators {
2018 session
2019 .peer
2020 .send(
2021 collaborator.peer_id.unwrap().into(),
2022 add_project_collaborator.clone(),
2023 )
2024 .trace_err();
2025 }
2026
2027 // First, we send the metadata associated with each worktree.
2028 let (language_servers, language_server_capabilities) = project
2029 .language_servers
2030 .clone()
2031 .into_iter()
2032 .map(|server| (server.server, server.capabilities))
2033 .unzip();
2034 response.send(proto::JoinProjectResponse {
2035 project_id: project.id.0 as u64,
2036 worktrees: worktrees.clone(),
2037 replica_id: replica_id.0 as u32,
2038 collaborators: collaborators.clone(),
2039 language_servers,
2040 language_server_capabilities,
2041 role: project.role.into(),
2042 })?;
2043
2044 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2045 // Stream this worktree's entries.
2046 let message = proto::UpdateWorktree {
2047 project_id: project_id.to_proto(),
2048 worktree_id,
2049 abs_path: worktree.abs_path.clone(),
2050 root_name: worktree.root_name,
2051 updated_entries: worktree.entries,
2052 removed_entries: Default::default(),
2053 scan_id: worktree.scan_id,
2054 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2055 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2056 removed_repositories: Default::default(),
2057 };
2058 for update in proto::split_worktree_update(message) {
2059 session.peer.send(session.connection_id, update.clone())?;
2060 }
2061
2062 // Stream this worktree's diagnostics.
2063 for summary in worktree.diagnostic_summaries {
2064 session.peer.send(
2065 session.connection_id,
2066 proto::UpdateDiagnosticSummary {
2067 project_id: project_id.to_proto(),
2068 worktree_id: worktree.id,
2069 summary: Some(summary),
2070 },
2071 )?;
2072 }
2073
2074 for settings_file in worktree.settings_files {
2075 session.peer.send(
2076 session.connection_id,
2077 proto::UpdateWorktreeSettings {
2078 project_id: project_id.to_proto(),
2079 worktree_id: worktree.id,
2080 path: settings_file.path,
2081 content: Some(settings_file.content),
2082 kind: Some(settings_file.kind.to_proto() as i32),
2083 },
2084 )?;
2085 }
2086 }
2087
2088 for repository in mem::take(&mut project.repositories) {
2089 for update in split_repository_update(repository) {
2090 session.peer.send(session.connection_id, update)?;
2091 }
2092 }
2093
2094 for language_server in &project.language_servers {
2095 session.peer.send(
2096 session.connection_id,
2097 proto::UpdateLanguageServer {
2098 project_id: project_id.to_proto(),
2099 server_name: Some(language_server.server.name.clone()),
2100 language_server_id: language_server.server.id,
2101 variant: Some(
2102 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2103 proto::LspDiskBasedDiagnosticsUpdated {},
2104 ),
2105 ),
2106 },
2107 )?;
2108 }
2109
2110 Ok(())
2111}
2112
2113/// Leave someone elses shared project.
2114async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2115 let sender_id = session.connection_id;
2116 let project_id = ProjectId::from_proto(request.project_id);
2117 let db = session.db().await;
2118
2119 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2120 tracing::info!(
2121 %project_id,
2122 "leave project"
2123 );
2124
2125 project_left(project, &session);
2126 if let Some(room) = room {
2127 room_updated(room, &session.peer);
2128 }
2129
2130 Ok(())
2131}
2132
2133/// Updates other participants with changes to the project
2134async fn update_project(
2135 request: proto::UpdateProject,
2136 response: Response<proto::UpdateProject>,
2137 session: Session,
2138) -> Result<()> {
2139 let project_id = ProjectId::from_proto(request.project_id);
2140 let (room, guest_connection_ids) = &*session
2141 .db()
2142 .await
2143 .update_project(project_id, session.connection_id, &request.worktrees)
2144 .await?;
2145 broadcast(
2146 Some(session.connection_id),
2147 guest_connection_ids.iter().copied(),
2148 |connection_id| {
2149 session
2150 .peer
2151 .forward_send(session.connection_id, connection_id, request.clone())
2152 },
2153 );
2154 if let Some(room) = room {
2155 room_updated(room, &session.peer);
2156 }
2157 response.send(proto::Ack {})?;
2158
2159 Ok(())
2160}
2161
2162/// Updates other participants with changes to the worktree
2163async fn update_worktree(
2164 request: proto::UpdateWorktree,
2165 response: Response<proto::UpdateWorktree>,
2166 session: Session,
2167) -> Result<()> {
2168 let guest_connection_ids = session
2169 .db()
2170 .await
2171 .update_worktree(&request, session.connection_id)
2172 .await?;
2173
2174 broadcast(
2175 Some(session.connection_id),
2176 guest_connection_ids.iter().copied(),
2177 |connection_id| {
2178 session
2179 .peer
2180 .forward_send(session.connection_id, connection_id, request.clone())
2181 },
2182 );
2183 response.send(proto::Ack {})?;
2184 Ok(())
2185}
2186
2187async fn update_repository(
2188 request: proto::UpdateRepository,
2189 response: Response<proto::UpdateRepository>,
2190 session: Session,
2191) -> Result<()> {
2192 let guest_connection_ids = session
2193 .db()
2194 .await
2195 .update_repository(&request, session.connection_id)
2196 .await?;
2197
2198 broadcast(
2199 Some(session.connection_id),
2200 guest_connection_ids.iter().copied(),
2201 |connection_id| {
2202 session
2203 .peer
2204 .forward_send(session.connection_id, connection_id, request.clone())
2205 },
2206 );
2207 response.send(proto::Ack {})?;
2208 Ok(())
2209}
2210
2211async fn remove_repository(
2212 request: proto::RemoveRepository,
2213 response: Response<proto::RemoveRepository>,
2214 session: Session,
2215) -> Result<()> {
2216 let guest_connection_ids = session
2217 .db()
2218 .await
2219 .remove_repository(&request, session.connection_id)
2220 .await?;
2221
2222 broadcast(
2223 Some(session.connection_id),
2224 guest_connection_ids.iter().copied(),
2225 |connection_id| {
2226 session
2227 .peer
2228 .forward_send(session.connection_id, connection_id, request.clone())
2229 },
2230 );
2231 response.send(proto::Ack {})?;
2232 Ok(())
2233}
2234
2235/// Updates other participants with changes to the diagnostics
2236async fn update_diagnostic_summary(
2237 message: proto::UpdateDiagnosticSummary,
2238 session: Session,
2239) -> Result<()> {
2240 let guest_connection_ids = session
2241 .db()
2242 .await
2243 .update_diagnostic_summary(&message, session.connection_id)
2244 .await?;
2245
2246 broadcast(
2247 Some(session.connection_id),
2248 guest_connection_ids.iter().copied(),
2249 |connection_id| {
2250 session
2251 .peer
2252 .forward_send(session.connection_id, connection_id, message.clone())
2253 },
2254 );
2255
2256 Ok(())
2257}
2258
2259/// Updates other participants with changes to the worktree settings
2260async fn update_worktree_settings(
2261 message: proto::UpdateWorktreeSettings,
2262 session: Session,
2263) -> Result<()> {
2264 let guest_connection_ids = session
2265 .db()
2266 .await
2267 .update_worktree_settings(&message, session.connection_id)
2268 .await?;
2269
2270 broadcast(
2271 Some(session.connection_id),
2272 guest_connection_ids.iter().copied(),
2273 |connection_id| {
2274 session
2275 .peer
2276 .forward_send(session.connection_id, connection_id, message.clone())
2277 },
2278 );
2279
2280 Ok(())
2281}
2282
2283/// Notify other participants that a language server has started.
2284async fn start_language_server(
2285 request: proto::StartLanguageServer,
2286 session: Session,
2287) -> Result<()> {
2288 let guest_connection_ids = session
2289 .db()
2290 .await
2291 .start_language_server(&request, session.connection_id)
2292 .await?;
2293
2294 broadcast(
2295 Some(session.connection_id),
2296 guest_connection_ids.iter().copied(),
2297 |connection_id| {
2298 session
2299 .peer
2300 .forward_send(session.connection_id, connection_id, request.clone())
2301 },
2302 );
2303 Ok(())
2304}
2305
2306/// Notify other participants that a language server has changed.
2307async fn update_language_server(
2308 request: proto::UpdateLanguageServer,
2309 session: Session,
2310) -> Result<()> {
2311 let project_id = ProjectId::from_proto(request.project_id);
2312 let db = session.db().await;
2313
2314 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2315 {
2316 if let Some(capabilities) = update.capabilities.clone() {
2317 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2318 .await?;
2319 }
2320 }
2321
2322 let project_connection_ids = db
2323 .project_connection_ids(project_id, session.connection_id, true)
2324 .await?;
2325 broadcast(
2326 Some(session.connection_id),
2327 project_connection_ids.iter().copied(),
2328 |connection_id| {
2329 session
2330 .peer
2331 .forward_send(session.connection_id, connection_id, request.clone())
2332 },
2333 );
2334 Ok(())
2335}
2336
2337/// forward a project request to the host. These requests should be read only
2338/// as guests are allowed to send them.
2339async fn forward_read_only_project_request<T>(
2340 request: T,
2341 response: Response<T>,
2342 session: Session,
2343) -> Result<()>
2344where
2345 T: EntityMessage + RequestMessage,
2346{
2347 let project_id = ProjectId::from_proto(request.remote_entity_id());
2348 let host_connection_id = session
2349 .db()
2350 .await
2351 .host_for_read_only_project_request(project_id, session.connection_id)
2352 .await?;
2353 let payload = session
2354 .peer
2355 .forward_request(session.connection_id, host_connection_id, request)
2356 .await?;
2357 response.send(payload)?;
2358 Ok(())
2359}
2360
2361/// forward a project request to the host. These requests are disallowed
2362/// for guests.
2363async fn forward_mutating_project_request<T>(
2364 request: T,
2365 response: Response<T>,
2366 session: Session,
2367) -> Result<()>
2368where
2369 T: EntityMessage + RequestMessage,
2370{
2371 let project_id = ProjectId::from_proto(request.remote_entity_id());
2372
2373 let host_connection_id = session
2374 .db()
2375 .await
2376 .host_for_mutating_project_request(project_id, session.connection_id)
2377 .await?;
2378 let payload = session
2379 .peer
2380 .forward_request(session.connection_id, host_connection_id, request)
2381 .await?;
2382 response.send(payload)?;
2383 Ok(())
2384}
2385
2386async fn multi_lsp_query(
2387 request: MultiLspQuery,
2388 response: Response<MultiLspQuery>,
2389 session: Session,
2390) -> Result<()> {
2391 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2392 tracing::info!("multi_lsp_query message received");
2393 forward_mutating_project_request(request, response, session).await
2394}
2395
2396/// Notify other participants that a new buffer has been created
2397async fn create_buffer_for_peer(
2398 request: proto::CreateBufferForPeer,
2399 session: Session,
2400) -> Result<()> {
2401 session
2402 .db()
2403 .await
2404 .check_user_is_project_host(
2405 ProjectId::from_proto(request.project_id),
2406 session.connection_id,
2407 )
2408 .await?;
2409 let peer_id = request.peer_id.context("invalid peer id")?;
2410 session
2411 .peer
2412 .forward_send(session.connection_id, peer_id.into(), request)?;
2413 Ok(())
2414}
2415
2416/// Notify other participants that a buffer has been updated. This is
2417/// allowed for guests as long as the update is limited to selections.
2418async fn update_buffer(
2419 request: proto::UpdateBuffer,
2420 response: Response<proto::UpdateBuffer>,
2421 session: Session,
2422) -> Result<()> {
2423 let project_id = ProjectId::from_proto(request.project_id);
2424 let mut capability = Capability::ReadOnly;
2425
2426 for op in request.operations.iter() {
2427 match op.variant {
2428 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2429 Some(_) => capability = Capability::ReadWrite,
2430 }
2431 }
2432
2433 let host = {
2434 let guard = session
2435 .db()
2436 .await
2437 .connections_for_buffer_update(project_id, session.connection_id, capability)
2438 .await?;
2439
2440 let (host, guests) = &*guard;
2441
2442 broadcast(
2443 Some(session.connection_id),
2444 guests.clone(),
2445 |connection_id| {
2446 session
2447 .peer
2448 .forward_send(session.connection_id, connection_id, request.clone())
2449 },
2450 );
2451
2452 *host
2453 };
2454
2455 if host != session.connection_id {
2456 session
2457 .peer
2458 .forward_request(session.connection_id, host, request.clone())
2459 .await?;
2460 }
2461
2462 response.send(proto::Ack {})?;
2463 Ok(())
2464}
2465
2466async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2467 let project_id = ProjectId::from_proto(message.project_id);
2468
2469 let operation = message.operation.as_ref().context("invalid operation")?;
2470 let capability = match operation.variant.as_ref() {
2471 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2472 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2473 match buffer_op.variant {
2474 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2475 Capability::ReadOnly
2476 }
2477 _ => Capability::ReadWrite,
2478 }
2479 } else {
2480 Capability::ReadWrite
2481 }
2482 }
2483 Some(_) => Capability::ReadWrite,
2484 None => Capability::ReadOnly,
2485 };
2486
2487 let guard = session
2488 .db()
2489 .await
2490 .connections_for_buffer_update(project_id, session.connection_id, capability)
2491 .await?;
2492
2493 let (host, guests) = &*guard;
2494
2495 broadcast(
2496 Some(session.connection_id),
2497 guests.iter().chain([host]).copied(),
2498 |connection_id| {
2499 session
2500 .peer
2501 .forward_send(session.connection_id, connection_id, message.clone())
2502 },
2503 );
2504
2505 Ok(())
2506}
2507
2508/// Notify other participants that a project has been updated.
2509async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2510 request: T,
2511 session: Session,
2512) -> Result<()> {
2513 let project_id = ProjectId::from_proto(request.remote_entity_id());
2514 let project_connection_ids = session
2515 .db()
2516 .await
2517 .project_connection_ids(project_id, session.connection_id, false)
2518 .await?;
2519
2520 broadcast(
2521 Some(session.connection_id),
2522 project_connection_ids.iter().copied(),
2523 |connection_id| {
2524 session
2525 .peer
2526 .forward_send(session.connection_id, connection_id, request.clone())
2527 },
2528 );
2529 Ok(())
2530}
2531
2532/// Start following another user in a call.
2533async fn follow(
2534 request: proto::Follow,
2535 response: Response<proto::Follow>,
2536 session: Session,
2537) -> Result<()> {
2538 let room_id = RoomId::from_proto(request.room_id);
2539 let project_id = request.project_id.map(ProjectId::from_proto);
2540 let leader_id = request.leader_id.context("invalid leader id")?.into();
2541 let follower_id = session.connection_id;
2542
2543 session
2544 .db()
2545 .await
2546 .check_room_participants(room_id, leader_id, session.connection_id)
2547 .await?;
2548
2549 let response_payload = session
2550 .peer
2551 .forward_request(session.connection_id, leader_id, request)
2552 .await?;
2553 response.send(response_payload)?;
2554
2555 if let Some(project_id) = project_id {
2556 let room = session
2557 .db()
2558 .await
2559 .follow(room_id, project_id, leader_id, follower_id)
2560 .await?;
2561 room_updated(&room, &session.peer);
2562 }
2563
2564 Ok(())
2565}
2566
2567/// Stop following another user in a call.
2568async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2569 let room_id = RoomId::from_proto(request.room_id);
2570 let project_id = request.project_id.map(ProjectId::from_proto);
2571 let leader_id = request.leader_id.context("invalid leader id")?.into();
2572 let follower_id = session.connection_id;
2573
2574 session
2575 .db()
2576 .await
2577 .check_room_participants(room_id, leader_id, session.connection_id)
2578 .await?;
2579
2580 session
2581 .peer
2582 .forward_send(session.connection_id, leader_id, request)?;
2583
2584 if let Some(project_id) = project_id {
2585 let room = session
2586 .db()
2587 .await
2588 .unfollow(room_id, project_id, leader_id, follower_id)
2589 .await?;
2590 room_updated(&room, &session.peer);
2591 }
2592
2593 Ok(())
2594}
2595
2596/// Notify everyone following you of your current location.
2597async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2598 let room_id = RoomId::from_proto(request.room_id);
2599 let database = session.db.lock().await;
2600
2601 let connection_ids = if let Some(project_id) = request.project_id {
2602 let project_id = ProjectId::from_proto(project_id);
2603 database
2604 .project_connection_ids(project_id, session.connection_id, true)
2605 .await?
2606 } else {
2607 database
2608 .room_connection_ids(room_id, session.connection_id)
2609 .await?
2610 };
2611
2612 // For now, don't send view update messages back to that view's current leader.
2613 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2614 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2615 _ => None,
2616 });
2617
2618 for connection_id in connection_ids.iter().cloned() {
2619 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2620 session
2621 .peer
2622 .forward_send(session.connection_id, connection_id, request.clone())?;
2623 }
2624 }
2625 Ok(())
2626}
2627
2628/// Get public data about users.
2629async fn get_users(
2630 request: proto::GetUsers,
2631 response: Response<proto::GetUsers>,
2632 session: Session,
2633) -> Result<()> {
2634 let user_ids = request
2635 .user_ids
2636 .into_iter()
2637 .map(UserId::from_proto)
2638 .collect();
2639 let users = session
2640 .db()
2641 .await
2642 .get_users_by_ids(user_ids)
2643 .await?
2644 .into_iter()
2645 .map(|user| proto::User {
2646 id: user.id.to_proto(),
2647 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2648 github_login: user.github_login,
2649 name: user.name,
2650 })
2651 .collect();
2652 response.send(proto::UsersResponse { users })?;
2653 Ok(())
2654}
2655
2656/// Search for users (to invite) buy Github login
2657async fn fuzzy_search_users(
2658 request: proto::FuzzySearchUsers,
2659 response: Response<proto::FuzzySearchUsers>,
2660 session: Session,
2661) -> Result<()> {
2662 let query = request.query;
2663 let users = match query.len() {
2664 0 => vec![],
2665 1 | 2 => session
2666 .db()
2667 .await
2668 .get_user_by_github_login(&query)
2669 .await?
2670 .into_iter()
2671 .collect(),
2672 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2673 };
2674 let users = users
2675 .into_iter()
2676 .filter(|user| user.id != session.user_id())
2677 .map(|user| proto::User {
2678 id: user.id.to_proto(),
2679 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2680 github_login: user.github_login,
2681 name: user.name,
2682 })
2683 .collect();
2684 response.send(proto::UsersResponse { users })?;
2685 Ok(())
2686}
2687
2688/// Send a contact request to another user.
2689async fn request_contact(
2690 request: proto::RequestContact,
2691 response: Response<proto::RequestContact>,
2692 session: Session,
2693) -> Result<()> {
2694 let requester_id = session.user_id();
2695 let responder_id = UserId::from_proto(request.responder_id);
2696 if requester_id == responder_id {
2697 return Err(anyhow!("cannot add yourself as a contact"))?;
2698 }
2699
2700 let notifications = session
2701 .db()
2702 .await
2703 .send_contact_request(requester_id, responder_id)
2704 .await?;
2705
2706 // Update outgoing contact requests of requester
2707 let mut update = proto::UpdateContacts::default();
2708 update.outgoing_requests.push(responder_id.to_proto());
2709 for connection_id in session
2710 .connection_pool()
2711 .await
2712 .user_connection_ids(requester_id)
2713 {
2714 session.peer.send(connection_id, update.clone())?;
2715 }
2716
2717 // Update incoming contact requests of responder
2718 let mut update = proto::UpdateContacts::default();
2719 update
2720 .incoming_requests
2721 .push(proto::IncomingContactRequest {
2722 requester_id: requester_id.to_proto(),
2723 });
2724 let connection_pool = session.connection_pool().await;
2725 for connection_id in connection_pool.user_connection_ids(responder_id) {
2726 session.peer.send(connection_id, update.clone())?;
2727 }
2728
2729 send_notifications(&connection_pool, &session.peer, notifications);
2730
2731 response.send(proto::Ack {})?;
2732 Ok(())
2733}
2734
2735/// Accept or decline a contact request
2736async fn respond_to_contact_request(
2737 request: proto::RespondToContactRequest,
2738 response: Response<proto::RespondToContactRequest>,
2739 session: Session,
2740) -> Result<()> {
2741 let responder_id = session.user_id();
2742 let requester_id = UserId::from_proto(request.requester_id);
2743 let db = session.db().await;
2744 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2745 db.dismiss_contact_notification(responder_id, requester_id)
2746 .await?;
2747 } else {
2748 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2749
2750 let notifications = db
2751 .respond_to_contact_request(responder_id, requester_id, accept)
2752 .await?;
2753 let requester_busy = db.is_user_busy(requester_id).await?;
2754 let responder_busy = db.is_user_busy(responder_id).await?;
2755
2756 let pool = session.connection_pool().await;
2757 // Update responder with new contact
2758 let mut update = proto::UpdateContacts::default();
2759 if accept {
2760 update
2761 .contacts
2762 .push(contact_for_user(requester_id, requester_busy, &pool));
2763 }
2764 update
2765 .remove_incoming_requests
2766 .push(requester_id.to_proto());
2767 for connection_id in pool.user_connection_ids(responder_id) {
2768 session.peer.send(connection_id, update.clone())?;
2769 }
2770
2771 // Update requester with new contact
2772 let mut update = proto::UpdateContacts::default();
2773 if accept {
2774 update
2775 .contacts
2776 .push(contact_for_user(responder_id, responder_busy, &pool));
2777 }
2778 update
2779 .remove_outgoing_requests
2780 .push(responder_id.to_proto());
2781
2782 for connection_id in pool.user_connection_ids(requester_id) {
2783 session.peer.send(connection_id, update.clone())?;
2784 }
2785
2786 send_notifications(&pool, &session.peer, notifications);
2787 }
2788
2789 response.send(proto::Ack {})?;
2790 Ok(())
2791}
2792
2793/// Remove a contact.
2794async fn remove_contact(
2795 request: proto::RemoveContact,
2796 response: Response<proto::RemoveContact>,
2797 session: Session,
2798) -> Result<()> {
2799 let requester_id = session.user_id();
2800 let responder_id = UserId::from_proto(request.user_id);
2801 let db = session.db().await;
2802 let (contact_accepted, deleted_notification_id) =
2803 db.remove_contact(requester_id, responder_id).await?;
2804
2805 let pool = session.connection_pool().await;
2806 // Update outgoing contact requests of requester
2807 let mut update = proto::UpdateContacts::default();
2808 if contact_accepted {
2809 update.remove_contacts.push(responder_id.to_proto());
2810 } else {
2811 update
2812 .remove_outgoing_requests
2813 .push(responder_id.to_proto());
2814 }
2815 for connection_id in pool.user_connection_ids(requester_id) {
2816 session.peer.send(connection_id, update.clone())?;
2817 }
2818
2819 // Update incoming contact requests of responder
2820 let mut update = proto::UpdateContacts::default();
2821 if contact_accepted {
2822 update.remove_contacts.push(requester_id.to_proto());
2823 } else {
2824 update
2825 .remove_incoming_requests
2826 .push(requester_id.to_proto());
2827 }
2828 for connection_id in pool.user_connection_ids(responder_id) {
2829 session.peer.send(connection_id, update.clone())?;
2830 if let Some(notification_id) = deleted_notification_id {
2831 session.peer.send(
2832 connection_id,
2833 proto::DeleteNotification {
2834 notification_id: notification_id.to_proto(),
2835 },
2836 )?;
2837 }
2838 }
2839
2840 response.send(proto::Ack {})?;
2841 Ok(())
2842}
2843
2844fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2845 version.0.minor() < 139
2846}
2847
2848async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2849 if is_staff {
2850 return Ok(proto::Plan::ZedPro);
2851 }
2852
2853 let subscription = db.get_active_billing_subscription(user_id).await?;
2854 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2855
2856 let plan = if let Some(subscription_kind) = subscription_kind {
2857 match subscription_kind {
2858 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2859 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2860 SubscriptionKind::ZedFree => proto::Plan::Free,
2861 }
2862 } else {
2863 proto::Plan::Free
2864 };
2865
2866 Ok(plan)
2867}
2868
2869async fn make_update_user_plan_message(
2870 user: &User,
2871 is_staff: bool,
2872 db: &Arc<Database>,
2873 llm_db: Option<Arc<LlmDatabase>>,
2874) -> Result<proto::UpdateUserPlan> {
2875 let feature_flags = db.get_user_flags(user.id).await?;
2876 let plan = current_plan(db, user.id, is_staff).await?;
2877 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2878 let billing_preferences = db.get_billing_preferences(user.id).await?;
2879
2880 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2881 let subscription = db.get_active_billing_subscription(user.id).await?;
2882
2883 let subscription_period =
2884 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2885
2886 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2887 llm_db
2888 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2889 .await?
2890 } else {
2891 None
2892 };
2893
2894 (subscription_period, usage)
2895 } else {
2896 (None, None)
2897 };
2898
2899 let bypass_account_age_check = feature_flags
2900 .iter()
2901 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2902 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2903 && !bypass_account_age_check
2904 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2905
2906 Ok(proto::UpdateUserPlan {
2907 plan: plan.into(),
2908 trial_started_at: billing_customer
2909 .as_ref()
2910 .and_then(|billing_customer| billing_customer.trial_started_at)
2911 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2912 is_usage_based_billing_enabled: if is_staff {
2913 Some(true)
2914 } else {
2915 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2916 },
2917 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2918 proto::SubscriptionPeriod {
2919 started_at: started_at.timestamp() as u64,
2920 ended_at: ended_at.timestamp() as u64,
2921 }
2922 }),
2923 account_too_young: Some(account_too_young),
2924 has_overdue_invoices: billing_customer
2925 .map(|billing_customer| billing_customer.has_overdue_invoices),
2926 usage: Some(
2927 usage
2928 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2929 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2930 ),
2931 })
2932}
2933
2934fn model_requests_limit(
2935 plan: cloud_llm_client::Plan,
2936 feature_flags: &Vec<String>,
2937) -> cloud_llm_client::UsageLimit {
2938 match plan.model_requests_limit() {
2939 cloud_llm_client::UsageLimit::Limited(limit) => {
2940 let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2941 && feature_flags
2942 .iter()
2943 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2944 {
2945 1_000
2946 } else {
2947 limit
2948 };
2949
2950 cloud_llm_client::UsageLimit::Limited(limit)
2951 }
2952 cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2953 }
2954}
2955
2956fn subscription_usage_to_proto(
2957 plan: proto::Plan,
2958 usage: crate::llm::db::subscription_usage::Model,
2959 feature_flags: &Vec<String>,
2960) -> proto::SubscriptionUsage {
2961 let plan = match plan {
2962 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2963 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2964 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2965 };
2966
2967 proto::SubscriptionUsage {
2968 model_requests_usage_amount: usage.model_requests as u32,
2969 model_requests_usage_limit: Some(proto::UsageLimit {
2970 variant: Some(match model_requests_limit(plan, feature_flags) {
2971 cloud_llm_client::UsageLimit::Limited(limit) => {
2972 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2973 limit: limit as u32,
2974 })
2975 }
2976 cloud_llm_client::UsageLimit::Unlimited => {
2977 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2978 }
2979 }),
2980 }),
2981 edit_predictions_usage_amount: usage.edit_predictions as u32,
2982 edit_predictions_usage_limit: Some(proto::UsageLimit {
2983 variant: Some(match plan.edit_predictions_limit() {
2984 cloud_llm_client::UsageLimit::Limited(limit) => {
2985 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2986 limit: limit as u32,
2987 })
2988 }
2989 cloud_llm_client::UsageLimit::Unlimited => {
2990 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2991 }
2992 }),
2993 }),
2994 }
2995}
2996
2997fn make_default_subscription_usage(
2998 plan: proto::Plan,
2999 feature_flags: &Vec<String>,
3000) -> proto::SubscriptionUsage {
3001 let plan = match plan {
3002 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
3003 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
3004 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
3005 };
3006
3007 proto::SubscriptionUsage {
3008 model_requests_usage_amount: 0,
3009 model_requests_usage_limit: Some(proto::UsageLimit {
3010 variant: Some(match model_requests_limit(plan, feature_flags) {
3011 cloud_llm_client::UsageLimit::Limited(limit) => {
3012 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3013 limit: limit as u32,
3014 })
3015 }
3016 cloud_llm_client::UsageLimit::Unlimited => {
3017 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3018 }
3019 }),
3020 }),
3021 edit_predictions_usage_amount: 0,
3022 edit_predictions_usage_limit: Some(proto::UsageLimit {
3023 variant: Some(match plan.edit_predictions_limit() {
3024 cloud_llm_client::UsageLimit::Limited(limit) => {
3025 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3026 limit: limit as u32,
3027 })
3028 }
3029 cloud_llm_client::UsageLimit::Unlimited => {
3030 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3031 }
3032 }),
3033 }),
3034 }
3035}
3036
3037async fn update_user_plan(session: &Session) -> Result<()> {
3038 let db = session.db().await;
3039
3040 let update_user_plan = make_update_user_plan_message(
3041 session.principal.user(),
3042 session.is_staff(),
3043 &db.0,
3044 session.app_state.llm_db.clone(),
3045 )
3046 .await?;
3047
3048 session
3049 .peer
3050 .send(session.connection_id, update_user_plan)
3051 .trace_err();
3052
3053 Ok(())
3054}
3055
3056async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3057 subscribe_user_to_channels(session.user_id(), &session).await?;
3058 Ok(())
3059}
3060
3061async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3062 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3063 let mut pool = session.connection_pool().await;
3064 for membership in &channels_for_user.channel_memberships {
3065 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3066 }
3067 session.peer.send(
3068 session.connection_id,
3069 build_update_user_channels(&channels_for_user),
3070 )?;
3071 session.peer.send(
3072 session.connection_id,
3073 build_channels_update(channels_for_user),
3074 )?;
3075 Ok(())
3076}
3077
3078/// Creates a new channel.
3079async fn create_channel(
3080 request: proto::CreateChannel,
3081 response: Response<proto::CreateChannel>,
3082 session: Session,
3083) -> Result<()> {
3084 let db = session.db().await;
3085
3086 let parent_id = request.parent_id.map(ChannelId::from_proto);
3087 let (channel, membership) = db
3088 .create_channel(&request.name, parent_id, session.user_id())
3089 .await?;
3090
3091 let root_id = channel.root_id();
3092 let channel = Channel::from_model(channel);
3093
3094 response.send(proto::CreateChannelResponse {
3095 channel: Some(channel.to_proto()),
3096 parent_id: request.parent_id,
3097 })?;
3098
3099 let mut connection_pool = session.connection_pool().await;
3100 if let Some(membership) = membership {
3101 connection_pool.subscribe_to_channel(
3102 membership.user_id,
3103 membership.channel_id,
3104 membership.role,
3105 );
3106 let update = proto::UpdateUserChannels {
3107 channel_memberships: vec![proto::ChannelMembership {
3108 channel_id: membership.channel_id.to_proto(),
3109 role: membership.role.into(),
3110 }],
3111 ..Default::default()
3112 };
3113 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3114 session.peer.send(connection_id, update.clone())?;
3115 }
3116 }
3117
3118 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3119 if !role.can_see_channel(channel.visibility) {
3120 continue;
3121 }
3122
3123 let update = proto::UpdateChannels {
3124 channels: vec![channel.to_proto()],
3125 ..Default::default()
3126 };
3127 session.peer.send(connection_id, update.clone())?;
3128 }
3129
3130 Ok(())
3131}
3132
3133/// Delete a channel
3134async fn delete_channel(
3135 request: proto::DeleteChannel,
3136 response: Response<proto::DeleteChannel>,
3137 session: Session,
3138) -> Result<()> {
3139 let db = session.db().await;
3140
3141 let channel_id = request.channel_id;
3142 let (root_channel, removed_channels) = db
3143 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3144 .await?;
3145 response.send(proto::Ack {})?;
3146
3147 // Notify members of removed channels
3148 let mut update = proto::UpdateChannels::default();
3149 update
3150 .delete_channels
3151 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3152
3153 let connection_pool = session.connection_pool().await;
3154 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3155 session.peer.send(connection_id, update.clone())?;
3156 }
3157
3158 Ok(())
3159}
3160
3161/// Invite someone to join a channel.
3162async fn invite_channel_member(
3163 request: proto::InviteChannelMember,
3164 response: Response<proto::InviteChannelMember>,
3165 session: Session,
3166) -> Result<()> {
3167 let db = session.db().await;
3168 let channel_id = ChannelId::from_proto(request.channel_id);
3169 let invitee_id = UserId::from_proto(request.user_id);
3170 let InviteMemberResult {
3171 channel,
3172 notifications,
3173 } = db
3174 .invite_channel_member(
3175 channel_id,
3176 invitee_id,
3177 session.user_id(),
3178 request.role().into(),
3179 )
3180 .await?;
3181
3182 let update = proto::UpdateChannels {
3183 channel_invitations: vec![channel.to_proto()],
3184 ..Default::default()
3185 };
3186
3187 let connection_pool = session.connection_pool().await;
3188 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3189 session.peer.send(connection_id, update.clone())?;
3190 }
3191
3192 send_notifications(&connection_pool, &session.peer, notifications);
3193
3194 response.send(proto::Ack {})?;
3195 Ok(())
3196}
3197
3198/// remove someone from a channel
3199async fn remove_channel_member(
3200 request: proto::RemoveChannelMember,
3201 response: Response<proto::RemoveChannelMember>,
3202 session: Session,
3203) -> Result<()> {
3204 let db = session.db().await;
3205 let channel_id = ChannelId::from_proto(request.channel_id);
3206 let member_id = UserId::from_proto(request.user_id);
3207
3208 let RemoveChannelMemberResult {
3209 membership_update,
3210 notification_id,
3211 } = db
3212 .remove_channel_member(channel_id, member_id, session.user_id())
3213 .await?;
3214
3215 let mut connection_pool = session.connection_pool().await;
3216 notify_membership_updated(
3217 &mut connection_pool,
3218 membership_update,
3219 member_id,
3220 &session.peer,
3221 );
3222 for connection_id in connection_pool.user_connection_ids(member_id) {
3223 if let Some(notification_id) = notification_id {
3224 session
3225 .peer
3226 .send(
3227 connection_id,
3228 proto::DeleteNotification {
3229 notification_id: notification_id.to_proto(),
3230 },
3231 )
3232 .trace_err();
3233 }
3234 }
3235
3236 response.send(proto::Ack {})?;
3237 Ok(())
3238}
3239
3240/// Toggle the channel between public and private.
3241/// Care is taken to maintain the invariant that public channels only descend from public channels,
3242/// (though members-only channels can appear at any point in the hierarchy).
3243async fn set_channel_visibility(
3244 request: proto::SetChannelVisibility,
3245 response: Response<proto::SetChannelVisibility>,
3246 session: Session,
3247) -> Result<()> {
3248 let db = session.db().await;
3249 let channel_id = ChannelId::from_proto(request.channel_id);
3250 let visibility = request.visibility().into();
3251
3252 let channel_model = db
3253 .set_channel_visibility(channel_id, visibility, session.user_id())
3254 .await?;
3255 let root_id = channel_model.root_id();
3256 let channel = Channel::from_model(channel_model);
3257
3258 let mut connection_pool = session.connection_pool().await;
3259 for (user_id, role) in connection_pool
3260 .channel_user_ids(root_id)
3261 .collect::<Vec<_>>()
3262 .into_iter()
3263 {
3264 let update = if role.can_see_channel(channel.visibility) {
3265 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3266 proto::UpdateChannels {
3267 channels: vec![channel.to_proto()],
3268 ..Default::default()
3269 }
3270 } else {
3271 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3272 proto::UpdateChannels {
3273 delete_channels: vec![channel.id.to_proto()],
3274 ..Default::default()
3275 }
3276 };
3277
3278 for connection_id in connection_pool.user_connection_ids(user_id) {
3279 session.peer.send(connection_id, update.clone())?;
3280 }
3281 }
3282
3283 response.send(proto::Ack {})?;
3284 Ok(())
3285}
3286
3287/// Alter the role for a user in the channel.
3288async fn set_channel_member_role(
3289 request: proto::SetChannelMemberRole,
3290 response: Response<proto::SetChannelMemberRole>,
3291 session: Session,
3292) -> Result<()> {
3293 let db = session.db().await;
3294 let channel_id = ChannelId::from_proto(request.channel_id);
3295 let member_id = UserId::from_proto(request.user_id);
3296 let result = db
3297 .set_channel_member_role(
3298 channel_id,
3299 session.user_id(),
3300 member_id,
3301 request.role().into(),
3302 )
3303 .await?;
3304
3305 match result {
3306 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3307 let mut connection_pool = session.connection_pool().await;
3308 notify_membership_updated(
3309 &mut connection_pool,
3310 membership_update,
3311 member_id,
3312 &session.peer,
3313 )
3314 }
3315 db::SetMemberRoleResult::InviteUpdated(channel) => {
3316 let update = proto::UpdateChannels {
3317 channel_invitations: vec![channel.to_proto()],
3318 ..Default::default()
3319 };
3320
3321 for connection_id in session
3322 .connection_pool()
3323 .await
3324 .user_connection_ids(member_id)
3325 {
3326 session.peer.send(connection_id, update.clone())?;
3327 }
3328 }
3329 }
3330
3331 response.send(proto::Ack {})?;
3332 Ok(())
3333}
3334
3335/// Change the name of a channel
3336async fn rename_channel(
3337 request: proto::RenameChannel,
3338 response: Response<proto::RenameChannel>,
3339 session: Session,
3340) -> Result<()> {
3341 let db = session.db().await;
3342 let channel_id = ChannelId::from_proto(request.channel_id);
3343 let channel_model = db
3344 .rename_channel(channel_id, session.user_id(), &request.name)
3345 .await?;
3346 let root_id = channel_model.root_id();
3347 let channel = Channel::from_model(channel_model);
3348
3349 response.send(proto::RenameChannelResponse {
3350 channel: Some(channel.to_proto()),
3351 })?;
3352
3353 let connection_pool = session.connection_pool().await;
3354 let update = proto::UpdateChannels {
3355 channels: vec![channel.to_proto()],
3356 ..Default::default()
3357 };
3358 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3359 if role.can_see_channel(channel.visibility) {
3360 session.peer.send(connection_id, update.clone())?;
3361 }
3362 }
3363
3364 Ok(())
3365}
3366
3367/// Move a channel to a new parent.
3368async fn move_channel(
3369 request: proto::MoveChannel,
3370 response: Response<proto::MoveChannel>,
3371 session: Session,
3372) -> Result<()> {
3373 let channel_id = ChannelId::from_proto(request.channel_id);
3374 let to = ChannelId::from_proto(request.to);
3375
3376 let (root_id, channels) = session
3377 .db()
3378 .await
3379 .move_channel(channel_id, to, session.user_id())
3380 .await?;
3381
3382 let connection_pool = session.connection_pool().await;
3383 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3384 let channels = channels
3385 .iter()
3386 .filter_map(|channel| {
3387 if role.can_see_channel(channel.visibility) {
3388 Some(channel.to_proto())
3389 } else {
3390 None
3391 }
3392 })
3393 .collect::<Vec<_>>();
3394 if channels.is_empty() {
3395 continue;
3396 }
3397
3398 let update = proto::UpdateChannels {
3399 channels,
3400 ..Default::default()
3401 };
3402
3403 session.peer.send(connection_id, update.clone())?;
3404 }
3405
3406 response.send(Ack {})?;
3407 Ok(())
3408}
3409
3410async fn reorder_channel(
3411 request: proto::ReorderChannel,
3412 response: Response<proto::ReorderChannel>,
3413 session: Session,
3414) -> Result<()> {
3415 let channel_id = ChannelId::from_proto(request.channel_id);
3416 let direction = request.direction();
3417
3418 let updated_channels = session
3419 .db()
3420 .await
3421 .reorder_channel(channel_id, direction, session.user_id())
3422 .await?;
3423
3424 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3425 let connection_pool = session.connection_pool().await;
3426 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3427 let channels = updated_channels
3428 .iter()
3429 .filter_map(|channel| {
3430 if role.can_see_channel(channel.visibility) {
3431 Some(channel.to_proto())
3432 } else {
3433 None
3434 }
3435 })
3436 .collect::<Vec<_>>();
3437
3438 if channels.is_empty() {
3439 continue;
3440 }
3441
3442 let update = proto::UpdateChannels {
3443 channels,
3444 ..Default::default()
3445 };
3446
3447 session.peer.send(connection_id, update.clone())?;
3448 }
3449 }
3450
3451 response.send(Ack {})?;
3452 Ok(())
3453}
3454
3455/// Get the list of channel members
3456async fn get_channel_members(
3457 request: proto::GetChannelMembers,
3458 response: Response<proto::GetChannelMembers>,
3459 session: Session,
3460) -> Result<()> {
3461 let db = session.db().await;
3462 let channel_id = ChannelId::from_proto(request.channel_id);
3463 let limit = if request.limit == 0 {
3464 u16::MAX as u64
3465 } else {
3466 request.limit
3467 };
3468 let (members, users) = db
3469 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3470 .await?;
3471 response.send(proto::GetChannelMembersResponse { members, users })?;
3472 Ok(())
3473}
3474
3475/// Accept or decline a channel invitation.
3476async fn respond_to_channel_invite(
3477 request: proto::RespondToChannelInvite,
3478 response: Response<proto::RespondToChannelInvite>,
3479 session: Session,
3480) -> Result<()> {
3481 let db = session.db().await;
3482 let channel_id = ChannelId::from_proto(request.channel_id);
3483 let RespondToChannelInvite {
3484 membership_update,
3485 notifications,
3486 } = db
3487 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3488 .await?;
3489
3490 let mut connection_pool = session.connection_pool().await;
3491 if let Some(membership_update) = membership_update {
3492 notify_membership_updated(
3493 &mut connection_pool,
3494 membership_update,
3495 session.user_id(),
3496 &session.peer,
3497 );
3498 } else {
3499 let update = proto::UpdateChannels {
3500 remove_channel_invitations: vec![channel_id.to_proto()],
3501 ..Default::default()
3502 };
3503
3504 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3505 session.peer.send(connection_id, update.clone())?;
3506 }
3507 };
3508
3509 send_notifications(&connection_pool, &session.peer, notifications);
3510
3511 response.send(proto::Ack {})?;
3512
3513 Ok(())
3514}
3515
3516/// Join the channels' room
3517async fn join_channel(
3518 request: proto::JoinChannel,
3519 response: Response<proto::JoinChannel>,
3520 session: Session,
3521) -> Result<()> {
3522 let channel_id = ChannelId::from_proto(request.channel_id);
3523 join_channel_internal(channel_id, Box::new(response), session).await
3524}
3525
3526trait JoinChannelInternalResponse {
3527 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3528}
3529impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3530 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3531 Response::<proto::JoinChannel>::send(self, result)
3532 }
3533}
3534impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3535 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3536 Response::<proto::JoinRoom>::send(self, result)
3537 }
3538}
3539
3540async fn join_channel_internal(
3541 channel_id: ChannelId,
3542 response: Box<impl JoinChannelInternalResponse>,
3543 session: Session,
3544) -> Result<()> {
3545 let joined_room = {
3546 let mut db = session.db().await;
3547 // If zed quits without leaving the room, and the user re-opens zed before the
3548 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3549 // room they were in.
3550 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3551 tracing::info!(
3552 stale_connection_id = %connection,
3553 "cleaning up stale connection",
3554 );
3555 drop(db);
3556 leave_room_for_session(&session, connection).await?;
3557 db = session.db().await;
3558 }
3559
3560 let (joined_room, membership_updated, role) = db
3561 .join_channel(channel_id, session.user_id(), session.connection_id)
3562 .await?;
3563
3564 let live_kit_connection_info =
3565 session
3566 .app_state
3567 .livekit_client
3568 .as_ref()
3569 .and_then(|live_kit| {
3570 let (can_publish, token) = if role == ChannelRole::Guest {
3571 (
3572 false,
3573 live_kit
3574 .guest_token(
3575 &joined_room.room.livekit_room,
3576 &session.user_id().to_string(),
3577 )
3578 .trace_err()?,
3579 )
3580 } else {
3581 (
3582 true,
3583 live_kit
3584 .room_token(
3585 &joined_room.room.livekit_room,
3586 &session.user_id().to_string(),
3587 )
3588 .trace_err()?,
3589 )
3590 };
3591
3592 Some(LiveKitConnectionInfo {
3593 server_url: live_kit.url().into(),
3594 token,
3595 can_publish,
3596 })
3597 });
3598
3599 response.send(proto::JoinRoomResponse {
3600 room: Some(joined_room.room.clone()),
3601 channel_id: joined_room
3602 .channel
3603 .as_ref()
3604 .map(|channel| channel.id.to_proto()),
3605 live_kit_connection_info,
3606 })?;
3607
3608 let mut connection_pool = session.connection_pool().await;
3609 if let Some(membership_updated) = membership_updated {
3610 notify_membership_updated(
3611 &mut connection_pool,
3612 membership_updated,
3613 session.user_id(),
3614 &session.peer,
3615 );
3616 }
3617
3618 room_updated(&joined_room.room, &session.peer);
3619
3620 joined_room
3621 };
3622
3623 channel_updated(
3624 &joined_room.channel.context("channel not returned")?,
3625 &joined_room.room,
3626 &session.peer,
3627 &*session.connection_pool().await,
3628 );
3629
3630 update_user_contacts(session.user_id(), &session).await?;
3631 Ok(())
3632}
3633
3634/// Start editing the channel notes
3635async fn join_channel_buffer(
3636 request: proto::JoinChannelBuffer,
3637 response: Response<proto::JoinChannelBuffer>,
3638 session: Session,
3639) -> Result<()> {
3640 let db = session.db().await;
3641 let channel_id = ChannelId::from_proto(request.channel_id);
3642
3643 let open_response = db
3644 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3645 .await?;
3646
3647 let collaborators = open_response.collaborators.clone();
3648 response.send(open_response)?;
3649
3650 let update = UpdateChannelBufferCollaborators {
3651 channel_id: channel_id.to_proto(),
3652 collaborators: collaborators.clone(),
3653 };
3654 channel_buffer_updated(
3655 session.connection_id,
3656 collaborators
3657 .iter()
3658 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3659 &update,
3660 &session.peer,
3661 );
3662
3663 Ok(())
3664}
3665
3666/// Edit the channel notes
3667async fn update_channel_buffer(
3668 request: proto::UpdateChannelBuffer,
3669 session: Session,
3670) -> Result<()> {
3671 let db = session.db().await;
3672 let channel_id = ChannelId::from_proto(request.channel_id);
3673
3674 let (collaborators, epoch, version) = db
3675 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3676 .await?;
3677
3678 channel_buffer_updated(
3679 session.connection_id,
3680 collaborators.clone(),
3681 &proto::UpdateChannelBuffer {
3682 channel_id: channel_id.to_proto(),
3683 operations: request.operations,
3684 },
3685 &session.peer,
3686 );
3687
3688 let pool = &*session.connection_pool().await;
3689
3690 let non_collaborators =
3691 pool.channel_connection_ids(channel_id)
3692 .filter_map(|(connection_id, _)| {
3693 if collaborators.contains(&connection_id) {
3694 None
3695 } else {
3696 Some(connection_id)
3697 }
3698 });
3699
3700 broadcast(None, non_collaborators, |peer_id| {
3701 session.peer.send(
3702 peer_id,
3703 proto::UpdateChannels {
3704 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3705 channel_id: channel_id.to_proto(),
3706 epoch: epoch as u64,
3707 version: version.clone(),
3708 }],
3709 ..Default::default()
3710 },
3711 )
3712 });
3713
3714 Ok(())
3715}
3716
3717/// Rejoin the channel notes after a connection blip
3718async fn rejoin_channel_buffers(
3719 request: proto::RejoinChannelBuffers,
3720 response: Response<proto::RejoinChannelBuffers>,
3721 session: Session,
3722) -> Result<()> {
3723 let db = session.db().await;
3724 let buffers = db
3725 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3726 .await?;
3727
3728 for rejoined_buffer in &buffers {
3729 let collaborators_to_notify = rejoined_buffer
3730 .buffer
3731 .collaborators
3732 .iter()
3733 .filter_map(|c| Some(c.peer_id?.into()));
3734 channel_buffer_updated(
3735 session.connection_id,
3736 collaborators_to_notify,
3737 &proto::UpdateChannelBufferCollaborators {
3738 channel_id: rejoined_buffer.buffer.channel_id,
3739 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3740 },
3741 &session.peer,
3742 );
3743 }
3744
3745 response.send(proto::RejoinChannelBuffersResponse {
3746 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3747 })?;
3748
3749 Ok(())
3750}
3751
3752/// Stop editing the channel notes
3753async fn leave_channel_buffer(
3754 request: proto::LeaveChannelBuffer,
3755 response: Response<proto::LeaveChannelBuffer>,
3756 session: Session,
3757) -> Result<()> {
3758 let db = session.db().await;
3759 let channel_id = ChannelId::from_proto(request.channel_id);
3760
3761 let left_buffer = db
3762 .leave_channel_buffer(channel_id, session.connection_id)
3763 .await?;
3764
3765 response.send(Ack {})?;
3766
3767 channel_buffer_updated(
3768 session.connection_id,
3769 left_buffer.connections,
3770 &proto::UpdateChannelBufferCollaborators {
3771 channel_id: channel_id.to_proto(),
3772 collaborators: left_buffer.collaborators,
3773 },
3774 &session.peer,
3775 );
3776
3777 Ok(())
3778}
3779
3780fn channel_buffer_updated<T: EnvelopedMessage>(
3781 sender_id: ConnectionId,
3782 collaborators: impl IntoIterator<Item = ConnectionId>,
3783 message: &T,
3784 peer: &Peer,
3785) {
3786 broadcast(Some(sender_id), collaborators, |peer_id| {
3787 peer.send(peer_id, message.clone())
3788 });
3789}
3790
3791fn send_notifications(
3792 connection_pool: &ConnectionPool,
3793 peer: &Peer,
3794 notifications: db::NotificationBatch,
3795) {
3796 for (user_id, notification) in notifications {
3797 for connection_id in connection_pool.user_connection_ids(user_id) {
3798 if let Err(error) = peer.send(
3799 connection_id,
3800 proto::AddNotification {
3801 notification: Some(notification.clone()),
3802 },
3803 ) {
3804 tracing::error!(
3805 "failed to send notification to {:?} {}",
3806 connection_id,
3807 error
3808 );
3809 }
3810 }
3811 }
3812}
3813
3814/// Send a message to the channel
3815async fn send_channel_message(
3816 request: proto::SendChannelMessage,
3817 response: Response<proto::SendChannelMessage>,
3818 session: Session,
3819) -> Result<()> {
3820 // Validate the message body.
3821 let body = request.body.trim().to_string();
3822 if body.len() > MAX_MESSAGE_LEN {
3823 return Err(anyhow!("message is too long"))?;
3824 }
3825 if body.is_empty() {
3826 return Err(anyhow!("message can't be blank"))?;
3827 }
3828
3829 // TODO: adjust mentions if body is trimmed
3830
3831 let timestamp = OffsetDateTime::now_utc();
3832 let nonce = request.nonce.context("nonce can't be blank")?;
3833
3834 let channel_id = ChannelId::from_proto(request.channel_id);
3835 let CreatedChannelMessage {
3836 message_id,
3837 participant_connection_ids,
3838 notifications,
3839 } = session
3840 .db()
3841 .await
3842 .create_channel_message(
3843 channel_id,
3844 session.user_id(),
3845 &body,
3846 &request.mentions,
3847 timestamp,
3848 nonce.clone().into(),
3849 request.reply_to_message_id.map(MessageId::from_proto),
3850 )
3851 .await?;
3852
3853 let message = proto::ChannelMessage {
3854 sender_id: session.user_id().to_proto(),
3855 id: message_id.to_proto(),
3856 body,
3857 mentions: request.mentions,
3858 timestamp: timestamp.unix_timestamp() as u64,
3859 nonce: Some(nonce),
3860 reply_to_message_id: request.reply_to_message_id,
3861 edited_at: None,
3862 };
3863 broadcast(
3864 Some(session.connection_id),
3865 participant_connection_ids.clone(),
3866 |connection| {
3867 session.peer.send(
3868 connection,
3869 proto::ChannelMessageSent {
3870 channel_id: channel_id.to_proto(),
3871 message: Some(message.clone()),
3872 },
3873 )
3874 },
3875 );
3876 response.send(proto::SendChannelMessageResponse {
3877 message: Some(message),
3878 })?;
3879
3880 let pool = &*session.connection_pool().await;
3881 let non_participants =
3882 pool.channel_connection_ids(channel_id)
3883 .filter_map(|(connection_id, _)| {
3884 if participant_connection_ids.contains(&connection_id) {
3885 None
3886 } else {
3887 Some(connection_id)
3888 }
3889 });
3890 broadcast(None, non_participants, |peer_id| {
3891 session.peer.send(
3892 peer_id,
3893 proto::UpdateChannels {
3894 latest_channel_message_ids: vec![proto::ChannelMessageId {
3895 channel_id: channel_id.to_proto(),
3896 message_id: message_id.to_proto(),
3897 }],
3898 ..Default::default()
3899 },
3900 )
3901 });
3902 send_notifications(pool, &session.peer, notifications);
3903
3904 Ok(())
3905}
3906
3907/// Delete a channel message
3908async fn remove_channel_message(
3909 request: proto::RemoveChannelMessage,
3910 response: Response<proto::RemoveChannelMessage>,
3911 session: Session,
3912) -> Result<()> {
3913 let channel_id = ChannelId::from_proto(request.channel_id);
3914 let message_id = MessageId::from_proto(request.message_id);
3915 let (connection_ids, existing_notification_ids) = session
3916 .db()
3917 .await
3918 .remove_channel_message(channel_id, message_id, session.user_id())
3919 .await?;
3920
3921 broadcast(
3922 Some(session.connection_id),
3923 connection_ids,
3924 move |connection| {
3925 session.peer.send(connection, request.clone())?;
3926
3927 for notification_id in &existing_notification_ids {
3928 session.peer.send(
3929 connection,
3930 proto::DeleteNotification {
3931 notification_id: (*notification_id).to_proto(),
3932 },
3933 )?;
3934 }
3935
3936 Ok(())
3937 },
3938 );
3939 response.send(proto::Ack {})?;
3940 Ok(())
3941}
3942
3943async fn update_channel_message(
3944 request: proto::UpdateChannelMessage,
3945 response: Response<proto::UpdateChannelMessage>,
3946 session: Session,
3947) -> Result<()> {
3948 let channel_id = ChannelId::from_proto(request.channel_id);
3949 let message_id = MessageId::from_proto(request.message_id);
3950 let updated_at = OffsetDateTime::now_utc();
3951 let UpdatedChannelMessage {
3952 message_id,
3953 participant_connection_ids,
3954 notifications,
3955 reply_to_message_id,
3956 timestamp,
3957 deleted_mention_notification_ids,
3958 updated_mention_notifications,
3959 } = session
3960 .db()
3961 .await
3962 .update_channel_message(
3963 channel_id,
3964 message_id,
3965 session.user_id(),
3966 request.body.as_str(),
3967 &request.mentions,
3968 updated_at,
3969 )
3970 .await?;
3971
3972 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3973
3974 let message = proto::ChannelMessage {
3975 sender_id: session.user_id().to_proto(),
3976 id: message_id.to_proto(),
3977 body: request.body.clone(),
3978 mentions: request.mentions.clone(),
3979 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3980 nonce: Some(nonce),
3981 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3982 edited_at: Some(updated_at.unix_timestamp() as u64),
3983 };
3984
3985 response.send(proto::Ack {})?;
3986
3987 let pool = &*session.connection_pool().await;
3988 broadcast(
3989 Some(session.connection_id),
3990 participant_connection_ids,
3991 |connection| {
3992 session.peer.send(
3993 connection,
3994 proto::ChannelMessageUpdate {
3995 channel_id: channel_id.to_proto(),
3996 message: Some(message.clone()),
3997 },
3998 )?;
3999
4000 for notification_id in &deleted_mention_notification_ids {
4001 session.peer.send(
4002 connection,
4003 proto::DeleteNotification {
4004 notification_id: (*notification_id).to_proto(),
4005 },
4006 )?;
4007 }
4008
4009 for notification in &updated_mention_notifications {
4010 session.peer.send(
4011 connection,
4012 proto::UpdateNotification {
4013 notification: Some(notification.clone()),
4014 },
4015 )?;
4016 }
4017
4018 Ok(())
4019 },
4020 );
4021
4022 send_notifications(pool, &session.peer, notifications);
4023
4024 Ok(())
4025}
4026
4027/// Mark a channel message as read
4028async fn acknowledge_channel_message(
4029 request: proto::AckChannelMessage,
4030 session: Session,
4031) -> Result<()> {
4032 let channel_id = ChannelId::from_proto(request.channel_id);
4033 let message_id = MessageId::from_proto(request.message_id);
4034 let notifications = session
4035 .db()
4036 .await
4037 .observe_channel_message(channel_id, session.user_id(), message_id)
4038 .await?;
4039 send_notifications(
4040 &*session.connection_pool().await,
4041 &session.peer,
4042 notifications,
4043 );
4044 Ok(())
4045}
4046
4047/// Mark a buffer version as synced
4048async fn acknowledge_buffer_version(
4049 request: proto::AckBufferOperation,
4050 session: Session,
4051) -> Result<()> {
4052 let buffer_id = BufferId::from_proto(request.buffer_id);
4053 session
4054 .db()
4055 .await
4056 .observe_buffer_version(
4057 buffer_id,
4058 session.user_id(),
4059 request.epoch as i32,
4060 &request.version,
4061 )
4062 .await?;
4063 Ok(())
4064}
4065
4066/// Get a Supermaven API key for the user
4067async fn get_supermaven_api_key(
4068 _request: proto::GetSupermavenApiKey,
4069 response: Response<proto::GetSupermavenApiKey>,
4070 session: Session,
4071) -> Result<()> {
4072 let user_id: String = session.user_id().to_string();
4073 if !session.is_staff() {
4074 return Err(anyhow!("supermaven not enabled for this account"))?;
4075 }
4076
4077 let email = session.email().context("user must have an email")?;
4078
4079 let supermaven_admin_api = session
4080 .supermaven_client
4081 .as_ref()
4082 .context("supermaven not configured")?;
4083
4084 let result = supermaven_admin_api
4085 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4086 .await?;
4087
4088 response.send(proto::GetSupermavenApiKeyResponse {
4089 api_key: result.api_key,
4090 })?;
4091
4092 Ok(())
4093}
4094
4095/// Start receiving chat updates for a channel
4096async fn join_channel_chat(
4097 request: proto::JoinChannelChat,
4098 response: Response<proto::JoinChannelChat>,
4099 session: Session,
4100) -> Result<()> {
4101 let channel_id = ChannelId::from_proto(request.channel_id);
4102
4103 let db = session.db().await;
4104 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4105 .await?;
4106 let messages = db
4107 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4108 .await?;
4109 response.send(proto::JoinChannelChatResponse {
4110 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4111 messages,
4112 })?;
4113 Ok(())
4114}
4115
4116/// Stop receiving chat updates for a channel
4117async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4118 let channel_id = ChannelId::from_proto(request.channel_id);
4119 session
4120 .db()
4121 .await
4122 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4123 .await?;
4124 Ok(())
4125}
4126
4127/// Retrieve the chat history for a channel
4128async fn get_channel_messages(
4129 request: proto::GetChannelMessages,
4130 response: Response<proto::GetChannelMessages>,
4131 session: Session,
4132) -> Result<()> {
4133 let channel_id = ChannelId::from_proto(request.channel_id);
4134 let messages = session
4135 .db()
4136 .await
4137 .get_channel_messages(
4138 channel_id,
4139 session.user_id(),
4140 MESSAGE_COUNT_PER_PAGE,
4141 Some(MessageId::from_proto(request.before_message_id)),
4142 )
4143 .await?;
4144 response.send(proto::GetChannelMessagesResponse {
4145 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4146 messages,
4147 })?;
4148 Ok(())
4149}
4150
4151/// Retrieve specific chat messages
4152async fn get_channel_messages_by_id(
4153 request: proto::GetChannelMessagesById,
4154 response: Response<proto::GetChannelMessagesById>,
4155 session: Session,
4156) -> Result<()> {
4157 let message_ids = request
4158 .message_ids
4159 .iter()
4160 .map(|id| MessageId::from_proto(*id))
4161 .collect::<Vec<_>>();
4162 let messages = session
4163 .db()
4164 .await
4165 .get_channel_messages_by_id(session.user_id(), &message_ids)
4166 .await?;
4167 response.send(proto::GetChannelMessagesResponse {
4168 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4169 messages,
4170 })?;
4171 Ok(())
4172}
4173
4174/// Retrieve the current users notifications
4175async fn get_notifications(
4176 request: proto::GetNotifications,
4177 response: Response<proto::GetNotifications>,
4178 session: Session,
4179) -> Result<()> {
4180 let notifications = session
4181 .db()
4182 .await
4183 .get_notifications(
4184 session.user_id(),
4185 NOTIFICATION_COUNT_PER_PAGE,
4186 request.before_id.map(db::NotificationId::from_proto),
4187 )
4188 .await?;
4189 response.send(proto::GetNotificationsResponse {
4190 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4191 notifications,
4192 })?;
4193 Ok(())
4194}
4195
4196/// Mark notifications as read
4197async fn mark_notification_as_read(
4198 request: proto::MarkNotificationRead,
4199 response: Response<proto::MarkNotificationRead>,
4200 session: Session,
4201) -> Result<()> {
4202 let database = &session.db().await;
4203 let notifications = database
4204 .mark_notification_as_read_by_id(
4205 session.user_id(),
4206 NotificationId::from_proto(request.notification_id),
4207 )
4208 .await?;
4209 send_notifications(
4210 &*session.connection_pool().await,
4211 &session.peer,
4212 notifications,
4213 );
4214 response.send(proto::Ack {})?;
4215 Ok(())
4216}
4217
4218/// Get the current users information
4219async fn get_private_user_info(
4220 _request: proto::GetPrivateUserInfo,
4221 response: Response<proto::GetPrivateUserInfo>,
4222 session: Session,
4223) -> Result<()> {
4224 let db = session.db().await;
4225
4226 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4227 let user = db
4228 .get_user_by_id(session.user_id())
4229 .await?
4230 .context("user not found")?;
4231 let flags = db.get_user_flags(session.user_id()).await?;
4232
4233 response.send(proto::GetPrivateUserInfoResponse {
4234 metrics_id,
4235 staff: user.admin,
4236 flags,
4237 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4238 })?;
4239 Ok(())
4240}
4241
4242/// Accept the terms of service (tos) on behalf of the current user
4243async fn accept_terms_of_service(
4244 _request: proto::AcceptTermsOfService,
4245 response: Response<proto::AcceptTermsOfService>,
4246 session: Session,
4247) -> Result<()> {
4248 let db = session.db().await;
4249
4250 let accepted_tos_at = Utc::now();
4251 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4252 .await?;
4253
4254 response.send(proto::AcceptTermsOfServiceResponse {
4255 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4256 })?;
4257
4258 // When the user accepts the terms of service, we want to refresh their LLM
4259 // token to grant access.
4260 session
4261 .peer
4262 .send(session.connection_id, proto::RefreshLlmToken {})?;
4263
4264 Ok(())
4265}
4266
4267async fn get_llm_api_token(
4268 _request: proto::GetLlmToken,
4269 response: Response<proto::GetLlmToken>,
4270 session: Session,
4271) -> Result<()> {
4272 let db = session.db().await;
4273
4274 let flags = db.get_user_flags(session.user_id()).await?;
4275
4276 let user_id = session.user_id();
4277 let user = db
4278 .get_user_by_id(user_id)
4279 .await?
4280 .with_context(|| format!("user {user_id} not found"))?;
4281
4282 if user.accepted_tos_at.is_none() {
4283 Err(anyhow!("terms of service not accepted"))?
4284 }
4285
4286 let stripe_client = session
4287 .app_state
4288 .stripe_client
4289 .as_ref()
4290 .context("failed to retrieve Stripe client")?;
4291
4292 let stripe_billing = session
4293 .app_state
4294 .stripe_billing
4295 .as_ref()
4296 .context("failed to retrieve Stripe billing object")?;
4297
4298 let billing_customer = if let Some(billing_customer) =
4299 db.get_billing_customer_by_user_id(user.id).await?
4300 {
4301 billing_customer
4302 } else {
4303 let customer_id = stripe_billing
4304 .find_or_create_customer_by_email(user.email_address.as_deref())
4305 .await?;
4306
4307 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4308 .await?
4309 .context("billing customer not found")?
4310 };
4311
4312 let billing_subscription =
4313 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4314 billing_subscription
4315 } else {
4316 let stripe_customer_id =
4317 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4318
4319 let stripe_subscription = stripe_billing
4320 .subscribe_to_zed_free(stripe_customer_id)
4321 .await?;
4322
4323 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4324 billing_customer_id: billing_customer.id,
4325 kind: Some(SubscriptionKind::ZedFree),
4326 stripe_subscription_id: stripe_subscription.id.to_string(),
4327 stripe_subscription_status: stripe_subscription.status.into(),
4328 stripe_cancellation_reason: None,
4329 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4330 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4331 })
4332 .await?
4333 };
4334
4335 let billing_preferences = db.get_billing_preferences(user.id).await?;
4336
4337 let token = LlmTokenClaims::create(
4338 &user,
4339 session.is_staff(),
4340 billing_customer,
4341 billing_preferences,
4342 &flags,
4343 billing_subscription,
4344 session.system_id.clone(),
4345 &session.app_state.config,
4346 )?;
4347 response.send(proto::GetLlmTokenResponse { token })?;
4348 Ok(())
4349}
4350
4351fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4352 let message = match message {
4353 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4354 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4355 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4356 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4357 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4358 code: frame.code.into(),
4359 reason: frame.reason.as_str().to_owned().into(),
4360 })),
4361 // We should never receive a frame while reading the message, according
4362 // to the `tungstenite` maintainers:
4363 //
4364 // > It cannot occur when you read messages from the WebSocket, but it
4365 // > can be used when you want to send the raw frames (e.g. you want to
4366 // > send the frames to the WebSocket without composing the full message first).
4367 // >
4368 // > — https://github.com/snapview/tungstenite-rs/issues/268
4369 TungsteniteMessage::Frame(_) => {
4370 bail!("received an unexpected frame while reading the message")
4371 }
4372 };
4373
4374 Ok(message)
4375}
4376
4377fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4378 match message {
4379 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4380 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4381 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4382 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4383 AxumMessage::Close(frame) => {
4384 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4385 code: frame.code.into(),
4386 reason: frame.reason.as_ref().into(),
4387 }))
4388 }
4389 }
4390}
4391
4392fn notify_membership_updated(
4393 connection_pool: &mut ConnectionPool,
4394 result: MembershipUpdated,
4395 user_id: UserId,
4396 peer: &Peer,
4397) {
4398 for membership in &result.new_channels.channel_memberships {
4399 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4400 }
4401 for channel_id in &result.removed_channels {
4402 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4403 }
4404
4405 let user_channels_update = proto::UpdateUserChannels {
4406 channel_memberships: result
4407 .new_channels
4408 .channel_memberships
4409 .iter()
4410 .map(|cm| proto::ChannelMembership {
4411 channel_id: cm.channel_id.to_proto(),
4412 role: cm.role.into(),
4413 })
4414 .collect(),
4415 ..Default::default()
4416 };
4417
4418 let mut update = build_channels_update(result.new_channels);
4419 update.delete_channels = result
4420 .removed_channels
4421 .into_iter()
4422 .map(|id| id.to_proto())
4423 .collect();
4424 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4425
4426 for connection_id in connection_pool.user_connection_ids(user_id) {
4427 peer.send(connection_id, user_channels_update.clone())
4428 .trace_err();
4429 peer.send(connection_id, update.clone()).trace_err();
4430 }
4431}
4432
4433fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4434 proto::UpdateUserChannels {
4435 channel_memberships: channels
4436 .channel_memberships
4437 .iter()
4438 .map(|m| proto::ChannelMembership {
4439 channel_id: m.channel_id.to_proto(),
4440 role: m.role.into(),
4441 })
4442 .collect(),
4443 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4444 observed_channel_message_id: channels.observed_channel_messages.clone(),
4445 }
4446}
4447
4448fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4449 let mut update = proto::UpdateChannels::default();
4450
4451 for channel in channels.channels {
4452 update.channels.push(channel.to_proto());
4453 }
4454
4455 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4456 update.latest_channel_message_ids = channels.latest_channel_messages;
4457
4458 for (channel_id, participants) in channels.channel_participants {
4459 update
4460 .channel_participants
4461 .push(proto::ChannelParticipants {
4462 channel_id: channel_id.to_proto(),
4463 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4464 });
4465 }
4466
4467 for channel in channels.invited_channels {
4468 update.channel_invitations.push(channel.to_proto());
4469 }
4470
4471 update
4472}
4473
4474fn build_initial_contacts_update(
4475 contacts: Vec<db::Contact>,
4476 pool: &ConnectionPool,
4477) -> proto::UpdateContacts {
4478 let mut update = proto::UpdateContacts::default();
4479
4480 for contact in contacts {
4481 match contact {
4482 db::Contact::Accepted { user_id, busy } => {
4483 update.contacts.push(contact_for_user(user_id, busy, pool));
4484 }
4485 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4486 db::Contact::Incoming { user_id } => {
4487 update
4488 .incoming_requests
4489 .push(proto::IncomingContactRequest {
4490 requester_id: user_id.to_proto(),
4491 })
4492 }
4493 }
4494 }
4495
4496 update
4497}
4498
4499fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4500 proto::Contact {
4501 user_id: user_id.to_proto(),
4502 online: pool.is_user_online(user_id),
4503 busy,
4504 }
4505}
4506
4507fn room_updated(room: &proto::Room, peer: &Peer) {
4508 broadcast(
4509 None,
4510 room.participants
4511 .iter()
4512 .filter_map(|participant| Some(participant.peer_id?.into())),
4513 |peer_id| {
4514 peer.send(
4515 peer_id,
4516 proto::RoomUpdated {
4517 room: Some(room.clone()),
4518 },
4519 )
4520 },
4521 );
4522}
4523
4524fn channel_updated(
4525 channel: &db::channel::Model,
4526 room: &proto::Room,
4527 peer: &Peer,
4528 pool: &ConnectionPool,
4529) {
4530 let participants = room
4531 .participants
4532 .iter()
4533 .map(|p| p.user_id)
4534 .collect::<Vec<_>>();
4535
4536 broadcast(
4537 None,
4538 pool.channel_connection_ids(channel.root_id())
4539 .filter_map(|(channel_id, role)| {
4540 role.can_see_channel(channel.visibility)
4541 .then_some(channel_id)
4542 }),
4543 |peer_id| {
4544 peer.send(
4545 peer_id,
4546 proto::UpdateChannels {
4547 channel_participants: vec![proto::ChannelParticipants {
4548 channel_id: channel.id.to_proto(),
4549 participant_user_ids: participants.clone(),
4550 }],
4551 ..Default::default()
4552 },
4553 )
4554 },
4555 );
4556}
4557
4558async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4559 let db = session.db().await;
4560
4561 let contacts = db.get_contacts(user_id).await?;
4562 let busy = db.is_user_busy(user_id).await?;
4563
4564 let pool = session.connection_pool().await;
4565 let updated_contact = contact_for_user(user_id, busy, &pool);
4566 for contact in contacts {
4567 if let db::Contact::Accepted {
4568 user_id: contact_user_id,
4569 ..
4570 } = contact
4571 {
4572 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4573 session
4574 .peer
4575 .send(
4576 contact_conn_id,
4577 proto::UpdateContacts {
4578 contacts: vec![updated_contact.clone()],
4579 remove_contacts: Default::default(),
4580 incoming_requests: Default::default(),
4581 remove_incoming_requests: Default::default(),
4582 outgoing_requests: Default::default(),
4583 remove_outgoing_requests: Default::default(),
4584 },
4585 )
4586 .trace_err();
4587 }
4588 }
4589 }
4590 Ok(())
4591}
4592
4593async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4594 let mut contacts_to_update = HashSet::default();
4595
4596 let room_id;
4597 let canceled_calls_to_user_ids;
4598 let livekit_room;
4599 let delete_livekit_room;
4600 let room;
4601 let channel;
4602
4603 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4604 contacts_to_update.insert(session.user_id());
4605
4606 for project in left_room.left_projects.values() {
4607 project_left(project, session);
4608 }
4609
4610 room_id = RoomId::from_proto(left_room.room.id);
4611 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4612 livekit_room = mem::take(&mut left_room.room.livekit_room);
4613 delete_livekit_room = left_room.deleted;
4614 room = mem::take(&mut left_room.room);
4615 channel = mem::take(&mut left_room.channel);
4616
4617 room_updated(&room, &session.peer);
4618 } else {
4619 return Ok(());
4620 }
4621
4622 if let Some(channel) = channel {
4623 channel_updated(
4624 &channel,
4625 &room,
4626 &session.peer,
4627 &*session.connection_pool().await,
4628 );
4629 }
4630
4631 {
4632 let pool = session.connection_pool().await;
4633 for canceled_user_id in canceled_calls_to_user_ids {
4634 for connection_id in pool.user_connection_ids(canceled_user_id) {
4635 session
4636 .peer
4637 .send(
4638 connection_id,
4639 proto::CallCanceled {
4640 room_id: room_id.to_proto(),
4641 },
4642 )
4643 .trace_err();
4644 }
4645 contacts_to_update.insert(canceled_user_id);
4646 }
4647 }
4648
4649 for contact_user_id in contacts_to_update {
4650 update_user_contacts(contact_user_id, session).await?;
4651 }
4652
4653 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4654 live_kit
4655 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4656 .await
4657 .trace_err();
4658
4659 if delete_livekit_room {
4660 live_kit.delete_room(livekit_room).await.trace_err();
4661 }
4662 }
4663
4664 Ok(())
4665}
4666
4667async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4668 let left_channel_buffers = session
4669 .db()
4670 .await
4671 .leave_channel_buffers(session.connection_id)
4672 .await?;
4673
4674 for left_buffer in left_channel_buffers {
4675 channel_buffer_updated(
4676 session.connection_id,
4677 left_buffer.connections,
4678 &proto::UpdateChannelBufferCollaborators {
4679 channel_id: left_buffer.channel_id.to_proto(),
4680 collaborators: left_buffer.collaborators,
4681 },
4682 &session.peer,
4683 );
4684 }
4685
4686 Ok(())
4687}
4688
4689fn project_left(project: &db::LeftProject, session: &Session) {
4690 for connection_id in &project.connection_ids {
4691 if project.should_unshare {
4692 session
4693 .peer
4694 .send(
4695 *connection_id,
4696 proto::UnshareProject {
4697 project_id: project.id.to_proto(),
4698 },
4699 )
4700 .trace_err();
4701 } else {
4702 session
4703 .peer
4704 .send(
4705 *connection_id,
4706 proto::RemoveProjectCollaborator {
4707 project_id: project.id.to_proto(),
4708 peer_id: Some(session.connection_id.into()),
4709 },
4710 )
4711 .trace_err();
4712 }
4713 }
4714}
4715
4716pub trait ResultExt {
4717 type Ok;
4718
4719 fn trace_err(self) -> Option<Self::Ok>;
4720}
4721
4722impl<T, E> ResultExt for Result<T, E>
4723where
4724 E: std::fmt::Debug,
4725{
4726 type Ok = T;
4727
4728 #[track_caller]
4729 fn trace_err(self) -> Option<T> {
4730 match self {
4731 Ok(value) => Some(value),
4732 Err(error) => {
4733 tracing::error!("{:?}", error);
4734 None
4735 }
4736 }
4737 }
4738}