1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::headers::UserAgent;
27use axum::{
28 Extension, Router, TypedHeader,
29 body::Body,
30 extract::{
31 ConnectInfo, WebSocketUpgrade,
32 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
33 },
34 headers::{Header, HeaderName},
35 http::StatusCode,
36 middleware,
37 response::IntoResponse,
38 routing::get,
39};
40use chrono::Utc;
41use collections::{HashMap, HashSet};
42pub use connection_pool::{ConnectionPool, ZedVersion};
43use core::fmt::{self, Debug, Formatter};
44use reqwest_client::ReqwestClient;
45use rpc::proto::{MultiLspQuery, split_repository_update};
46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
47
48use futures::{
49 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
50 stream::FuturesUnordered,
51};
52use prometheus::{IntGauge, register_int_gauge};
53use rpc::{
54 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
55 proto::{
56 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
57 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
58 },
59};
60use semantic_version::SemanticVersion;
61use serde::{Serialize, Serializer};
62use std::{
63 any::TypeId,
64 future::Future,
65 marker::PhantomData,
66 mem,
67 net::SocketAddr,
68 ops::{Deref, DerefMut},
69 rc::Rc,
70 sync::{
71 Arc, OnceLock,
72 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
73 },
74 time::{Duration, Instant},
75};
76use time::OffsetDateTime;
77use tokio::sync::{Semaphore, watch};
78use tower::ServiceBuilder;
79use tracing::{
80 Instrument,
81 field::{self},
82 info_span, instrument,
83};
84
85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
86
87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
89
90const MESSAGE_COUNT_PER_PAGE: usize = 100;
91const MAX_MESSAGE_LEN: usize = 1024;
92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
94
95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
96
97type MessageHandler =
98 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
99
100pub struct ConnectionGuard;
101
102impl ConnectionGuard {
103 pub fn try_acquire() -> Result<Self, ()> {
104 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
105 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
106 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
107 tracing::error!(
108 "too many concurrent connections: {}",
109 current_connections + 1
110 );
111 return Err(());
112 }
113 Ok(ConnectionGuard)
114 }
115}
116
117impl Drop for ConnectionGuard {
118 fn drop(&mut self) {
119 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
120 }
121}
122
123struct Response<R> {
124 peer: Arc<Peer>,
125 receipt: Receipt<R>,
126 responded: Arc<AtomicBool>,
127}
128
129impl<R: RequestMessage> Response<R> {
130 fn send(self, payload: R::Response) -> Result<()> {
131 self.responded.store(true, SeqCst);
132 self.peer.respond(self.receipt, payload)?;
133 Ok(())
134 }
135}
136
137#[derive(Clone, Debug)]
138pub enum Principal {
139 User(User),
140 Impersonated { user: User, admin: User },
141}
142
143impl Principal {
144 fn user(&self) -> &User {
145 match self {
146 Principal::User(user) => user,
147 Principal::Impersonated { user, .. } => user,
148 }
149 }
150
151 fn update_span(&self, span: &tracing::Span) {
152 match &self {
153 Principal::User(user) => {
154 span.record("user_id", user.id.0);
155 span.record("login", &user.github_login);
156 }
157 Principal::Impersonated { user, admin } => {
158 span.record("user_id", user.id.0);
159 span.record("login", &user.github_login);
160 span.record("impersonator", &admin.github_login);
161 }
162 }
163 }
164}
165
166#[derive(Clone)]
167struct Session {
168 principal: Principal,
169 connection_id: ConnectionId,
170 db: Arc<tokio::sync::Mutex<DbHandle>>,
171 peer: Arc<Peer>,
172 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
173 app_state: Arc<AppState>,
174 supermaven_client: Option<Arc<SupermavenAdminApi>>,
175 /// The GeoIP country code for the user.
176 #[allow(unused)]
177 geoip_country_code: Option<String>,
178 system_id: Option<String>,
179 _executor: Executor,
180}
181
182impl Session {
183 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
184 #[cfg(test)]
185 tokio::task::yield_now().await;
186 let guard = self.db.lock().await;
187 #[cfg(test)]
188 tokio::task::yield_now().await;
189 guard
190 }
191
192 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
193 #[cfg(test)]
194 tokio::task::yield_now().await;
195 let guard = self.connection_pool.lock();
196 ConnectionPoolGuard {
197 guard,
198 _not_send: PhantomData,
199 }
200 }
201
202 fn is_staff(&self) -> bool {
203 match &self.principal {
204 Principal::User(user) => user.admin,
205 Principal::Impersonated { .. } => true,
206 }
207 }
208
209 fn user_id(&self) -> UserId {
210 match &self.principal {
211 Principal::User(user) => user.id,
212 Principal::Impersonated { user, .. } => user.id,
213 }
214 }
215
216 pub fn email(&self) -> Option<String> {
217 match &self.principal {
218 Principal::User(user) => user.email_address.clone(),
219 Principal::Impersonated { user, .. } => user.email_address.clone(),
220 }
221 }
222}
223
224impl Debug for Session {
225 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
226 let mut result = f.debug_struct("Session");
227 match &self.principal {
228 Principal::User(user) => {
229 result.field("user", &user.github_login);
230 }
231 Principal::Impersonated { user, admin } => {
232 result.field("user", &user.github_login);
233 result.field("impersonator", &admin.github_login);
234 }
235 }
236 result.field("connection_id", &self.connection_id).finish()
237 }
238}
239
240struct DbHandle(Arc<Database>);
241
242impl Deref for DbHandle {
243 type Target = Database;
244
245 fn deref(&self) -> &Self::Target {
246 self.0.as_ref()
247 }
248}
249
250pub struct Server {
251 id: parking_lot::Mutex<ServerId>,
252 peer: Arc<Peer>,
253 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
254 app_state: Arc<AppState>,
255 handlers: HashMap<TypeId, MessageHandler>,
256 teardown: watch::Sender<bool>,
257}
258
259pub(crate) struct ConnectionPoolGuard<'a> {
260 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
261 _not_send: PhantomData<Rc<()>>,
262}
263
264#[derive(Serialize)]
265pub struct ServerSnapshot<'a> {
266 peer: &'a Peer,
267 #[serde(serialize_with = "serialize_deref")]
268 connection_pool: ConnectionPoolGuard<'a>,
269}
270
271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
272where
273 S: Serializer,
274 T: Deref<Target = U>,
275 U: Serialize,
276{
277 Serialize::serialize(value.deref(), serializer)
278}
279
280impl Server {
281 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
282 let mut server = Self {
283 id: parking_lot::Mutex::new(id),
284 peer: Peer::new(id.0 as u32),
285 app_state: app_state.clone(),
286 connection_pool: Default::default(),
287 handlers: Default::default(),
288 teardown: watch::channel(false).0,
289 };
290
291 server
292 .add_request_handler(ping)
293 .add_request_handler(create_room)
294 .add_request_handler(join_room)
295 .add_request_handler(rejoin_room)
296 .add_request_handler(leave_room)
297 .add_request_handler(set_room_participant_role)
298 .add_request_handler(call)
299 .add_request_handler(cancel_call)
300 .add_message_handler(decline_call)
301 .add_request_handler(update_participant_location)
302 .add_request_handler(share_project)
303 .add_message_handler(unshare_project)
304 .add_request_handler(join_project)
305 .add_message_handler(leave_project)
306 .add_request_handler(update_project)
307 .add_request_handler(update_worktree)
308 .add_request_handler(update_repository)
309 .add_request_handler(remove_repository)
310 .add_message_handler(start_language_server)
311 .add_message_handler(update_language_server)
312 .add_message_handler(update_diagnostic_summary)
313 .add_message_handler(update_worktree_settings)
314 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
316 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
317 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
318 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
319 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
320 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
321 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
322 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
323 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
324 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
325 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
326 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
327 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
328 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
329 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
330 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
331 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
332 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
333 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
334 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
335 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
336 .add_request_handler(
337 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
338 )
339 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
340 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
341 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
342 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
344 .add_request_handler(
345 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
346 )
347 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
348 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
349 .add_request_handler(
350 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
351 )
352 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
353 .add_request_handler(
354 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
355 )
356 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
357 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
358 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
359 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
360 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
361 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
362 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
363 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
364 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
365 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
368 .add_request_handler(
369 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
370 )
371 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
372 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
373 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
374 .add_request_handler(multi_lsp_query)
375 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
376 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
377 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
378 .add_message_handler(create_buffer_for_peer)
379 .add_request_handler(update_buffer)
380 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
381 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
382 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
383 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
384 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
385 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
386 .add_message_handler(
387 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
388 )
389 .add_request_handler(get_users)
390 .add_request_handler(fuzzy_search_users)
391 .add_request_handler(request_contact)
392 .add_request_handler(remove_contact)
393 .add_request_handler(respond_to_contact_request)
394 .add_message_handler(subscribe_to_channels)
395 .add_request_handler(create_channel)
396 .add_request_handler(delete_channel)
397 .add_request_handler(invite_channel_member)
398 .add_request_handler(remove_channel_member)
399 .add_request_handler(set_channel_member_role)
400 .add_request_handler(set_channel_visibility)
401 .add_request_handler(rename_channel)
402 .add_request_handler(join_channel_buffer)
403 .add_request_handler(leave_channel_buffer)
404 .add_message_handler(update_channel_buffer)
405 .add_request_handler(rejoin_channel_buffers)
406 .add_request_handler(get_channel_members)
407 .add_request_handler(respond_to_channel_invite)
408 .add_request_handler(join_channel)
409 .add_request_handler(join_channel_chat)
410 .add_message_handler(leave_channel_chat)
411 .add_request_handler(send_channel_message)
412 .add_request_handler(remove_channel_message)
413 .add_request_handler(update_channel_message)
414 .add_request_handler(get_channel_messages)
415 .add_request_handler(get_channel_messages_by_id)
416 .add_request_handler(get_notifications)
417 .add_request_handler(mark_notification_as_read)
418 .add_request_handler(move_channel)
419 .add_request_handler(reorder_channel)
420 .add_request_handler(follow)
421 .add_message_handler(unfollow)
422 .add_message_handler(update_followers)
423 .add_request_handler(get_private_user_info)
424 .add_request_handler(get_llm_api_token)
425 .add_request_handler(accept_terms_of_service)
426 .add_message_handler(acknowledge_channel_message)
427 .add_message_handler(acknowledge_buffer_version)
428 .add_request_handler(get_supermaven_api_key)
429 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
430 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
431 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
432 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
433 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
434 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
435 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
436 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
437 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
438 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
439 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
440 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
441 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
442 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
443 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
444 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
445 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
446 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
447 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
448 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
449 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
450 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
451 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
452 .add_message_handler(update_context);
453
454 Arc::new(server)
455 }
456
457 pub async fn start(&self) -> Result<()> {
458 let server_id = *self.id.lock();
459 let app_state = self.app_state.clone();
460 let peer = self.peer.clone();
461 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
462 let pool = self.connection_pool.clone();
463 let livekit_client = self.app_state.livekit_client.clone();
464
465 let span = info_span!("start server");
466 self.app_state.executor.spawn_detached(
467 async move {
468 tracing::info!("waiting for cleanup timeout");
469 timeout.await;
470 tracing::info!("cleanup timeout expired, retrieving stale rooms");
471
472 app_state
473 .db
474 .delete_stale_channel_chat_participants(
475 &app_state.config.zed_environment,
476 server_id,
477 )
478 .await
479 .trace_err();
480
481 if let Some((room_ids, channel_ids)) = app_state
482 .db
483 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
484 .await
485 .trace_err()
486 {
487 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
488 tracing::info!(
489 stale_channel_buffer_count = channel_ids.len(),
490 "retrieved stale channel buffers"
491 );
492
493 for channel_id in channel_ids {
494 if let Some(refreshed_channel_buffer) = app_state
495 .db
496 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
497 .await
498 .trace_err()
499 {
500 for connection_id in refreshed_channel_buffer.connection_ids {
501 peer.send(
502 connection_id,
503 proto::UpdateChannelBufferCollaborators {
504 channel_id: channel_id.to_proto(),
505 collaborators: refreshed_channel_buffer
506 .collaborators
507 .clone(),
508 },
509 )
510 .trace_err();
511 }
512 }
513 }
514
515 for room_id in room_ids {
516 let mut contacts_to_update = HashSet::default();
517 let mut canceled_calls_to_user_ids = Vec::new();
518 let mut livekit_room = String::new();
519 let mut delete_livekit_room = false;
520
521 if let Some(mut refreshed_room) = app_state
522 .db
523 .clear_stale_room_participants(room_id, server_id)
524 .await
525 .trace_err()
526 {
527 tracing::info!(
528 room_id = room_id.0,
529 new_participant_count = refreshed_room.room.participants.len(),
530 "refreshed room"
531 );
532 room_updated(&refreshed_room.room, &peer);
533 if let Some(channel) = refreshed_room.channel.as_ref() {
534 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
535 }
536 contacts_to_update
537 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
538 contacts_to_update
539 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
540 canceled_calls_to_user_ids =
541 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
542 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
543 delete_livekit_room = refreshed_room.room.participants.is_empty();
544 }
545
546 {
547 let pool = pool.lock();
548 for canceled_user_id in canceled_calls_to_user_ids {
549 for connection_id in pool.user_connection_ids(canceled_user_id) {
550 peer.send(
551 connection_id,
552 proto::CallCanceled {
553 room_id: room_id.to_proto(),
554 },
555 )
556 .trace_err();
557 }
558 }
559 }
560
561 for user_id in contacts_to_update {
562 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
563 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
564 if let Some((busy, contacts)) = busy.zip(contacts) {
565 let pool = pool.lock();
566 let updated_contact = contact_for_user(user_id, busy, &pool);
567 for contact in contacts {
568 if let db::Contact::Accepted {
569 user_id: contact_user_id,
570 ..
571 } = contact
572 {
573 for contact_conn_id in
574 pool.user_connection_ids(contact_user_id)
575 {
576 peer.send(
577 contact_conn_id,
578 proto::UpdateContacts {
579 contacts: vec![updated_contact.clone()],
580 remove_contacts: Default::default(),
581 incoming_requests: Default::default(),
582 remove_incoming_requests: Default::default(),
583 outgoing_requests: Default::default(),
584 remove_outgoing_requests: Default::default(),
585 },
586 )
587 .trace_err();
588 }
589 }
590 }
591 }
592 }
593
594 if let Some(live_kit) = livekit_client.as_ref() {
595 if delete_livekit_room {
596 live_kit.delete_room(livekit_room).await.trace_err();
597 }
598 }
599 }
600 }
601
602 app_state
603 .db
604 .delete_stale_channel_chat_participants(
605 &app_state.config.zed_environment,
606 server_id,
607 )
608 .await
609 .trace_err();
610
611 app_state
612 .db
613 .clear_old_worktree_entries(server_id)
614 .await
615 .trace_err();
616
617 app_state
618 .db
619 .delete_stale_servers(&app_state.config.zed_environment, server_id)
620 .await
621 .trace_err();
622 }
623 .instrument(span),
624 );
625 Ok(())
626 }
627
628 pub fn teardown(&self) {
629 self.peer.teardown();
630 self.connection_pool.lock().reset();
631 let _ = self.teardown.send(true);
632 }
633
634 #[cfg(test)]
635 pub fn reset(&self, id: ServerId) {
636 self.teardown();
637 *self.id.lock() = id;
638 self.peer.reset(id.0 as u32);
639 let _ = self.teardown.send(false);
640 }
641
642 #[cfg(test)]
643 pub fn id(&self) -> ServerId {
644 *self.id.lock()
645 }
646
647 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
648 where
649 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
650 Fut: 'static + Send + Future<Output = Result<()>>,
651 M: EnvelopedMessage,
652 {
653 let prev_handler = self.handlers.insert(
654 TypeId::of::<M>(),
655 Box::new(move |envelope, session| {
656 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
657 let received_at = envelope.received_at;
658 tracing::info!("message received");
659 let start_time = Instant::now();
660 let future = (handler)(*envelope, session);
661 async move {
662 let result = future.await;
663 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
664 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
665 let queue_duration_ms = total_duration_ms - processing_duration_ms;
666
667 match result {
668 Err(error) => {
669 tracing::error!(
670 ?error,
671 total_duration_ms,
672 processing_duration_ms,
673 queue_duration_ms,
674 "error handling message"
675 )
676 }
677 Ok(()) => tracing::info!(
678 total_duration_ms,
679 processing_duration_ms,
680 queue_duration_ms,
681 "finished handling message"
682 ),
683 }
684 }
685 .boxed()
686 }),
687 );
688 if prev_handler.is_some() {
689 panic!("registered a handler for the same message twice");
690 }
691 self
692 }
693
694 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
695 where
696 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
697 Fut: 'static + Send + Future<Output = Result<()>>,
698 M: EnvelopedMessage,
699 {
700 self.add_handler(move |envelope, session| handler(envelope.payload, session));
701 self
702 }
703
704 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
705 where
706 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
707 Fut: Send + Future<Output = Result<()>>,
708 M: RequestMessage,
709 {
710 let handler = Arc::new(handler);
711 self.add_handler(move |envelope, session| {
712 let receipt = envelope.receipt();
713 let handler = handler.clone();
714 async move {
715 let peer = session.peer.clone();
716 let responded = Arc::new(AtomicBool::default());
717 let response = Response {
718 peer: peer.clone(),
719 responded: responded.clone(),
720 receipt,
721 };
722 match (handler)(envelope.payload, response, session).await {
723 Ok(()) => {
724 if responded.load(std::sync::atomic::Ordering::SeqCst) {
725 Ok(())
726 } else {
727 Err(anyhow!("handler did not send a response"))?
728 }
729 }
730 Err(error) => {
731 let proto_err = match &error {
732 Error::Internal(err) => err.to_proto(),
733 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
734 };
735 peer.respond_with_error(receipt, proto_err)?;
736 Err(error)
737 }
738 }
739 }
740 })
741 }
742
743 pub fn handle_connection(
744 self: &Arc<Self>,
745 connection: Connection,
746 address: String,
747 principal: Principal,
748 zed_version: ZedVersion,
749 release_channel: Option<String>,
750 user_agent: Option<String>,
751 geoip_country_code: Option<String>,
752 system_id: Option<String>,
753 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
754 executor: Executor,
755 connection_guard: Option<ConnectionGuard>,
756 ) -> impl Future<Output = ()> + use<> {
757 let this = self.clone();
758 let span = info_span!("handle connection", %address,
759 connection_id=field::Empty,
760 user_id=field::Empty,
761 login=field::Empty,
762 impersonator=field::Empty,
763 user_agent=field::Empty,
764 geoip_country_code=field::Empty
765 );
766 principal.update_span(&span);
767 if let Some(user_agent) = user_agent {
768 span.record("user_agent", user_agent);
769 }
770 if let Some(release_channel) = release_channel {
771 span.record("release_channel", release_channel);
772 }
773
774 if let Some(country_code) = geoip_country_code.as_ref() {
775 span.record("geoip_country_code", country_code);
776 }
777
778 let mut teardown = self.teardown.subscribe();
779 async move {
780 if *teardown.borrow() {
781 tracing::error!("server is tearing down");
782 return;
783 }
784
785 let (connection_id, handle_io, mut incoming_rx) =
786 this.peer.add_connection(connection, {
787 let executor = executor.clone();
788 move |duration| executor.sleep(duration)
789 });
790 tracing::Span::current().record("connection_id", format!("{}", connection_id));
791
792 tracing::info!("connection opened");
793
794 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
795 let http_client = match ReqwestClient::user_agent(&user_agent) {
796 Ok(http_client) => Arc::new(http_client),
797 Err(error) => {
798 tracing::error!(?error, "failed to create HTTP client");
799 return;
800 }
801 };
802
803 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
804 |supermaven_admin_api_key| {
805 Arc::new(SupermavenAdminApi::new(
806 supermaven_admin_api_key.to_string(),
807 http_client.clone(),
808 ))
809 },
810 );
811
812 let session = Session {
813 principal: principal.clone(),
814 connection_id,
815 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
816 peer: this.peer.clone(),
817 connection_pool: this.connection_pool.clone(),
818 app_state: this.app_state.clone(),
819 geoip_country_code,
820 system_id,
821 _executor: executor.clone(),
822 supermaven_client,
823 };
824
825 if let Err(error) = this
826 .send_initial_client_update(
827 connection_id,
828 zed_version,
829 send_connection_id,
830 &session,
831 )
832 .await
833 {
834 tracing::error!(?error, "failed to send initial client update");
835 return;
836 }
837 drop(connection_guard);
838
839 let handle_io = handle_io.fuse();
840 futures::pin_mut!(handle_io);
841
842 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
843 // This prevents deadlocks when e.g., client A performs a request to client B and
844 // client B performs a request to client A. If both clients stop processing further
845 // messages until their respective request completes, they won't have a chance to
846 // respond to the other client's request and cause a deadlock.
847 //
848 // This arrangement ensures we will attempt to process earlier messages first, but fall
849 // back to processing messages arrived later in the spirit of making progress.
850 const MAX_CONCURRENT_HANDLERS: usize = 256;
851 let mut foreground_message_handlers = FuturesUnordered::new();
852 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
853 let get_concurrent_handlers = {
854 let concurrent_handlers = concurrent_handlers.clone();
855 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
856 };
857 loop {
858 let next_message = async {
859 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
860 let message = incoming_rx.next().await;
861 // Cache the concurrent_handlers here, so that we know what the
862 // queue looks like as each handler starts
863 (permit, message, get_concurrent_handlers())
864 }
865 .fuse();
866 futures::pin_mut!(next_message);
867 futures::select_biased! {
868 _ = teardown.changed().fuse() => return,
869 result = handle_io => {
870 if let Err(error) = result {
871 tracing::error!(?error, "error handling I/O");
872 }
873 break;
874 }
875 _ = foreground_message_handlers.next() => {}
876 next_message = next_message => {
877 let (permit, message, concurrent_handlers) = next_message;
878 if let Some(message) = message {
879 let type_name = message.payload_type_name();
880 // note: we copy all the fields from the parent span so we can query them in the logs.
881 // (https://github.com/tokio-rs/tracing/issues/2670).
882 let span = tracing::info_span!("receive message",
883 %connection_id,
884 %address,
885 type_name,
886 concurrent_handlers,
887 user_id=field::Empty,
888 login=field::Empty,
889 impersonator=field::Empty,
890 multi_lsp_query_request=field::Empty,
891 );
892 principal.update_span(&span);
893 let span_enter = span.enter();
894 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
895 let is_background = message.is_background();
896 let handle_message = (handler)(message, session.clone());
897 drop(span_enter);
898
899 let handle_message = async move {
900 handle_message.await;
901 drop(permit);
902 }.instrument(span);
903 if is_background {
904 executor.spawn_detached(handle_message);
905 } else {
906 foreground_message_handlers.push(handle_message);
907 }
908 } else {
909 tracing::error!("no message handler");
910 }
911 } else {
912 tracing::info!("connection closed");
913 break;
914 }
915 }
916 }
917 }
918
919 drop(foreground_message_handlers);
920 let concurrent_handlers = get_concurrent_handlers();
921 tracing::info!(concurrent_handlers, "signing out");
922 if let Err(error) = connection_lost(session, teardown, executor).await {
923 tracing::error!(?error, "error signing out");
924 }
925 }
926 .instrument(span)
927 }
928
929 async fn send_initial_client_update(
930 &self,
931 connection_id: ConnectionId,
932 zed_version: ZedVersion,
933 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
934 session: &Session,
935 ) -> Result<()> {
936 self.peer.send(
937 connection_id,
938 proto::Hello {
939 peer_id: Some(connection_id.into()),
940 },
941 )?;
942 tracing::info!("sent hello message");
943 if let Some(send_connection_id) = send_connection_id.take() {
944 let _ = send_connection_id.send(connection_id);
945 }
946
947 match &session.principal {
948 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
949 if !user.connected_once {
950 self.peer.send(connection_id, proto::ShowContacts {})?;
951 self.app_state
952 .db
953 .set_user_connected_once(user.id, true)
954 .await?;
955 }
956
957 update_user_plan(session).await?;
958
959 let contacts = self.app_state.db.get_contacts(user.id).await?;
960
961 {
962 let mut pool = self.connection_pool.lock();
963 pool.add_connection(connection_id, user.id, user.admin, zed_version);
964 self.peer.send(
965 connection_id,
966 build_initial_contacts_update(contacts, &pool),
967 )?;
968 }
969
970 if should_auto_subscribe_to_channels(zed_version) {
971 subscribe_user_to_channels(user.id, session).await?;
972 }
973
974 if let Some(incoming_call) =
975 self.app_state.db.incoming_call_for_user(user.id).await?
976 {
977 self.peer.send(connection_id, incoming_call)?;
978 }
979
980 update_user_contacts(user.id, session).await?;
981 }
982 }
983
984 Ok(())
985 }
986
987 pub async fn invite_code_redeemed(
988 self: &Arc<Self>,
989 inviter_id: UserId,
990 invitee_id: UserId,
991 ) -> Result<()> {
992 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
993 if let Some(code) = &user.invite_code {
994 let pool = self.connection_pool.lock();
995 let invitee_contact = contact_for_user(invitee_id, false, &pool);
996 for connection_id in pool.user_connection_ids(inviter_id) {
997 self.peer.send(
998 connection_id,
999 proto::UpdateContacts {
1000 contacts: vec![invitee_contact.clone()],
1001 ..Default::default()
1002 },
1003 )?;
1004 self.peer.send(
1005 connection_id,
1006 proto::UpdateInviteInfo {
1007 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1008 count: user.invite_count as u32,
1009 },
1010 )?;
1011 }
1012 }
1013 }
1014 Ok(())
1015 }
1016
1017 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1018 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1019 if let Some(invite_code) = &user.invite_code {
1020 let pool = self.connection_pool.lock();
1021 for connection_id in pool.user_connection_ids(user_id) {
1022 self.peer.send(
1023 connection_id,
1024 proto::UpdateInviteInfo {
1025 url: format!(
1026 "{}{}",
1027 self.app_state.config.invite_link_prefix, invite_code
1028 ),
1029 count: user.invite_count as u32,
1030 },
1031 )?;
1032 }
1033 }
1034 }
1035 Ok(())
1036 }
1037
1038 pub async fn update_plan_for_user(
1039 self: &Arc<Self>,
1040 user_id: UserId,
1041 update_user_plan: proto::UpdateUserPlan,
1042 ) -> Result<()> {
1043 let pool = self.connection_pool.lock();
1044 for connection_id in pool.user_connection_ids(user_id) {
1045 self.peer
1046 .send(connection_id, update_user_plan.clone())
1047 .trace_err();
1048 }
1049
1050 Ok(())
1051 }
1052
1053 /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1054 /// message on the Collab server.
1055 ///
1056 /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1057 pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1058 let user = self
1059 .app_state
1060 .db
1061 .get_user_by_id(user_id)
1062 .await?
1063 .context("user not found")?;
1064
1065 let update_user_plan = make_update_user_plan_message(
1066 &user,
1067 user.admin,
1068 &self.app_state.db,
1069 self.app_state.llm_db.clone(),
1070 )
1071 .await?;
1072
1073 self.update_plan_for_user(user_id, update_user_plan).await
1074 }
1075
1076 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1077 let pool = self.connection_pool.lock();
1078 for connection_id in pool.user_connection_ids(user_id) {
1079 self.peer
1080 .send(connection_id, proto::RefreshLlmToken {})
1081 .trace_err();
1082 }
1083 }
1084
1085 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1086 ServerSnapshot {
1087 connection_pool: ConnectionPoolGuard {
1088 guard: self.connection_pool.lock(),
1089 _not_send: PhantomData,
1090 },
1091 peer: &self.peer,
1092 }
1093 }
1094}
1095
1096impl Deref for ConnectionPoolGuard<'_> {
1097 type Target = ConnectionPool;
1098
1099 fn deref(&self) -> &Self::Target {
1100 &self.guard
1101 }
1102}
1103
1104impl DerefMut for ConnectionPoolGuard<'_> {
1105 fn deref_mut(&mut self) -> &mut Self::Target {
1106 &mut self.guard
1107 }
1108}
1109
1110impl Drop for ConnectionPoolGuard<'_> {
1111 fn drop(&mut self) {
1112 #[cfg(test)]
1113 self.check_invariants();
1114 }
1115}
1116
1117fn broadcast<F>(
1118 sender_id: Option<ConnectionId>,
1119 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1120 mut f: F,
1121) where
1122 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1123{
1124 for receiver_id in receiver_ids {
1125 if Some(receiver_id) != sender_id {
1126 if let Err(error) = f(receiver_id) {
1127 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1128 }
1129 }
1130 }
1131}
1132
1133pub struct ProtocolVersion(u32);
1134
1135impl Header for ProtocolVersion {
1136 fn name() -> &'static HeaderName {
1137 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1138 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1139 }
1140
1141 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1142 where
1143 Self: Sized,
1144 I: Iterator<Item = &'i axum::http::HeaderValue>,
1145 {
1146 let version = values
1147 .next()
1148 .ok_or_else(axum::headers::Error::invalid)?
1149 .to_str()
1150 .map_err(|_| axum::headers::Error::invalid())?
1151 .parse()
1152 .map_err(|_| axum::headers::Error::invalid())?;
1153 Ok(Self(version))
1154 }
1155
1156 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1157 values.extend([self.0.to_string().parse().unwrap()]);
1158 }
1159}
1160
1161pub struct AppVersionHeader(SemanticVersion);
1162impl Header for AppVersionHeader {
1163 fn name() -> &'static HeaderName {
1164 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1165 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1166 }
1167
1168 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1169 where
1170 Self: Sized,
1171 I: Iterator<Item = &'i axum::http::HeaderValue>,
1172 {
1173 let version = values
1174 .next()
1175 .ok_or_else(axum::headers::Error::invalid)?
1176 .to_str()
1177 .map_err(|_| axum::headers::Error::invalid())?
1178 .parse()
1179 .map_err(|_| axum::headers::Error::invalid())?;
1180 Ok(Self(version))
1181 }
1182
1183 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1184 values.extend([self.0.to_string().parse().unwrap()]);
1185 }
1186}
1187
1188#[derive(Debug)]
1189pub struct ReleaseChannelHeader(String);
1190
1191impl Header for ReleaseChannelHeader {
1192 fn name() -> &'static HeaderName {
1193 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1194 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1195 }
1196
1197 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1198 where
1199 Self: Sized,
1200 I: Iterator<Item = &'i axum::http::HeaderValue>,
1201 {
1202 Ok(Self(
1203 values
1204 .next()
1205 .ok_or_else(axum::headers::Error::invalid)?
1206 .to_str()
1207 .map_err(|_| axum::headers::Error::invalid())?
1208 .to_owned(),
1209 ))
1210 }
1211
1212 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1213 values.extend([self.0.parse().unwrap()]);
1214 }
1215}
1216
1217pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1218 Router::new()
1219 .route("/rpc", get(handle_websocket_request))
1220 .layer(
1221 ServiceBuilder::new()
1222 .layer(Extension(server.app_state.clone()))
1223 .layer(middleware::from_fn(auth::validate_header)),
1224 )
1225 .route("/metrics", get(handle_metrics))
1226 .layer(Extension(server))
1227}
1228
1229pub async fn handle_websocket_request(
1230 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1231 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1232 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1233 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1234 Extension(server): Extension<Arc<Server>>,
1235 Extension(principal): Extension<Principal>,
1236 user_agent: Option<TypedHeader<UserAgent>>,
1237 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1238 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1239 ws: WebSocketUpgrade,
1240) -> axum::response::Response {
1241 if protocol_version != rpc::PROTOCOL_VERSION {
1242 return (
1243 StatusCode::UPGRADE_REQUIRED,
1244 "client must be upgraded".to_string(),
1245 )
1246 .into_response();
1247 }
1248
1249 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1250 return (
1251 StatusCode::UPGRADE_REQUIRED,
1252 "no version header found".to_string(),
1253 )
1254 .into_response();
1255 };
1256
1257 let release_channel = release_channel_header.map(|header| header.0.0);
1258
1259 if !version.can_collaborate() {
1260 return (
1261 StatusCode::UPGRADE_REQUIRED,
1262 "client must be upgraded".to_string(),
1263 )
1264 .into_response();
1265 }
1266
1267 let socket_address = socket_address.to_string();
1268
1269 // Acquire connection guard before WebSocket upgrade
1270 let connection_guard = match ConnectionGuard::try_acquire() {
1271 Ok(guard) => guard,
1272 Err(()) => {
1273 return (
1274 StatusCode::SERVICE_UNAVAILABLE,
1275 "Too many concurrent connections",
1276 )
1277 .into_response();
1278 }
1279 };
1280
1281 ws.on_upgrade(move |socket| {
1282 let socket = socket
1283 .map_ok(to_tungstenite_message)
1284 .err_into()
1285 .with(|message| async move { to_axum_message(message) });
1286 let connection = Connection::new(Box::pin(socket));
1287 async move {
1288 server
1289 .handle_connection(
1290 connection,
1291 socket_address,
1292 principal,
1293 version,
1294 release_channel,
1295 user_agent.map(|header| header.to_string()),
1296 country_code_header.map(|header| header.to_string()),
1297 system_id_header.map(|header| header.to_string()),
1298 None,
1299 Executor::Production,
1300 Some(connection_guard),
1301 )
1302 .await;
1303 }
1304 })
1305}
1306
1307pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1308 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1309 let connections_metric = CONNECTIONS_METRIC
1310 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1311
1312 let connections = server
1313 .connection_pool
1314 .lock()
1315 .connections()
1316 .filter(|connection| !connection.admin)
1317 .count();
1318 connections_metric.set(connections as _);
1319
1320 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1321 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1322 register_int_gauge!(
1323 "shared_projects",
1324 "number of open projects with one or more guests"
1325 )
1326 .unwrap()
1327 });
1328
1329 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1330 shared_projects_metric.set(shared_projects as _);
1331
1332 let encoder = prometheus::TextEncoder::new();
1333 let metric_families = prometheus::gather();
1334 let encoded_metrics = encoder
1335 .encode_to_string(&metric_families)
1336 .map_err(|err| anyhow!("{err}"))?;
1337 Ok(encoded_metrics)
1338}
1339
1340#[instrument(err, skip(executor))]
1341async fn connection_lost(
1342 session: Session,
1343 mut teardown: watch::Receiver<bool>,
1344 executor: Executor,
1345) -> Result<()> {
1346 session.peer.disconnect(session.connection_id);
1347 session
1348 .connection_pool()
1349 .await
1350 .remove_connection(session.connection_id)?;
1351
1352 session
1353 .db()
1354 .await
1355 .connection_lost(session.connection_id)
1356 .await
1357 .trace_err();
1358
1359 futures::select_biased! {
1360 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1361
1362 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1363 leave_room_for_session(&session, session.connection_id).await.trace_err();
1364 leave_channel_buffers_for_session(&session)
1365 .await
1366 .trace_err();
1367
1368 if !session
1369 .connection_pool()
1370 .await
1371 .is_user_online(session.user_id())
1372 {
1373 let db = session.db().await;
1374 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1375 room_updated(&room, &session.peer);
1376 }
1377 }
1378
1379 update_user_contacts(session.user_id(), &session).await?;
1380 },
1381 _ = teardown.changed().fuse() => {}
1382 }
1383
1384 Ok(())
1385}
1386
1387/// Acknowledges a ping from a client, used to keep the connection alive.
1388async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1389 response.send(proto::Ack {})?;
1390 Ok(())
1391}
1392
1393/// Creates a new room for calling (outside of channels)
1394async fn create_room(
1395 _request: proto::CreateRoom,
1396 response: Response<proto::CreateRoom>,
1397 session: Session,
1398) -> Result<()> {
1399 let livekit_room = nanoid::nanoid!(30);
1400
1401 let live_kit_connection_info = util::maybe!(async {
1402 let live_kit = session.app_state.livekit_client.as_ref();
1403 let live_kit = live_kit?;
1404 let user_id = session.user_id().to_string();
1405
1406 let token = live_kit
1407 .room_token(&livekit_room, &user_id.to_string())
1408 .trace_err()?;
1409
1410 Some(proto::LiveKitConnectionInfo {
1411 server_url: live_kit.url().into(),
1412 token,
1413 can_publish: true,
1414 })
1415 })
1416 .await;
1417
1418 let room = session
1419 .db()
1420 .await
1421 .create_room(session.user_id(), session.connection_id, &livekit_room)
1422 .await?;
1423
1424 response.send(proto::CreateRoomResponse {
1425 room: Some(room.clone()),
1426 live_kit_connection_info,
1427 })?;
1428
1429 update_user_contacts(session.user_id(), &session).await?;
1430 Ok(())
1431}
1432
1433/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1434async fn join_room(
1435 request: proto::JoinRoom,
1436 response: Response<proto::JoinRoom>,
1437 session: Session,
1438) -> Result<()> {
1439 let room_id = RoomId::from_proto(request.id);
1440
1441 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1442
1443 if let Some(channel_id) = channel_id {
1444 return join_channel_internal(channel_id, Box::new(response), session).await;
1445 }
1446
1447 let joined_room = {
1448 let room = session
1449 .db()
1450 .await
1451 .join_room(room_id, session.user_id(), session.connection_id)
1452 .await?;
1453 room_updated(&room.room, &session.peer);
1454 room.into_inner()
1455 };
1456
1457 for connection_id in session
1458 .connection_pool()
1459 .await
1460 .user_connection_ids(session.user_id())
1461 {
1462 session
1463 .peer
1464 .send(
1465 connection_id,
1466 proto::CallCanceled {
1467 room_id: room_id.to_proto(),
1468 },
1469 )
1470 .trace_err();
1471 }
1472
1473 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1474 {
1475 live_kit
1476 .room_token(
1477 &joined_room.room.livekit_room,
1478 &session.user_id().to_string(),
1479 )
1480 .trace_err()
1481 .map(|token| proto::LiveKitConnectionInfo {
1482 server_url: live_kit.url().into(),
1483 token,
1484 can_publish: true,
1485 })
1486 } else {
1487 None
1488 };
1489
1490 response.send(proto::JoinRoomResponse {
1491 room: Some(joined_room.room),
1492 channel_id: None,
1493 live_kit_connection_info,
1494 })?;
1495
1496 update_user_contacts(session.user_id(), &session).await?;
1497 Ok(())
1498}
1499
1500/// Rejoin room is used to reconnect to a room after connection errors.
1501async fn rejoin_room(
1502 request: proto::RejoinRoom,
1503 response: Response<proto::RejoinRoom>,
1504 session: Session,
1505) -> Result<()> {
1506 let room;
1507 let channel;
1508 {
1509 let mut rejoined_room = session
1510 .db()
1511 .await
1512 .rejoin_room(request, session.user_id(), session.connection_id)
1513 .await?;
1514
1515 response.send(proto::RejoinRoomResponse {
1516 room: Some(rejoined_room.room.clone()),
1517 reshared_projects: rejoined_room
1518 .reshared_projects
1519 .iter()
1520 .map(|project| proto::ResharedProject {
1521 id: project.id.to_proto(),
1522 collaborators: project
1523 .collaborators
1524 .iter()
1525 .map(|collaborator| collaborator.to_proto())
1526 .collect(),
1527 })
1528 .collect(),
1529 rejoined_projects: rejoined_room
1530 .rejoined_projects
1531 .iter()
1532 .map(|rejoined_project| rejoined_project.to_proto())
1533 .collect(),
1534 })?;
1535 room_updated(&rejoined_room.room, &session.peer);
1536
1537 for project in &rejoined_room.reshared_projects {
1538 for collaborator in &project.collaborators {
1539 session
1540 .peer
1541 .send(
1542 collaborator.connection_id,
1543 proto::UpdateProjectCollaborator {
1544 project_id: project.id.to_proto(),
1545 old_peer_id: Some(project.old_connection_id.into()),
1546 new_peer_id: Some(session.connection_id.into()),
1547 },
1548 )
1549 .trace_err();
1550 }
1551
1552 broadcast(
1553 Some(session.connection_id),
1554 project
1555 .collaborators
1556 .iter()
1557 .map(|collaborator| collaborator.connection_id),
1558 |connection_id| {
1559 session.peer.forward_send(
1560 session.connection_id,
1561 connection_id,
1562 proto::UpdateProject {
1563 project_id: project.id.to_proto(),
1564 worktrees: project.worktrees.clone(),
1565 },
1566 )
1567 },
1568 );
1569 }
1570
1571 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1572
1573 let rejoined_room = rejoined_room.into_inner();
1574
1575 room = rejoined_room.room;
1576 channel = rejoined_room.channel;
1577 }
1578
1579 if let Some(channel) = channel {
1580 channel_updated(
1581 &channel,
1582 &room,
1583 &session.peer,
1584 &*session.connection_pool().await,
1585 );
1586 }
1587
1588 update_user_contacts(session.user_id(), &session).await?;
1589 Ok(())
1590}
1591
1592fn notify_rejoined_projects(
1593 rejoined_projects: &mut Vec<RejoinedProject>,
1594 session: &Session,
1595) -> Result<()> {
1596 for project in rejoined_projects.iter() {
1597 for collaborator in &project.collaborators {
1598 session
1599 .peer
1600 .send(
1601 collaborator.connection_id,
1602 proto::UpdateProjectCollaborator {
1603 project_id: project.id.to_proto(),
1604 old_peer_id: Some(project.old_connection_id.into()),
1605 new_peer_id: Some(session.connection_id.into()),
1606 },
1607 )
1608 .trace_err();
1609 }
1610 }
1611
1612 for project in rejoined_projects {
1613 for worktree in mem::take(&mut project.worktrees) {
1614 // Stream this worktree's entries.
1615 let message = proto::UpdateWorktree {
1616 project_id: project.id.to_proto(),
1617 worktree_id: worktree.id,
1618 abs_path: worktree.abs_path.clone(),
1619 root_name: worktree.root_name,
1620 updated_entries: worktree.updated_entries,
1621 removed_entries: worktree.removed_entries,
1622 scan_id: worktree.scan_id,
1623 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1624 updated_repositories: worktree.updated_repositories,
1625 removed_repositories: worktree.removed_repositories,
1626 };
1627 for update in proto::split_worktree_update(message) {
1628 session.peer.send(session.connection_id, update)?;
1629 }
1630
1631 // Stream this worktree's diagnostics.
1632 for summary in worktree.diagnostic_summaries {
1633 session.peer.send(
1634 session.connection_id,
1635 proto::UpdateDiagnosticSummary {
1636 project_id: project.id.to_proto(),
1637 worktree_id: worktree.id,
1638 summary: Some(summary),
1639 },
1640 )?;
1641 }
1642
1643 for settings_file in worktree.settings_files {
1644 session.peer.send(
1645 session.connection_id,
1646 proto::UpdateWorktreeSettings {
1647 project_id: project.id.to_proto(),
1648 worktree_id: worktree.id,
1649 path: settings_file.path,
1650 content: Some(settings_file.content),
1651 kind: Some(settings_file.kind.to_proto().into()),
1652 },
1653 )?;
1654 }
1655 }
1656
1657 for repository in mem::take(&mut project.updated_repositories) {
1658 for update in split_repository_update(repository) {
1659 session.peer.send(session.connection_id, update)?;
1660 }
1661 }
1662
1663 for id in mem::take(&mut project.removed_repositories) {
1664 session.peer.send(
1665 session.connection_id,
1666 proto::RemoveRepository {
1667 project_id: project.id.to_proto(),
1668 id,
1669 },
1670 )?;
1671 }
1672 }
1673
1674 Ok(())
1675}
1676
1677/// leave room disconnects from the room.
1678async fn leave_room(
1679 _: proto::LeaveRoom,
1680 response: Response<proto::LeaveRoom>,
1681 session: Session,
1682) -> Result<()> {
1683 leave_room_for_session(&session, session.connection_id).await?;
1684 response.send(proto::Ack {})?;
1685 Ok(())
1686}
1687
1688/// Updates the permissions of someone else in the room.
1689async fn set_room_participant_role(
1690 request: proto::SetRoomParticipantRole,
1691 response: Response<proto::SetRoomParticipantRole>,
1692 session: Session,
1693) -> Result<()> {
1694 let user_id = UserId::from_proto(request.user_id);
1695 let role = ChannelRole::from(request.role());
1696
1697 let (livekit_room, can_publish) = {
1698 let room = session
1699 .db()
1700 .await
1701 .set_room_participant_role(
1702 session.user_id(),
1703 RoomId::from_proto(request.room_id),
1704 user_id,
1705 role,
1706 )
1707 .await?;
1708
1709 let livekit_room = room.livekit_room.clone();
1710 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1711 room_updated(&room, &session.peer);
1712 (livekit_room, can_publish)
1713 };
1714
1715 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1716 live_kit
1717 .update_participant(
1718 livekit_room.clone(),
1719 request.user_id.to_string(),
1720 livekit_api::proto::ParticipantPermission {
1721 can_subscribe: true,
1722 can_publish,
1723 can_publish_data: can_publish,
1724 hidden: false,
1725 recorder: false,
1726 },
1727 )
1728 .await
1729 .trace_err();
1730 }
1731
1732 response.send(proto::Ack {})?;
1733 Ok(())
1734}
1735
1736/// Call someone else into the current room
1737async fn call(
1738 request: proto::Call,
1739 response: Response<proto::Call>,
1740 session: Session,
1741) -> Result<()> {
1742 let room_id = RoomId::from_proto(request.room_id);
1743 let calling_user_id = session.user_id();
1744 let calling_connection_id = session.connection_id;
1745 let called_user_id = UserId::from_proto(request.called_user_id);
1746 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1747 if !session
1748 .db()
1749 .await
1750 .has_contact(calling_user_id, called_user_id)
1751 .await?
1752 {
1753 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1754 }
1755
1756 let incoming_call = {
1757 let (room, incoming_call) = &mut *session
1758 .db()
1759 .await
1760 .call(
1761 room_id,
1762 calling_user_id,
1763 calling_connection_id,
1764 called_user_id,
1765 initial_project_id,
1766 )
1767 .await?;
1768 room_updated(room, &session.peer);
1769 mem::take(incoming_call)
1770 };
1771 update_user_contacts(called_user_id, &session).await?;
1772
1773 let mut calls = session
1774 .connection_pool()
1775 .await
1776 .user_connection_ids(called_user_id)
1777 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1778 .collect::<FuturesUnordered<_>>();
1779
1780 while let Some(call_response) = calls.next().await {
1781 match call_response.as_ref() {
1782 Ok(_) => {
1783 response.send(proto::Ack {})?;
1784 return Ok(());
1785 }
1786 Err(_) => {
1787 call_response.trace_err();
1788 }
1789 }
1790 }
1791
1792 {
1793 let room = session
1794 .db()
1795 .await
1796 .call_failed(room_id, called_user_id)
1797 .await?;
1798 room_updated(&room, &session.peer);
1799 }
1800 update_user_contacts(called_user_id, &session).await?;
1801
1802 Err(anyhow!("failed to ring user"))?
1803}
1804
1805/// Cancel an outgoing call.
1806async fn cancel_call(
1807 request: proto::CancelCall,
1808 response: Response<proto::CancelCall>,
1809 session: Session,
1810) -> Result<()> {
1811 let called_user_id = UserId::from_proto(request.called_user_id);
1812 let room_id = RoomId::from_proto(request.room_id);
1813 {
1814 let room = session
1815 .db()
1816 .await
1817 .cancel_call(room_id, session.connection_id, called_user_id)
1818 .await?;
1819 room_updated(&room, &session.peer);
1820 }
1821
1822 for connection_id in session
1823 .connection_pool()
1824 .await
1825 .user_connection_ids(called_user_id)
1826 {
1827 session
1828 .peer
1829 .send(
1830 connection_id,
1831 proto::CallCanceled {
1832 room_id: room_id.to_proto(),
1833 },
1834 )
1835 .trace_err();
1836 }
1837 response.send(proto::Ack {})?;
1838
1839 update_user_contacts(called_user_id, &session).await?;
1840 Ok(())
1841}
1842
1843/// Decline an incoming call.
1844async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1845 let room_id = RoomId::from_proto(message.room_id);
1846 {
1847 let room = session
1848 .db()
1849 .await
1850 .decline_call(Some(room_id), session.user_id())
1851 .await?
1852 .context("declining call")?;
1853 room_updated(&room, &session.peer);
1854 }
1855
1856 for connection_id in session
1857 .connection_pool()
1858 .await
1859 .user_connection_ids(session.user_id())
1860 {
1861 session
1862 .peer
1863 .send(
1864 connection_id,
1865 proto::CallCanceled {
1866 room_id: room_id.to_proto(),
1867 },
1868 )
1869 .trace_err();
1870 }
1871 update_user_contacts(session.user_id(), &session).await?;
1872 Ok(())
1873}
1874
1875/// Updates other participants in the room with your current location.
1876async fn update_participant_location(
1877 request: proto::UpdateParticipantLocation,
1878 response: Response<proto::UpdateParticipantLocation>,
1879 session: Session,
1880) -> Result<()> {
1881 let room_id = RoomId::from_proto(request.room_id);
1882 let location = request.location.context("invalid location")?;
1883
1884 let db = session.db().await;
1885 let room = db
1886 .update_room_participant_location(room_id, session.connection_id, location)
1887 .await?;
1888
1889 room_updated(&room, &session.peer);
1890 response.send(proto::Ack {})?;
1891 Ok(())
1892}
1893
1894/// Share a project into the room.
1895async fn share_project(
1896 request: proto::ShareProject,
1897 response: Response<proto::ShareProject>,
1898 session: Session,
1899) -> Result<()> {
1900 let (project_id, room) = &*session
1901 .db()
1902 .await
1903 .share_project(
1904 RoomId::from_proto(request.room_id),
1905 session.connection_id,
1906 &request.worktrees,
1907 request.is_ssh_project,
1908 )
1909 .await?;
1910 response.send(proto::ShareProjectResponse {
1911 project_id: project_id.to_proto(),
1912 })?;
1913 room_updated(room, &session.peer);
1914
1915 Ok(())
1916}
1917
1918/// Unshare a project from the room.
1919async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1920 let project_id = ProjectId::from_proto(message.project_id);
1921 unshare_project_internal(project_id, session.connection_id, &session).await
1922}
1923
1924async fn unshare_project_internal(
1925 project_id: ProjectId,
1926 connection_id: ConnectionId,
1927 session: &Session,
1928) -> Result<()> {
1929 let delete = {
1930 let room_guard = session
1931 .db()
1932 .await
1933 .unshare_project(project_id, connection_id)
1934 .await?;
1935
1936 let (delete, room, guest_connection_ids) = &*room_guard;
1937
1938 let message = proto::UnshareProject {
1939 project_id: project_id.to_proto(),
1940 };
1941
1942 broadcast(
1943 Some(connection_id),
1944 guest_connection_ids.iter().copied(),
1945 |conn_id| session.peer.send(conn_id, message.clone()),
1946 );
1947 if let Some(room) = room {
1948 room_updated(room, &session.peer);
1949 }
1950
1951 *delete
1952 };
1953
1954 if delete {
1955 let db = session.db().await;
1956 db.delete_project(project_id).await?;
1957 }
1958
1959 Ok(())
1960}
1961
1962/// Join someone elses shared project.
1963async fn join_project(
1964 request: proto::JoinProject,
1965 response: Response<proto::JoinProject>,
1966 session: Session,
1967) -> Result<()> {
1968 let project_id = ProjectId::from_proto(request.project_id);
1969
1970 tracing::info!(%project_id, "join project");
1971
1972 let db = session.db().await;
1973 let (project, replica_id) = &mut *db
1974 .join_project(
1975 project_id,
1976 session.connection_id,
1977 session.user_id(),
1978 request.committer_name.clone(),
1979 request.committer_email.clone(),
1980 )
1981 .await?;
1982 drop(db);
1983 tracing::info!(%project_id, "join remote project");
1984 let collaborators = project
1985 .collaborators
1986 .iter()
1987 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1988 .map(|collaborator| collaborator.to_proto())
1989 .collect::<Vec<_>>();
1990 let project_id = project.id;
1991 let guest_user_id = session.user_id();
1992
1993 let worktrees = project
1994 .worktrees
1995 .iter()
1996 .map(|(id, worktree)| proto::WorktreeMetadata {
1997 id: *id,
1998 root_name: worktree.root_name.clone(),
1999 visible: worktree.visible,
2000 abs_path: worktree.abs_path.clone(),
2001 })
2002 .collect::<Vec<_>>();
2003
2004 let add_project_collaborator = proto::AddProjectCollaborator {
2005 project_id: project_id.to_proto(),
2006 collaborator: Some(proto::Collaborator {
2007 peer_id: Some(session.connection_id.into()),
2008 replica_id: replica_id.0 as u32,
2009 user_id: guest_user_id.to_proto(),
2010 is_host: false,
2011 committer_name: request.committer_name.clone(),
2012 committer_email: request.committer_email.clone(),
2013 }),
2014 };
2015
2016 for collaborator in &collaborators {
2017 session
2018 .peer
2019 .send(
2020 collaborator.peer_id.unwrap().into(),
2021 add_project_collaborator.clone(),
2022 )
2023 .trace_err();
2024 }
2025
2026 // First, we send the metadata associated with each worktree.
2027 let (language_servers, language_server_capabilities) = project
2028 .language_servers
2029 .clone()
2030 .into_iter()
2031 .map(|server| (server.server, server.capabilities))
2032 .unzip();
2033 response.send(proto::JoinProjectResponse {
2034 project_id: project.id.0 as u64,
2035 worktrees: worktrees.clone(),
2036 replica_id: replica_id.0 as u32,
2037 collaborators: collaborators.clone(),
2038 language_servers,
2039 language_server_capabilities,
2040 role: project.role.into(),
2041 })?;
2042
2043 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2044 // Stream this worktree's entries.
2045 let message = proto::UpdateWorktree {
2046 project_id: project_id.to_proto(),
2047 worktree_id,
2048 abs_path: worktree.abs_path.clone(),
2049 root_name: worktree.root_name,
2050 updated_entries: worktree.entries,
2051 removed_entries: Default::default(),
2052 scan_id: worktree.scan_id,
2053 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2054 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2055 removed_repositories: Default::default(),
2056 };
2057 for update in proto::split_worktree_update(message) {
2058 session.peer.send(session.connection_id, update.clone())?;
2059 }
2060
2061 // Stream this worktree's diagnostics.
2062 for summary in worktree.diagnostic_summaries {
2063 session.peer.send(
2064 session.connection_id,
2065 proto::UpdateDiagnosticSummary {
2066 project_id: project_id.to_proto(),
2067 worktree_id: worktree.id,
2068 summary: Some(summary),
2069 },
2070 )?;
2071 }
2072
2073 for settings_file in worktree.settings_files {
2074 session.peer.send(
2075 session.connection_id,
2076 proto::UpdateWorktreeSettings {
2077 project_id: project_id.to_proto(),
2078 worktree_id: worktree.id,
2079 path: settings_file.path,
2080 content: Some(settings_file.content),
2081 kind: Some(settings_file.kind.to_proto() as i32),
2082 },
2083 )?;
2084 }
2085 }
2086
2087 for repository in mem::take(&mut project.repositories) {
2088 for update in split_repository_update(repository) {
2089 session.peer.send(session.connection_id, update)?;
2090 }
2091 }
2092
2093 for language_server in &project.language_servers {
2094 session.peer.send(
2095 session.connection_id,
2096 proto::UpdateLanguageServer {
2097 project_id: project_id.to_proto(),
2098 server_name: Some(language_server.server.name.clone()),
2099 language_server_id: language_server.server.id,
2100 variant: Some(
2101 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2102 proto::LspDiskBasedDiagnosticsUpdated {},
2103 ),
2104 ),
2105 },
2106 )?;
2107 }
2108
2109 Ok(())
2110}
2111
2112/// Leave someone elses shared project.
2113async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2114 let sender_id = session.connection_id;
2115 let project_id = ProjectId::from_proto(request.project_id);
2116 let db = session.db().await;
2117
2118 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2119 tracing::info!(
2120 %project_id,
2121 "leave project"
2122 );
2123
2124 project_left(project, &session);
2125 if let Some(room) = room {
2126 room_updated(room, &session.peer);
2127 }
2128
2129 Ok(())
2130}
2131
2132/// Updates other participants with changes to the project
2133async fn update_project(
2134 request: proto::UpdateProject,
2135 response: Response<proto::UpdateProject>,
2136 session: Session,
2137) -> Result<()> {
2138 let project_id = ProjectId::from_proto(request.project_id);
2139 let (room, guest_connection_ids) = &*session
2140 .db()
2141 .await
2142 .update_project(project_id, session.connection_id, &request.worktrees)
2143 .await?;
2144 broadcast(
2145 Some(session.connection_id),
2146 guest_connection_ids.iter().copied(),
2147 |connection_id| {
2148 session
2149 .peer
2150 .forward_send(session.connection_id, connection_id, request.clone())
2151 },
2152 );
2153 if let Some(room) = room {
2154 room_updated(room, &session.peer);
2155 }
2156 response.send(proto::Ack {})?;
2157
2158 Ok(())
2159}
2160
2161/// Updates other participants with changes to the worktree
2162async fn update_worktree(
2163 request: proto::UpdateWorktree,
2164 response: Response<proto::UpdateWorktree>,
2165 session: Session,
2166) -> Result<()> {
2167 let guest_connection_ids = session
2168 .db()
2169 .await
2170 .update_worktree(&request, session.connection_id)
2171 .await?;
2172
2173 broadcast(
2174 Some(session.connection_id),
2175 guest_connection_ids.iter().copied(),
2176 |connection_id| {
2177 session
2178 .peer
2179 .forward_send(session.connection_id, connection_id, request.clone())
2180 },
2181 );
2182 response.send(proto::Ack {})?;
2183 Ok(())
2184}
2185
2186async fn update_repository(
2187 request: proto::UpdateRepository,
2188 response: Response<proto::UpdateRepository>,
2189 session: Session,
2190) -> Result<()> {
2191 let guest_connection_ids = session
2192 .db()
2193 .await
2194 .update_repository(&request, session.connection_id)
2195 .await?;
2196
2197 broadcast(
2198 Some(session.connection_id),
2199 guest_connection_ids.iter().copied(),
2200 |connection_id| {
2201 session
2202 .peer
2203 .forward_send(session.connection_id, connection_id, request.clone())
2204 },
2205 );
2206 response.send(proto::Ack {})?;
2207 Ok(())
2208}
2209
2210async fn remove_repository(
2211 request: proto::RemoveRepository,
2212 response: Response<proto::RemoveRepository>,
2213 session: Session,
2214) -> Result<()> {
2215 let guest_connection_ids = session
2216 .db()
2217 .await
2218 .remove_repository(&request, session.connection_id)
2219 .await?;
2220
2221 broadcast(
2222 Some(session.connection_id),
2223 guest_connection_ids.iter().copied(),
2224 |connection_id| {
2225 session
2226 .peer
2227 .forward_send(session.connection_id, connection_id, request.clone())
2228 },
2229 );
2230 response.send(proto::Ack {})?;
2231 Ok(())
2232}
2233
2234/// Updates other participants with changes to the diagnostics
2235async fn update_diagnostic_summary(
2236 message: proto::UpdateDiagnosticSummary,
2237 session: Session,
2238) -> Result<()> {
2239 let guest_connection_ids = session
2240 .db()
2241 .await
2242 .update_diagnostic_summary(&message, session.connection_id)
2243 .await?;
2244
2245 broadcast(
2246 Some(session.connection_id),
2247 guest_connection_ids.iter().copied(),
2248 |connection_id| {
2249 session
2250 .peer
2251 .forward_send(session.connection_id, connection_id, message.clone())
2252 },
2253 );
2254
2255 Ok(())
2256}
2257
2258/// Updates other participants with changes to the worktree settings
2259async fn update_worktree_settings(
2260 message: proto::UpdateWorktreeSettings,
2261 session: Session,
2262) -> Result<()> {
2263 let guest_connection_ids = session
2264 .db()
2265 .await
2266 .update_worktree_settings(&message, session.connection_id)
2267 .await?;
2268
2269 broadcast(
2270 Some(session.connection_id),
2271 guest_connection_ids.iter().copied(),
2272 |connection_id| {
2273 session
2274 .peer
2275 .forward_send(session.connection_id, connection_id, message.clone())
2276 },
2277 );
2278
2279 Ok(())
2280}
2281
2282/// Notify other participants that a language server has started.
2283async fn start_language_server(
2284 request: proto::StartLanguageServer,
2285 session: Session,
2286) -> Result<()> {
2287 let guest_connection_ids = session
2288 .db()
2289 .await
2290 .start_language_server(&request, session.connection_id)
2291 .await?;
2292
2293 broadcast(
2294 Some(session.connection_id),
2295 guest_connection_ids.iter().copied(),
2296 |connection_id| {
2297 session
2298 .peer
2299 .forward_send(session.connection_id, connection_id, request.clone())
2300 },
2301 );
2302 Ok(())
2303}
2304
2305/// Notify other participants that a language server has changed.
2306async fn update_language_server(
2307 request: proto::UpdateLanguageServer,
2308 session: Session,
2309) -> Result<()> {
2310 let project_id = ProjectId::from_proto(request.project_id);
2311 let db = session.db().await;
2312
2313 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2314 {
2315 if let Some(capabilities) = update.capabilities.clone() {
2316 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2317 .await?;
2318 }
2319 }
2320
2321 let project_connection_ids = db
2322 .project_connection_ids(project_id, session.connection_id, true)
2323 .await?;
2324 broadcast(
2325 Some(session.connection_id),
2326 project_connection_ids.iter().copied(),
2327 |connection_id| {
2328 session
2329 .peer
2330 .forward_send(session.connection_id, connection_id, request.clone())
2331 },
2332 );
2333 Ok(())
2334}
2335
2336/// forward a project request to the host. These requests should be read only
2337/// as guests are allowed to send them.
2338async fn forward_read_only_project_request<T>(
2339 request: T,
2340 response: Response<T>,
2341 session: Session,
2342) -> Result<()>
2343where
2344 T: EntityMessage + RequestMessage,
2345{
2346 let project_id = ProjectId::from_proto(request.remote_entity_id());
2347 let host_connection_id = session
2348 .db()
2349 .await
2350 .host_for_read_only_project_request(project_id, session.connection_id)
2351 .await?;
2352 let payload = session
2353 .peer
2354 .forward_request(session.connection_id, host_connection_id, request)
2355 .await?;
2356 response.send(payload)?;
2357 Ok(())
2358}
2359
2360/// forward a project request to the host. These requests are disallowed
2361/// for guests.
2362async fn forward_mutating_project_request<T>(
2363 request: T,
2364 response: Response<T>,
2365 session: Session,
2366) -> Result<()>
2367where
2368 T: EntityMessage + RequestMessage,
2369{
2370 let project_id = ProjectId::from_proto(request.remote_entity_id());
2371
2372 let host_connection_id = session
2373 .db()
2374 .await
2375 .host_for_mutating_project_request(project_id, session.connection_id)
2376 .await?;
2377 let payload = session
2378 .peer
2379 .forward_request(session.connection_id, host_connection_id, request)
2380 .await?;
2381 response.send(payload)?;
2382 Ok(())
2383}
2384
2385async fn multi_lsp_query(
2386 request: MultiLspQuery,
2387 response: Response<MultiLspQuery>,
2388 session: Session,
2389) -> Result<()> {
2390 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2391 tracing::info!("multi_lsp_query message received");
2392 forward_mutating_project_request(request, response, session).await
2393}
2394
2395/// Notify other participants that a new buffer has been created
2396async fn create_buffer_for_peer(
2397 request: proto::CreateBufferForPeer,
2398 session: Session,
2399) -> Result<()> {
2400 session
2401 .db()
2402 .await
2403 .check_user_is_project_host(
2404 ProjectId::from_proto(request.project_id),
2405 session.connection_id,
2406 )
2407 .await?;
2408 let peer_id = request.peer_id.context("invalid peer id")?;
2409 session
2410 .peer
2411 .forward_send(session.connection_id, peer_id.into(), request)?;
2412 Ok(())
2413}
2414
2415/// Notify other participants that a buffer has been updated. This is
2416/// allowed for guests as long as the update is limited to selections.
2417async fn update_buffer(
2418 request: proto::UpdateBuffer,
2419 response: Response<proto::UpdateBuffer>,
2420 session: Session,
2421) -> Result<()> {
2422 let project_id = ProjectId::from_proto(request.project_id);
2423 let mut capability = Capability::ReadOnly;
2424
2425 for op in request.operations.iter() {
2426 match op.variant {
2427 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2428 Some(_) => capability = Capability::ReadWrite,
2429 }
2430 }
2431
2432 let host = {
2433 let guard = session
2434 .db()
2435 .await
2436 .connections_for_buffer_update(project_id, session.connection_id, capability)
2437 .await?;
2438
2439 let (host, guests) = &*guard;
2440
2441 broadcast(
2442 Some(session.connection_id),
2443 guests.clone(),
2444 |connection_id| {
2445 session
2446 .peer
2447 .forward_send(session.connection_id, connection_id, request.clone())
2448 },
2449 );
2450
2451 *host
2452 };
2453
2454 if host != session.connection_id {
2455 session
2456 .peer
2457 .forward_request(session.connection_id, host, request.clone())
2458 .await?;
2459 }
2460
2461 response.send(proto::Ack {})?;
2462 Ok(())
2463}
2464
2465async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2466 let project_id = ProjectId::from_proto(message.project_id);
2467
2468 let operation = message.operation.as_ref().context("invalid operation")?;
2469 let capability = match operation.variant.as_ref() {
2470 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2471 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2472 match buffer_op.variant {
2473 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2474 Capability::ReadOnly
2475 }
2476 _ => Capability::ReadWrite,
2477 }
2478 } else {
2479 Capability::ReadWrite
2480 }
2481 }
2482 Some(_) => Capability::ReadWrite,
2483 None => Capability::ReadOnly,
2484 };
2485
2486 let guard = session
2487 .db()
2488 .await
2489 .connections_for_buffer_update(project_id, session.connection_id, capability)
2490 .await?;
2491
2492 let (host, guests) = &*guard;
2493
2494 broadcast(
2495 Some(session.connection_id),
2496 guests.iter().chain([host]).copied(),
2497 |connection_id| {
2498 session
2499 .peer
2500 .forward_send(session.connection_id, connection_id, message.clone())
2501 },
2502 );
2503
2504 Ok(())
2505}
2506
2507/// Notify other participants that a project has been updated.
2508async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2509 request: T,
2510 session: Session,
2511) -> Result<()> {
2512 let project_id = ProjectId::from_proto(request.remote_entity_id());
2513 let project_connection_ids = session
2514 .db()
2515 .await
2516 .project_connection_ids(project_id, session.connection_id, false)
2517 .await?;
2518
2519 broadcast(
2520 Some(session.connection_id),
2521 project_connection_ids.iter().copied(),
2522 |connection_id| {
2523 session
2524 .peer
2525 .forward_send(session.connection_id, connection_id, request.clone())
2526 },
2527 );
2528 Ok(())
2529}
2530
2531/// Start following another user in a call.
2532async fn follow(
2533 request: proto::Follow,
2534 response: Response<proto::Follow>,
2535 session: Session,
2536) -> Result<()> {
2537 let room_id = RoomId::from_proto(request.room_id);
2538 let project_id = request.project_id.map(ProjectId::from_proto);
2539 let leader_id = request.leader_id.context("invalid leader id")?.into();
2540 let follower_id = session.connection_id;
2541
2542 session
2543 .db()
2544 .await
2545 .check_room_participants(room_id, leader_id, session.connection_id)
2546 .await?;
2547
2548 let response_payload = session
2549 .peer
2550 .forward_request(session.connection_id, leader_id, request)
2551 .await?;
2552 response.send(response_payload)?;
2553
2554 if let Some(project_id) = project_id {
2555 let room = session
2556 .db()
2557 .await
2558 .follow(room_id, project_id, leader_id, follower_id)
2559 .await?;
2560 room_updated(&room, &session.peer);
2561 }
2562
2563 Ok(())
2564}
2565
2566/// Stop following another user in a call.
2567async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2568 let room_id = RoomId::from_proto(request.room_id);
2569 let project_id = request.project_id.map(ProjectId::from_proto);
2570 let leader_id = request.leader_id.context("invalid leader id")?.into();
2571 let follower_id = session.connection_id;
2572
2573 session
2574 .db()
2575 .await
2576 .check_room_participants(room_id, leader_id, session.connection_id)
2577 .await?;
2578
2579 session
2580 .peer
2581 .forward_send(session.connection_id, leader_id, request)?;
2582
2583 if let Some(project_id) = project_id {
2584 let room = session
2585 .db()
2586 .await
2587 .unfollow(room_id, project_id, leader_id, follower_id)
2588 .await?;
2589 room_updated(&room, &session.peer);
2590 }
2591
2592 Ok(())
2593}
2594
2595/// Notify everyone following you of your current location.
2596async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2597 let room_id = RoomId::from_proto(request.room_id);
2598 let database = session.db.lock().await;
2599
2600 let connection_ids = if let Some(project_id) = request.project_id {
2601 let project_id = ProjectId::from_proto(project_id);
2602 database
2603 .project_connection_ids(project_id, session.connection_id, true)
2604 .await?
2605 } else {
2606 database
2607 .room_connection_ids(room_id, session.connection_id)
2608 .await?
2609 };
2610
2611 // For now, don't send view update messages back to that view's current leader.
2612 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2613 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2614 _ => None,
2615 });
2616
2617 for connection_id in connection_ids.iter().cloned() {
2618 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2619 session
2620 .peer
2621 .forward_send(session.connection_id, connection_id, request.clone())?;
2622 }
2623 }
2624 Ok(())
2625}
2626
2627/// Get public data about users.
2628async fn get_users(
2629 request: proto::GetUsers,
2630 response: Response<proto::GetUsers>,
2631 session: Session,
2632) -> Result<()> {
2633 let user_ids = request
2634 .user_ids
2635 .into_iter()
2636 .map(UserId::from_proto)
2637 .collect();
2638 let users = session
2639 .db()
2640 .await
2641 .get_users_by_ids(user_ids)
2642 .await?
2643 .into_iter()
2644 .map(|user| proto::User {
2645 id: user.id.to_proto(),
2646 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2647 github_login: user.github_login,
2648 name: user.name,
2649 })
2650 .collect();
2651 response.send(proto::UsersResponse { users })?;
2652 Ok(())
2653}
2654
2655/// Search for users (to invite) buy Github login
2656async fn fuzzy_search_users(
2657 request: proto::FuzzySearchUsers,
2658 response: Response<proto::FuzzySearchUsers>,
2659 session: Session,
2660) -> Result<()> {
2661 let query = request.query;
2662 let users = match query.len() {
2663 0 => vec![],
2664 1 | 2 => session
2665 .db()
2666 .await
2667 .get_user_by_github_login(&query)
2668 .await?
2669 .into_iter()
2670 .collect(),
2671 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2672 };
2673 let users = users
2674 .into_iter()
2675 .filter(|user| user.id != session.user_id())
2676 .map(|user| proto::User {
2677 id: user.id.to_proto(),
2678 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2679 github_login: user.github_login,
2680 name: user.name,
2681 })
2682 .collect();
2683 response.send(proto::UsersResponse { users })?;
2684 Ok(())
2685}
2686
2687/// Send a contact request to another user.
2688async fn request_contact(
2689 request: proto::RequestContact,
2690 response: Response<proto::RequestContact>,
2691 session: Session,
2692) -> Result<()> {
2693 let requester_id = session.user_id();
2694 let responder_id = UserId::from_proto(request.responder_id);
2695 if requester_id == responder_id {
2696 return Err(anyhow!("cannot add yourself as a contact"))?;
2697 }
2698
2699 let notifications = session
2700 .db()
2701 .await
2702 .send_contact_request(requester_id, responder_id)
2703 .await?;
2704
2705 // Update outgoing contact requests of requester
2706 let mut update = proto::UpdateContacts::default();
2707 update.outgoing_requests.push(responder_id.to_proto());
2708 for connection_id in session
2709 .connection_pool()
2710 .await
2711 .user_connection_ids(requester_id)
2712 {
2713 session.peer.send(connection_id, update.clone())?;
2714 }
2715
2716 // Update incoming contact requests of responder
2717 let mut update = proto::UpdateContacts::default();
2718 update
2719 .incoming_requests
2720 .push(proto::IncomingContactRequest {
2721 requester_id: requester_id.to_proto(),
2722 });
2723 let connection_pool = session.connection_pool().await;
2724 for connection_id in connection_pool.user_connection_ids(responder_id) {
2725 session.peer.send(connection_id, update.clone())?;
2726 }
2727
2728 send_notifications(&connection_pool, &session.peer, notifications);
2729
2730 response.send(proto::Ack {})?;
2731 Ok(())
2732}
2733
2734/// Accept or decline a contact request
2735async fn respond_to_contact_request(
2736 request: proto::RespondToContactRequest,
2737 response: Response<proto::RespondToContactRequest>,
2738 session: Session,
2739) -> Result<()> {
2740 let responder_id = session.user_id();
2741 let requester_id = UserId::from_proto(request.requester_id);
2742 let db = session.db().await;
2743 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2744 db.dismiss_contact_notification(responder_id, requester_id)
2745 .await?;
2746 } else {
2747 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2748
2749 let notifications = db
2750 .respond_to_contact_request(responder_id, requester_id, accept)
2751 .await?;
2752 let requester_busy = db.is_user_busy(requester_id).await?;
2753 let responder_busy = db.is_user_busy(responder_id).await?;
2754
2755 let pool = session.connection_pool().await;
2756 // Update responder with new contact
2757 let mut update = proto::UpdateContacts::default();
2758 if accept {
2759 update
2760 .contacts
2761 .push(contact_for_user(requester_id, requester_busy, &pool));
2762 }
2763 update
2764 .remove_incoming_requests
2765 .push(requester_id.to_proto());
2766 for connection_id in pool.user_connection_ids(responder_id) {
2767 session.peer.send(connection_id, update.clone())?;
2768 }
2769
2770 // Update requester with new contact
2771 let mut update = proto::UpdateContacts::default();
2772 if accept {
2773 update
2774 .contacts
2775 .push(contact_for_user(responder_id, responder_busy, &pool));
2776 }
2777 update
2778 .remove_outgoing_requests
2779 .push(responder_id.to_proto());
2780
2781 for connection_id in pool.user_connection_ids(requester_id) {
2782 session.peer.send(connection_id, update.clone())?;
2783 }
2784
2785 send_notifications(&pool, &session.peer, notifications);
2786 }
2787
2788 response.send(proto::Ack {})?;
2789 Ok(())
2790}
2791
2792/// Remove a contact.
2793async fn remove_contact(
2794 request: proto::RemoveContact,
2795 response: Response<proto::RemoveContact>,
2796 session: Session,
2797) -> Result<()> {
2798 let requester_id = session.user_id();
2799 let responder_id = UserId::from_proto(request.user_id);
2800 let db = session.db().await;
2801 let (contact_accepted, deleted_notification_id) =
2802 db.remove_contact(requester_id, responder_id).await?;
2803
2804 let pool = session.connection_pool().await;
2805 // Update outgoing contact requests of requester
2806 let mut update = proto::UpdateContacts::default();
2807 if contact_accepted {
2808 update.remove_contacts.push(responder_id.to_proto());
2809 } else {
2810 update
2811 .remove_outgoing_requests
2812 .push(responder_id.to_proto());
2813 }
2814 for connection_id in pool.user_connection_ids(requester_id) {
2815 session.peer.send(connection_id, update.clone())?;
2816 }
2817
2818 // Update incoming contact requests of responder
2819 let mut update = proto::UpdateContacts::default();
2820 if contact_accepted {
2821 update.remove_contacts.push(requester_id.to_proto());
2822 } else {
2823 update
2824 .remove_incoming_requests
2825 .push(requester_id.to_proto());
2826 }
2827 for connection_id in pool.user_connection_ids(responder_id) {
2828 session.peer.send(connection_id, update.clone())?;
2829 if let Some(notification_id) = deleted_notification_id {
2830 session.peer.send(
2831 connection_id,
2832 proto::DeleteNotification {
2833 notification_id: notification_id.to_proto(),
2834 },
2835 )?;
2836 }
2837 }
2838
2839 response.send(proto::Ack {})?;
2840 Ok(())
2841}
2842
2843fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2844 version.0.minor() < 139
2845}
2846
2847async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2848 if is_staff {
2849 return Ok(proto::Plan::ZedPro);
2850 }
2851
2852 let subscription = db.get_active_billing_subscription(user_id).await?;
2853 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2854
2855 let plan = if let Some(subscription_kind) = subscription_kind {
2856 match subscription_kind {
2857 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2858 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2859 SubscriptionKind::ZedFree => proto::Plan::Free,
2860 }
2861 } else {
2862 proto::Plan::Free
2863 };
2864
2865 Ok(plan)
2866}
2867
2868async fn make_update_user_plan_message(
2869 user: &User,
2870 is_staff: bool,
2871 db: &Arc<Database>,
2872 llm_db: Option<Arc<LlmDatabase>>,
2873) -> Result<proto::UpdateUserPlan> {
2874 let feature_flags = db.get_user_flags(user.id).await?;
2875 let plan = current_plan(db, user.id, is_staff).await?;
2876 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2877 let billing_preferences = db.get_billing_preferences(user.id).await?;
2878
2879 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2880 let subscription = db.get_active_billing_subscription(user.id).await?;
2881
2882 let subscription_period =
2883 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2884
2885 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2886 llm_db
2887 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2888 .await?
2889 } else {
2890 None
2891 };
2892
2893 (subscription_period, usage)
2894 } else {
2895 (None, None)
2896 };
2897
2898 let bypass_account_age_check = feature_flags
2899 .iter()
2900 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2901 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2902 && !bypass_account_age_check
2903 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2904
2905 Ok(proto::UpdateUserPlan {
2906 plan: plan.into(),
2907 trial_started_at: billing_customer
2908 .as_ref()
2909 .and_then(|billing_customer| billing_customer.trial_started_at)
2910 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2911 is_usage_based_billing_enabled: if is_staff {
2912 Some(true)
2913 } else {
2914 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2915 },
2916 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2917 proto::SubscriptionPeriod {
2918 started_at: started_at.timestamp() as u64,
2919 ended_at: ended_at.timestamp() as u64,
2920 }
2921 }),
2922 account_too_young: Some(account_too_young),
2923 has_overdue_invoices: billing_customer
2924 .map(|billing_customer| billing_customer.has_overdue_invoices),
2925 usage: Some(
2926 usage
2927 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2928 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2929 ),
2930 })
2931}
2932
2933fn model_requests_limit(
2934 plan: cloud_llm_client::Plan,
2935 feature_flags: &Vec<String>,
2936) -> cloud_llm_client::UsageLimit {
2937 match plan.model_requests_limit() {
2938 cloud_llm_client::UsageLimit::Limited(limit) => {
2939 let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2940 && feature_flags
2941 .iter()
2942 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2943 {
2944 1_000
2945 } else {
2946 limit
2947 };
2948
2949 cloud_llm_client::UsageLimit::Limited(limit)
2950 }
2951 cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2952 }
2953}
2954
2955fn subscription_usage_to_proto(
2956 plan: proto::Plan,
2957 usage: crate::llm::db::subscription_usage::Model,
2958 feature_flags: &Vec<String>,
2959) -> proto::SubscriptionUsage {
2960 let plan = match plan {
2961 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2962 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2963 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2964 };
2965
2966 proto::SubscriptionUsage {
2967 model_requests_usage_amount: usage.model_requests as u32,
2968 model_requests_usage_limit: Some(proto::UsageLimit {
2969 variant: Some(match model_requests_limit(plan, feature_flags) {
2970 cloud_llm_client::UsageLimit::Limited(limit) => {
2971 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2972 limit: limit as u32,
2973 })
2974 }
2975 cloud_llm_client::UsageLimit::Unlimited => {
2976 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2977 }
2978 }),
2979 }),
2980 edit_predictions_usage_amount: usage.edit_predictions as u32,
2981 edit_predictions_usage_limit: Some(proto::UsageLimit {
2982 variant: Some(match plan.edit_predictions_limit() {
2983 cloud_llm_client::UsageLimit::Limited(limit) => {
2984 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2985 limit: limit as u32,
2986 })
2987 }
2988 cloud_llm_client::UsageLimit::Unlimited => {
2989 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2990 }
2991 }),
2992 }),
2993 }
2994}
2995
2996fn make_default_subscription_usage(
2997 plan: proto::Plan,
2998 feature_flags: &Vec<String>,
2999) -> proto::SubscriptionUsage {
3000 let plan = match plan {
3001 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
3002 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
3003 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
3004 };
3005
3006 proto::SubscriptionUsage {
3007 model_requests_usage_amount: 0,
3008 model_requests_usage_limit: Some(proto::UsageLimit {
3009 variant: Some(match model_requests_limit(plan, feature_flags) {
3010 cloud_llm_client::UsageLimit::Limited(limit) => {
3011 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3012 limit: limit as u32,
3013 })
3014 }
3015 cloud_llm_client::UsageLimit::Unlimited => {
3016 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3017 }
3018 }),
3019 }),
3020 edit_predictions_usage_amount: 0,
3021 edit_predictions_usage_limit: Some(proto::UsageLimit {
3022 variant: Some(match plan.edit_predictions_limit() {
3023 cloud_llm_client::UsageLimit::Limited(limit) => {
3024 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3025 limit: limit as u32,
3026 })
3027 }
3028 cloud_llm_client::UsageLimit::Unlimited => {
3029 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3030 }
3031 }),
3032 }),
3033 }
3034}
3035
3036async fn update_user_plan(session: &Session) -> Result<()> {
3037 let db = session.db().await;
3038
3039 let update_user_plan = make_update_user_plan_message(
3040 session.principal.user(),
3041 session.is_staff(),
3042 &db.0,
3043 session.app_state.llm_db.clone(),
3044 )
3045 .await?;
3046
3047 session
3048 .peer
3049 .send(session.connection_id, update_user_plan)
3050 .trace_err();
3051
3052 Ok(())
3053}
3054
3055async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3056 subscribe_user_to_channels(session.user_id(), &session).await?;
3057 Ok(())
3058}
3059
3060async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3061 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3062 let mut pool = session.connection_pool().await;
3063 for membership in &channels_for_user.channel_memberships {
3064 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3065 }
3066 session.peer.send(
3067 session.connection_id,
3068 build_update_user_channels(&channels_for_user),
3069 )?;
3070 session.peer.send(
3071 session.connection_id,
3072 build_channels_update(channels_for_user),
3073 )?;
3074 Ok(())
3075}
3076
3077/// Creates a new channel.
3078async fn create_channel(
3079 request: proto::CreateChannel,
3080 response: Response<proto::CreateChannel>,
3081 session: Session,
3082) -> Result<()> {
3083 let db = session.db().await;
3084
3085 let parent_id = request.parent_id.map(ChannelId::from_proto);
3086 let (channel, membership) = db
3087 .create_channel(&request.name, parent_id, session.user_id())
3088 .await?;
3089
3090 let root_id = channel.root_id();
3091 let channel = Channel::from_model(channel);
3092
3093 response.send(proto::CreateChannelResponse {
3094 channel: Some(channel.to_proto()),
3095 parent_id: request.parent_id,
3096 })?;
3097
3098 let mut connection_pool = session.connection_pool().await;
3099 if let Some(membership) = membership {
3100 connection_pool.subscribe_to_channel(
3101 membership.user_id,
3102 membership.channel_id,
3103 membership.role,
3104 );
3105 let update = proto::UpdateUserChannels {
3106 channel_memberships: vec![proto::ChannelMembership {
3107 channel_id: membership.channel_id.to_proto(),
3108 role: membership.role.into(),
3109 }],
3110 ..Default::default()
3111 };
3112 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3113 session.peer.send(connection_id, update.clone())?;
3114 }
3115 }
3116
3117 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3118 if !role.can_see_channel(channel.visibility) {
3119 continue;
3120 }
3121
3122 let update = proto::UpdateChannels {
3123 channels: vec![channel.to_proto()],
3124 ..Default::default()
3125 };
3126 session.peer.send(connection_id, update.clone())?;
3127 }
3128
3129 Ok(())
3130}
3131
3132/// Delete a channel
3133async fn delete_channel(
3134 request: proto::DeleteChannel,
3135 response: Response<proto::DeleteChannel>,
3136 session: Session,
3137) -> Result<()> {
3138 let db = session.db().await;
3139
3140 let channel_id = request.channel_id;
3141 let (root_channel, removed_channels) = db
3142 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3143 .await?;
3144 response.send(proto::Ack {})?;
3145
3146 // Notify members of removed channels
3147 let mut update = proto::UpdateChannels::default();
3148 update
3149 .delete_channels
3150 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3151
3152 let connection_pool = session.connection_pool().await;
3153 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3154 session.peer.send(connection_id, update.clone())?;
3155 }
3156
3157 Ok(())
3158}
3159
3160/// Invite someone to join a channel.
3161async fn invite_channel_member(
3162 request: proto::InviteChannelMember,
3163 response: Response<proto::InviteChannelMember>,
3164 session: Session,
3165) -> Result<()> {
3166 let db = session.db().await;
3167 let channel_id = ChannelId::from_proto(request.channel_id);
3168 let invitee_id = UserId::from_proto(request.user_id);
3169 let InviteMemberResult {
3170 channel,
3171 notifications,
3172 } = db
3173 .invite_channel_member(
3174 channel_id,
3175 invitee_id,
3176 session.user_id(),
3177 request.role().into(),
3178 )
3179 .await?;
3180
3181 let update = proto::UpdateChannels {
3182 channel_invitations: vec![channel.to_proto()],
3183 ..Default::default()
3184 };
3185
3186 let connection_pool = session.connection_pool().await;
3187 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3188 session.peer.send(connection_id, update.clone())?;
3189 }
3190
3191 send_notifications(&connection_pool, &session.peer, notifications);
3192
3193 response.send(proto::Ack {})?;
3194 Ok(())
3195}
3196
3197/// remove someone from a channel
3198async fn remove_channel_member(
3199 request: proto::RemoveChannelMember,
3200 response: Response<proto::RemoveChannelMember>,
3201 session: Session,
3202) -> Result<()> {
3203 let db = session.db().await;
3204 let channel_id = ChannelId::from_proto(request.channel_id);
3205 let member_id = UserId::from_proto(request.user_id);
3206
3207 let RemoveChannelMemberResult {
3208 membership_update,
3209 notification_id,
3210 } = db
3211 .remove_channel_member(channel_id, member_id, session.user_id())
3212 .await?;
3213
3214 let mut connection_pool = session.connection_pool().await;
3215 notify_membership_updated(
3216 &mut connection_pool,
3217 membership_update,
3218 member_id,
3219 &session.peer,
3220 );
3221 for connection_id in connection_pool.user_connection_ids(member_id) {
3222 if let Some(notification_id) = notification_id {
3223 session
3224 .peer
3225 .send(
3226 connection_id,
3227 proto::DeleteNotification {
3228 notification_id: notification_id.to_proto(),
3229 },
3230 )
3231 .trace_err();
3232 }
3233 }
3234
3235 response.send(proto::Ack {})?;
3236 Ok(())
3237}
3238
3239/// Toggle the channel between public and private.
3240/// Care is taken to maintain the invariant that public channels only descend from public channels,
3241/// (though members-only channels can appear at any point in the hierarchy).
3242async fn set_channel_visibility(
3243 request: proto::SetChannelVisibility,
3244 response: Response<proto::SetChannelVisibility>,
3245 session: Session,
3246) -> Result<()> {
3247 let db = session.db().await;
3248 let channel_id = ChannelId::from_proto(request.channel_id);
3249 let visibility = request.visibility().into();
3250
3251 let channel_model = db
3252 .set_channel_visibility(channel_id, visibility, session.user_id())
3253 .await?;
3254 let root_id = channel_model.root_id();
3255 let channel = Channel::from_model(channel_model);
3256
3257 let mut connection_pool = session.connection_pool().await;
3258 for (user_id, role) in connection_pool
3259 .channel_user_ids(root_id)
3260 .collect::<Vec<_>>()
3261 .into_iter()
3262 {
3263 let update = if role.can_see_channel(channel.visibility) {
3264 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3265 proto::UpdateChannels {
3266 channels: vec![channel.to_proto()],
3267 ..Default::default()
3268 }
3269 } else {
3270 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3271 proto::UpdateChannels {
3272 delete_channels: vec![channel.id.to_proto()],
3273 ..Default::default()
3274 }
3275 };
3276
3277 for connection_id in connection_pool.user_connection_ids(user_id) {
3278 session.peer.send(connection_id, update.clone())?;
3279 }
3280 }
3281
3282 response.send(proto::Ack {})?;
3283 Ok(())
3284}
3285
3286/// Alter the role for a user in the channel.
3287async fn set_channel_member_role(
3288 request: proto::SetChannelMemberRole,
3289 response: Response<proto::SetChannelMemberRole>,
3290 session: Session,
3291) -> Result<()> {
3292 let db = session.db().await;
3293 let channel_id = ChannelId::from_proto(request.channel_id);
3294 let member_id = UserId::from_proto(request.user_id);
3295 let result = db
3296 .set_channel_member_role(
3297 channel_id,
3298 session.user_id(),
3299 member_id,
3300 request.role().into(),
3301 )
3302 .await?;
3303
3304 match result {
3305 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3306 let mut connection_pool = session.connection_pool().await;
3307 notify_membership_updated(
3308 &mut connection_pool,
3309 membership_update,
3310 member_id,
3311 &session.peer,
3312 )
3313 }
3314 db::SetMemberRoleResult::InviteUpdated(channel) => {
3315 let update = proto::UpdateChannels {
3316 channel_invitations: vec![channel.to_proto()],
3317 ..Default::default()
3318 };
3319
3320 for connection_id in session
3321 .connection_pool()
3322 .await
3323 .user_connection_ids(member_id)
3324 {
3325 session.peer.send(connection_id, update.clone())?;
3326 }
3327 }
3328 }
3329
3330 response.send(proto::Ack {})?;
3331 Ok(())
3332}
3333
3334/// Change the name of a channel
3335async fn rename_channel(
3336 request: proto::RenameChannel,
3337 response: Response<proto::RenameChannel>,
3338 session: Session,
3339) -> Result<()> {
3340 let db = session.db().await;
3341 let channel_id = ChannelId::from_proto(request.channel_id);
3342 let channel_model = db
3343 .rename_channel(channel_id, session.user_id(), &request.name)
3344 .await?;
3345 let root_id = channel_model.root_id();
3346 let channel = Channel::from_model(channel_model);
3347
3348 response.send(proto::RenameChannelResponse {
3349 channel: Some(channel.to_proto()),
3350 })?;
3351
3352 let connection_pool = session.connection_pool().await;
3353 let update = proto::UpdateChannels {
3354 channels: vec![channel.to_proto()],
3355 ..Default::default()
3356 };
3357 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3358 if role.can_see_channel(channel.visibility) {
3359 session.peer.send(connection_id, update.clone())?;
3360 }
3361 }
3362
3363 Ok(())
3364}
3365
3366/// Move a channel to a new parent.
3367async fn move_channel(
3368 request: proto::MoveChannel,
3369 response: Response<proto::MoveChannel>,
3370 session: Session,
3371) -> Result<()> {
3372 let channel_id = ChannelId::from_proto(request.channel_id);
3373 let to = ChannelId::from_proto(request.to);
3374
3375 let (root_id, channels) = session
3376 .db()
3377 .await
3378 .move_channel(channel_id, to, session.user_id())
3379 .await?;
3380
3381 let connection_pool = session.connection_pool().await;
3382 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3383 let channels = channels
3384 .iter()
3385 .filter_map(|channel| {
3386 if role.can_see_channel(channel.visibility) {
3387 Some(channel.to_proto())
3388 } else {
3389 None
3390 }
3391 })
3392 .collect::<Vec<_>>();
3393 if channels.is_empty() {
3394 continue;
3395 }
3396
3397 let update = proto::UpdateChannels {
3398 channels,
3399 ..Default::default()
3400 };
3401
3402 session.peer.send(connection_id, update.clone())?;
3403 }
3404
3405 response.send(Ack {})?;
3406 Ok(())
3407}
3408
3409async fn reorder_channel(
3410 request: proto::ReorderChannel,
3411 response: Response<proto::ReorderChannel>,
3412 session: Session,
3413) -> Result<()> {
3414 let channel_id = ChannelId::from_proto(request.channel_id);
3415 let direction = request.direction();
3416
3417 let updated_channels = session
3418 .db()
3419 .await
3420 .reorder_channel(channel_id, direction, session.user_id())
3421 .await?;
3422
3423 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3424 let connection_pool = session.connection_pool().await;
3425 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3426 let channels = updated_channels
3427 .iter()
3428 .filter_map(|channel| {
3429 if role.can_see_channel(channel.visibility) {
3430 Some(channel.to_proto())
3431 } else {
3432 None
3433 }
3434 })
3435 .collect::<Vec<_>>();
3436
3437 if channels.is_empty() {
3438 continue;
3439 }
3440
3441 let update = proto::UpdateChannels {
3442 channels,
3443 ..Default::default()
3444 };
3445
3446 session.peer.send(connection_id, update.clone())?;
3447 }
3448 }
3449
3450 response.send(Ack {})?;
3451 Ok(())
3452}
3453
3454/// Get the list of channel members
3455async fn get_channel_members(
3456 request: proto::GetChannelMembers,
3457 response: Response<proto::GetChannelMembers>,
3458 session: Session,
3459) -> Result<()> {
3460 let db = session.db().await;
3461 let channel_id = ChannelId::from_proto(request.channel_id);
3462 let limit = if request.limit == 0 {
3463 u16::MAX as u64
3464 } else {
3465 request.limit
3466 };
3467 let (members, users) = db
3468 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3469 .await?;
3470 response.send(proto::GetChannelMembersResponse { members, users })?;
3471 Ok(())
3472}
3473
3474/// Accept or decline a channel invitation.
3475async fn respond_to_channel_invite(
3476 request: proto::RespondToChannelInvite,
3477 response: Response<proto::RespondToChannelInvite>,
3478 session: Session,
3479) -> Result<()> {
3480 let db = session.db().await;
3481 let channel_id = ChannelId::from_proto(request.channel_id);
3482 let RespondToChannelInvite {
3483 membership_update,
3484 notifications,
3485 } = db
3486 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3487 .await?;
3488
3489 let mut connection_pool = session.connection_pool().await;
3490 if let Some(membership_update) = membership_update {
3491 notify_membership_updated(
3492 &mut connection_pool,
3493 membership_update,
3494 session.user_id(),
3495 &session.peer,
3496 );
3497 } else {
3498 let update = proto::UpdateChannels {
3499 remove_channel_invitations: vec![channel_id.to_proto()],
3500 ..Default::default()
3501 };
3502
3503 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3504 session.peer.send(connection_id, update.clone())?;
3505 }
3506 };
3507
3508 send_notifications(&connection_pool, &session.peer, notifications);
3509
3510 response.send(proto::Ack {})?;
3511
3512 Ok(())
3513}
3514
3515/// Join the channels' room
3516async fn join_channel(
3517 request: proto::JoinChannel,
3518 response: Response<proto::JoinChannel>,
3519 session: Session,
3520) -> Result<()> {
3521 let channel_id = ChannelId::from_proto(request.channel_id);
3522 join_channel_internal(channel_id, Box::new(response), session).await
3523}
3524
3525trait JoinChannelInternalResponse {
3526 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3527}
3528impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3529 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3530 Response::<proto::JoinChannel>::send(self, result)
3531 }
3532}
3533impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3534 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3535 Response::<proto::JoinRoom>::send(self, result)
3536 }
3537}
3538
3539async fn join_channel_internal(
3540 channel_id: ChannelId,
3541 response: Box<impl JoinChannelInternalResponse>,
3542 session: Session,
3543) -> Result<()> {
3544 let joined_room = {
3545 let mut db = session.db().await;
3546 // If zed quits without leaving the room, and the user re-opens zed before the
3547 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3548 // room they were in.
3549 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3550 tracing::info!(
3551 stale_connection_id = %connection,
3552 "cleaning up stale connection",
3553 );
3554 drop(db);
3555 leave_room_for_session(&session, connection).await?;
3556 db = session.db().await;
3557 }
3558
3559 let (joined_room, membership_updated, role) = db
3560 .join_channel(channel_id, session.user_id(), session.connection_id)
3561 .await?;
3562
3563 let live_kit_connection_info =
3564 session
3565 .app_state
3566 .livekit_client
3567 .as_ref()
3568 .and_then(|live_kit| {
3569 let (can_publish, token) = if role == ChannelRole::Guest {
3570 (
3571 false,
3572 live_kit
3573 .guest_token(
3574 &joined_room.room.livekit_room,
3575 &session.user_id().to_string(),
3576 )
3577 .trace_err()?,
3578 )
3579 } else {
3580 (
3581 true,
3582 live_kit
3583 .room_token(
3584 &joined_room.room.livekit_room,
3585 &session.user_id().to_string(),
3586 )
3587 .trace_err()?,
3588 )
3589 };
3590
3591 Some(LiveKitConnectionInfo {
3592 server_url: live_kit.url().into(),
3593 token,
3594 can_publish,
3595 })
3596 });
3597
3598 response.send(proto::JoinRoomResponse {
3599 room: Some(joined_room.room.clone()),
3600 channel_id: joined_room
3601 .channel
3602 .as_ref()
3603 .map(|channel| channel.id.to_proto()),
3604 live_kit_connection_info,
3605 })?;
3606
3607 let mut connection_pool = session.connection_pool().await;
3608 if let Some(membership_updated) = membership_updated {
3609 notify_membership_updated(
3610 &mut connection_pool,
3611 membership_updated,
3612 session.user_id(),
3613 &session.peer,
3614 );
3615 }
3616
3617 room_updated(&joined_room.room, &session.peer);
3618
3619 joined_room
3620 };
3621
3622 channel_updated(
3623 &joined_room.channel.context("channel not returned")?,
3624 &joined_room.room,
3625 &session.peer,
3626 &*session.connection_pool().await,
3627 );
3628
3629 update_user_contacts(session.user_id(), &session).await?;
3630 Ok(())
3631}
3632
3633/// Start editing the channel notes
3634async fn join_channel_buffer(
3635 request: proto::JoinChannelBuffer,
3636 response: Response<proto::JoinChannelBuffer>,
3637 session: Session,
3638) -> Result<()> {
3639 let db = session.db().await;
3640 let channel_id = ChannelId::from_proto(request.channel_id);
3641
3642 let open_response = db
3643 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3644 .await?;
3645
3646 let collaborators = open_response.collaborators.clone();
3647 response.send(open_response)?;
3648
3649 let update = UpdateChannelBufferCollaborators {
3650 channel_id: channel_id.to_proto(),
3651 collaborators: collaborators.clone(),
3652 };
3653 channel_buffer_updated(
3654 session.connection_id,
3655 collaborators
3656 .iter()
3657 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3658 &update,
3659 &session.peer,
3660 );
3661
3662 Ok(())
3663}
3664
3665/// Edit the channel notes
3666async fn update_channel_buffer(
3667 request: proto::UpdateChannelBuffer,
3668 session: Session,
3669) -> Result<()> {
3670 let db = session.db().await;
3671 let channel_id = ChannelId::from_proto(request.channel_id);
3672
3673 let (collaborators, epoch, version) = db
3674 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3675 .await?;
3676
3677 channel_buffer_updated(
3678 session.connection_id,
3679 collaborators.clone(),
3680 &proto::UpdateChannelBuffer {
3681 channel_id: channel_id.to_proto(),
3682 operations: request.operations,
3683 },
3684 &session.peer,
3685 );
3686
3687 let pool = &*session.connection_pool().await;
3688
3689 let non_collaborators =
3690 pool.channel_connection_ids(channel_id)
3691 .filter_map(|(connection_id, _)| {
3692 if collaborators.contains(&connection_id) {
3693 None
3694 } else {
3695 Some(connection_id)
3696 }
3697 });
3698
3699 broadcast(None, non_collaborators, |peer_id| {
3700 session.peer.send(
3701 peer_id,
3702 proto::UpdateChannels {
3703 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3704 channel_id: channel_id.to_proto(),
3705 epoch: epoch as u64,
3706 version: version.clone(),
3707 }],
3708 ..Default::default()
3709 },
3710 )
3711 });
3712
3713 Ok(())
3714}
3715
3716/// Rejoin the channel notes after a connection blip
3717async fn rejoin_channel_buffers(
3718 request: proto::RejoinChannelBuffers,
3719 response: Response<proto::RejoinChannelBuffers>,
3720 session: Session,
3721) -> Result<()> {
3722 let db = session.db().await;
3723 let buffers = db
3724 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3725 .await?;
3726
3727 for rejoined_buffer in &buffers {
3728 let collaborators_to_notify = rejoined_buffer
3729 .buffer
3730 .collaborators
3731 .iter()
3732 .filter_map(|c| Some(c.peer_id?.into()));
3733 channel_buffer_updated(
3734 session.connection_id,
3735 collaborators_to_notify,
3736 &proto::UpdateChannelBufferCollaborators {
3737 channel_id: rejoined_buffer.buffer.channel_id,
3738 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3739 },
3740 &session.peer,
3741 );
3742 }
3743
3744 response.send(proto::RejoinChannelBuffersResponse {
3745 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3746 })?;
3747
3748 Ok(())
3749}
3750
3751/// Stop editing the channel notes
3752async fn leave_channel_buffer(
3753 request: proto::LeaveChannelBuffer,
3754 response: Response<proto::LeaveChannelBuffer>,
3755 session: Session,
3756) -> Result<()> {
3757 let db = session.db().await;
3758 let channel_id = ChannelId::from_proto(request.channel_id);
3759
3760 let left_buffer = db
3761 .leave_channel_buffer(channel_id, session.connection_id)
3762 .await?;
3763
3764 response.send(Ack {})?;
3765
3766 channel_buffer_updated(
3767 session.connection_id,
3768 left_buffer.connections,
3769 &proto::UpdateChannelBufferCollaborators {
3770 channel_id: channel_id.to_proto(),
3771 collaborators: left_buffer.collaborators,
3772 },
3773 &session.peer,
3774 );
3775
3776 Ok(())
3777}
3778
3779fn channel_buffer_updated<T: EnvelopedMessage>(
3780 sender_id: ConnectionId,
3781 collaborators: impl IntoIterator<Item = ConnectionId>,
3782 message: &T,
3783 peer: &Peer,
3784) {
3785 broadcast(Some(sender_id), collaborators, |peer_id| {
3786 peer.send(peer_id, message.clone())
3787 });
3788}
3789
3790fn send_notifications(
3791 connection_pool: &ConnectionPool,
3792 peer: &Peer,
3793 notifications: db::NotificationBatch,
3794) {
3795 for (user_id, notification) in notifications {
3796 for connection_id in connection_pool.user_connection_ids(user_id) {
3797 if let Err(error) = peer.send(
3798 connection_id,
3799 proto::AddNotification {
3800 notification: Some(notification.clone()),
3801 },
3802 ) {
3803 tracing::error!(
3804 "failed to send notification to {:?} {}",
3805 connection_id,
3806 error
3807 );
3808 }
3809 }
3810 }
3811}
3812
3813/// Send a message to the channel
3814async fn send_channel_message(
3815 request: proto::SendChannelMessage,
3816 response: Response<proto::SendChannelMessage>,
3817 session: Session,
3818) -> Result<()> {
3819 // Validate the message body.
3820 let body = request.body.trim().to_string();
3821 if body.len() > MAX_MESSAGE_LEN {
3822 return Err(anyhow!("message is too long"))?;
3823 }
3824 if body.is_empty() {
3825 return Err(anyhow!("message can't be blank"))?;
3826 }
3827
3828 // TODO: adjust mentions if body is trimmed
3829
3830 let timestamp = OffsetDateTime::now_utc();
3831 let nonce = request.nonce.context("nonce can't be blank")?;
3832
3833 let channel_id = ChannelId::from_proto(request.channel_id);
3834 let CreatedChannelMessage {
3835 message_id,
3836 participant_connection_ids,
3837 notifications,
3838 } = session
3839 .db()
3840 .await
3841 .create_channel_message(
3842 channel_id,
3843 session.user_id(),
3844 &body,
3845 &request.mentions,
3846 timestamp,
3847 nonce.clone().into(),
3848 request.reply_to_message_id.map(MessageId::from_proto),
3849 )
3850 .await?;
3851
3852 let message = proto::ChannelMessage {
3853 sender_id: session.user_id().to_proto(),
3854 id: message_id.to_proto(),
3855 body,
3856 mentions: request.mentions,
3857 timestamp: timestamp.unix_timestamp() as u64,
3858 nonce: Some(nonce),
3859 reply_to_message_id: request.reply_to_message_id,
3860 edited_at: None,
3861 };
3862 broadcast(
3863 Some(session.connection_id),
3864 participant_connection_ids.clone(),
3865 |connection| {
3866 session.peer.send(
3867 connection,
3868 proto::ChannelMessageSent {
3869 channel_id: channel_id.to_proto(),
3870 message: Some(message.clone()),
3871 },
3872 )
3873 },
3874 );
3875 response.send(proto::SendChannelMessageResponse {
3876 message: Some(message),
3877 })?;
3878
3879 let pool = &*session.connection_pool().await;
3880 let non_participants =
3881 pool.channel_connection_ids(channel_id)
3882 .filter_map(|(connection_id, _)| {
3883 if participant_connection_ids.contains(&connection_id) {
3884 None
3885 } else {
3886 Some(connection_id)
3887 }
3888 });
3889 broadcast(None, non_participants, |peer_id| {
3890 session.peer.send(
3891 peer_id,
3892 proto::UpdateChannels {
3893 latest_channel_message_ids: vec![proto::ChannelMessageId {
3894 channel_id: channel_id.to_proto(),
3895 message_id: message_id.to_proto(),
3896 }],
3897 ..Default::default()
3898 },
3899 )
3900 });
3901 send_notifications(pool, &session.peer, notifications);
3902
3903 Ok(())
3904}
3905
3906/// Delete a channel message
3907async fn remove_channel_message(
3908 request: proto::RemoveChannelMessage,
3909 response: Response<proto::RemoveChannelMessage>,
3910 session: Session,
3911) -> Result<()> {
3912 let channel_id = ChannelId::from_proto(request.channel_id);
3913 let message_id = MessageId::from_proto(request.message_id);
3914 let (connection_ids, existing_notification_ids) = session
3915 .db()
3916 .await
3917 .remove_channel_message(channel_id, message_id, session.user_id())
3918 .await?;
3919
3920 broadcast(
3921 Some(session.connection_id),
3922 connection_ids,
3923 move |connection| {
3924 session.peer.send(connection, request.clone())?;
3925
3926 for notification_id in &existing_notification_ids {
3927 session.peer.send(
3928 connection,
3929 proto::DeleteNotification {
3930 notification_id: (*notification_id).to_proto(),
3931 },
3932 )?;
3933 }
3934
3935 Ok(())
3936 },
3937 );
3938 response.send(proto::Ack {})?;
3939 Ok(())
3940}
3941
3942async fn update_channel_message(
3943 request: proto::UpdateChannelMessage,
3944 response: Response<proto::UpdateChannelMessage>,
3945 session: Session,
3946) -> Result<()> {
3947 let channel_id = ChannelId::from_proto(request.channel_id);
3948 let message_id = MessageId::from_proto(request.message_id);
3949 let updated_at = OffsetDateTime::now_utc();
3950 let UpdatedChannelMessage {
3951 message_id,
3952 participant_connection_ids,
3953 notifications,
3954 reply_to_message_id,
3955 timestamp,
3956 deleted_mention_notification_ids,
3957 updated_mention_notifications,
3958 } = session
3959 .db()
3960 .await
3961 .update_channel_message(
3962 channel_id,
3963 message_id,
3964 session.user_id(),
3965 request.body.as_str(),
3966 &request.mentions,
3967 updated_at,
3968 )
3969 .await?;
3970
3971 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3972
3973 let message = proto::ChannelMessage {
3974 sender_id: session.user_id().to_proto(),
3975 id: message_id.to_proto(),
3976 body: request.body.clone(),
3977 mentions: request.mentions.clone(),
3978 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3979 nonce: Some(nonce),
3980 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3981 edited_at: Some(updated_at.unix_timestamp() as u64),
3982 };
3983
3984 response.send(proto::Ack {})?;
3985
3986 let pool = &*session.connection_pool().await;
3987 broadcast(
3988 Some(session.connection_id),
3989 participant_connection_ids,
3990 |connection| {
3991 session.peer.send(
3992 connection,
3993 proto::ChannelMessageUpdate {
3994 channel_id: channel_id.to_proto(),
3995 message: Some(message.clone()),
3996 },
3997 )?;
3998
3999 for notification_id in &deleted_mention_notification_ids {
4000 session.peer.send(
4001 connection,
4002 proto::DeleteNotification {
4003 notification_id: (*notification_id).to_proto(),
4004 },
4005 )?;
4006 }
4007
4008 for notification in &updated_mention_notifications {
4009 session.peer.send(
4010 connection,
4011 proto::UpdateNotification {
4012 notification: Some(notification.clone()),
4013 },
4014 )?;
4015 }
4016
4017 Ok(())
4018 },
4019 );
4020
4021 send_notifications(pool, &session.peer, notifications);
4022
4023 Ok(())
4024}
4025
4026/// Mark a channel message as read
4027async fn acknowledge_channel_message(
4028 request: proto::AckChannelMessage,
4029 session: Session,
4030) -> Result<()> {
4031 let channel_id = ChannelId::from_proto(request.channel_id);
4032 let message_id = MessageId::from_proto(request.message_id);
4033 let notifications = session
4034 .db()
4035 .await
4036 .observe_channel_message(channel_id, session.user_id(), message_id)
4037 .await?;
4038 send_notifications(
4039 &*session.connection_pool().await,
4040 &session.peer,
4041 notifications,
4042 );
4043 Ok(())
4044}
4045
4046/// Mark a buffer version as synced
4047async fn acknowledge_buffer_version(
4048 request: proto::AckBufferOperation,
4049 session: Session,
4050) -> Result<()> {
4051 let buffer_id = BufferId::from_proto(request.buffer_id);
4052 session
4053 .db()
4054 .await
4055 .observe_buffer_version(
4056 buffer_id,
4057 session.user_id(),
4058 request.epoch as i32,
4059 &request.version,
4060 )
4061 .await?;
4062 Ok(())
4063}
4064
4065/// Get a Supermaven API key for the user
4066async fn get_supermaven_api_key(
4067 _request: proto::GetSupermavenApiKey,
4068 response: Response<proto::GetSupermavenApiKey>,
4069 session: Session,
4070) -> Result<()> {
4071 let user_id: String = session.user_id().to_string();
4072 if !session.is_staff() {
4073 return Err(anyhow!("supermaven not enabled for this account"))?;
4074 }
4075
4076 let email = session.email().context("user must have an email")?;
4077
4078 let supermaven_admin_api = session
4079 .supermaven_client
4080 .as_ref()
4081 .context("supermaven not configured")?;
4082
4083 let result = supermaven_admin_api
4084 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4085 .await?;
4086
4087 response.send(proto::GetSupermavenApiKeyResponse {
4088 api_key: result.api_key,
4089 })?;
4090
4091 Ok(())
4092}
4093
4094/// Start receiving chat updates for a channel
4095async fn join_channel_chat(
4096 request: proto::JoinChannelChat,
4097 response: Response<proto::JoinChannelChat>,
4098 session: Session,
4099) -> Result<()> {
4100 let channel_id = ChannelId::from_proto(request.channel_id);
4101
4102 let db = session.db().await;
4103 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4104 .await?;
4105 let messages = db
4106 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4107 .await?;
4108 response.send(proto::JoinChannelChatResponse {
4109 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4110 messages,
4111 })?;
4112 Ok(())
4113}
4114
4115/// Stop receiving chat updates for a channel
4116async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4117 let channel_id = ChannelId::from_proto(request.channel_id);
4118 session
4119 .db()
4120 .await
4121 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4122 .await?;
4123 Ok(())
4124}
4125
4126/// Retrieve the chat history for a channel
4127async fn get_channel_messages(
4128 request: proto::GetChannelMessages,
4129 response: Response<proto::GetChannelMessages>,
4130 session: Session,
4131) -> Result<()> {
4132 let channel_id = ChannelId::from_proto(request.channel_id);
4133 let messages = session
4134 .db()
4135 .await
4136 .get_channel_messages(
4137 channel_id,
4138 session.user_id(),
4139 MESSAGE_COUNT_PER_PAGE,
4140 Some(MessageId::from_proto(request.before_message_id)),
4141 )
4142 .await?;
4143 response.send(proto::GetChannelMessagesResponse {
4144 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4145 messages,
4146 })?;
4147 Ok(())
4148}
4149
4150/// Retrieve specific chat messages
4151async fn get_channel_messages_by_id(
4152 request: proto::GetChannelMessagesById,
4153 response: Response<proto::GetChannelMessagesById>,
4154 session: Session,
4155) -> Result<()> {
4156 let message_ids = request
4157 .message_ids
4158 .iter()
4159 .map(|id| MessageId::from_proto(*id))
4160 .collect::<Vec<_>>();
4161 let messages = session
4162 .db()
4163 .await
4164 .get_channel_messages_by_id(session.user_id(), &message_ids)
4165 .await?;
4166 response.send(proto::GetChannelMessagesResponse {
4167 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4168 messages,
4169 })?;
4170 Ok(())
4171}
4172
4173/// Retrieve the current users notifications
4174async fn get_notifications(
4175 request: proto::GetNotifications,
4176 response: Response<proto::GetNotifications>,
4177 session: Session,
4178) -> Result<()> {
4179 let notifications = session
4180 .db()
4181 .await
4182 .get_notifications(
4183 session.user_id(),
4184 NOTIFICATION_COUNT_PER_PAGE,
4185 request.before_id.map(db::NotificationId::from_proto),
4186 )
4187 .await?;
4188 response.send(proto::GetNotificationsResponse {
4189 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4190 notifications,
4191 })?;
4192 Ok(())
4193}
4194
4195/// Mark notifications as read
4196async fn mark_notification_as_read(
4197 request: proto::MarkNotificationRead,
4198 response: Response<proto::MarkNotificationRead>,
4199 session: Session,
4200) -> Result<()> {
4201 let database = &session.db().await;
4202 let notifications = database
4203 .mark_notification_as_read_by_id(
4204 session.user_id(),
4205 NotificationId::from_proto(request.notification_id),
4206 )
4207 .await?;
4208 send_notifications(
4209 &*session.connection_pool().await,
4210 &session.peer,
4211 notifications,
4212 );
4213 response.send(proto::Ack {})?;
4214 Ok(())
4215}
4216
4217/// Get the current users information
4218async fn get_private_user_info(
4219 _request: proto::GetPrivateUserInfo,
4220 response: Response<proto::GetPrivateUserInfo>,
4221 session: Session,
4222) -> Result<()> {
4223 let db = session.db().await;
4224
4225 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4226 let user = db
4227 .get_user_by_id(session.user_id())
4228 .await?
4229 .context("user not found")?;
4230 let flags = db.get_user_flags(session.user_id()).await?;
4231
4232 response.send(proto::GetPrivateUserInfoResponse {
4233 metrics_id,
4234 staff: user.admin,
4235 flags,
4236 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4237 })?;
4238 Ok(())
4239}
4240
4241/// Accept the terms of service (tos) on behalf of the current user
4242async fn accept_terms_of_service(
4243 _request: proto::AcceptTermsOfService,
4244 response: Response<proto::AcceptTermsOfService>,
4245 session: Session,
4246) -> Result<()> {
4247 let db = session.db().await;
4248
4249 let accepted_tos_at = Utc::now();
4250 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4251 .await?;
4252
4253 response.send(proto::AcceptTermsOfServiceResponse {
4254 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4255 })?;
4256
4257 // When the user accepts the terms of service, we want to refresh their LLM
4258 // token to grant access.
4259 session
4260 .peer
4261 .send(session.connection_id, proto::RefreshLlmToken {})?;
4262
4263 Ok(())
4264}
4265
4266async fn get_llm_api_token(
4267 _request: proto::GetLlmToken,
4268 response: Response<proto::GetLlmToken>,
4269 session: Session,
4270) -> Result<()> {
4271 let db = session.db().await;
4272
4273 let flags = db.get_user_flags(session.user_id()).await?;
4274
4275 let user_id = session.user_id();
4276 let user = db
4277 .get_user_by_id(user_id)
4278 .await?
4279 .with_context(|| format!("user {user_id} not found"))?;
4280
4281 if user.accepted_tos_at.is_none() {
4282 Err(anyhow!("terms of service not accepted"))?
4283 }
4284
4285 let stripe_client = session
4286 .app_state
4287 .stripe_client
4288 .as_ref()
4289 .context("failed to retrieve Stripe client")?;
4290
4291 let stripe_billing = session
4292 .app_state
4293 .stripe_billing
4294 .as_ref()
4295 .context("failed to retrieve Stripe billing object")?;
4296
4297 let billing_customer = if let Some(billing_customer) =
4298 db.get_billing_customer_by_user_id(user.id).await?
4299 {
4300 billing_customer
4301 } else {
4302 let customer_id = stripe_billing
4303 .find_or_create_customer_by_email(user.email_address.as_deref())
4304 .await?;
4305
4306 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4307 .await?
4308 .context("billing customer not found")?
4309 };
4310
4311 let billing_subscription =
4312 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4313 billing_subscription
4314 } else {
4315 let stripe_customer_id =
4316 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4317
4318 let stripe_subscription = stripe_billing
4319 .subscribe_to_zed_free(stripe_customer_id)
4320 .await?;
4321
4322 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4323 billing_customer_id: billing_customer.id,
4324 kind: Some(SubscriptionKind::ZedFree),
4325 stripe_subscription_id: stripe_subscription.id.to_string(),
4326 stripe_subscription_status: stripe_subscription.status.into(),
4327 stripe_cancellation_reason: None,
4328 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4329 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4330 })
4331 .await?
4332 };
4333
4334 let billing_preferences = db.get_billing_preferences(user.id).await?;
4335
4336 let token = LlmTokenClaims::create(
4337 &user,
4338 session.is_staff(),
4339 billing_customer,
4340 billing_preferences,
4341 &flags,
4342 billing_subscription,
4343 session.system_id.clone(),
4344 &session.app_state.config,
4345 )?;
4346 response.send(proto::GetLlmTokenResponse { token })?;
4347 Ok(())
4348}
4349
4350fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4351 let message = match message {
4352 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4353 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4354 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4355 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4356 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4357 code: frame.code.into(),
4358 reason: frame.reason.as_str().to_owned().into(),
4359 })),
4360 // We should never receive a frame while reading the message, according
4361 // to the `tungstenite` maintainers:
4362 //
4363 // > It cannot occur when you read messages from the WebSocket, but it
4364 // > can be used when you want to send the raw frames (e.g. you want to
4365 // > send the frames to the WebSocket without composing the full message first).
4366 // >
4367 // > — https://github.com/snapview/tungstenite-rs/issues/268
4368 TungsteniteMessage::Frame(_) => {
4369 bail!("received an unexpected frame while reading the message")
4370 }
4371 };
4372
4373 Ok(message)
4374}
4375
4376fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4377 match message {
4378 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4379 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4380 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4381 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4382 AxumMessage::Close(frame) => {
4383 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4384 code: frame.code.into(),
4385 reason: frame.reason.as_ref().into(),
4386 }))
4387 }
4388 }
4389}
4390
4391fn notify_membership_updated(
4392 connection_pool: &mut ConnectionPool,
4393 result: MembershipUpdated,
4394 user_id: UserId,
4395 peer: &Peer,
4396) {
4397 for membership in &result.new_channels.channel_memberships {
4398 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4399 }
4400 for channel_id in &result.removed_channels {
4401 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4402 }
4403
4404 let user_channels_update = proto::UpdateUserChannels {
4405 channel_memberships: result
4406 .new_channels
4407 .channel_memberships
4408 .iter()
4409 .map(|cm| proto::ChannelMembership {
4410 channel_id: cm.channel_id.to_proto(),
4411 role: cm.role.into(),
4412 })
4413 .collect(),
4414 ..Default::default()
4415 };
4416
4417 let mut update = build_channels_update(result.new_channels);
4418 update.delete_channels = result
4419 .removed_channels
4420 .into_iter()
4421 .map(|id| id.to_proto())
4422 .collect();
4423 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4424
4425 for connection_id in connection_pool.user_connection_ids(user_id) {
4426 peer.send(connection_id, user_channels_update.clone())
4427 .trace_err();
4428 peer.send(connection_id, update.clone()).trace_err();
4429 }
4430}
4431
4432fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4433 proto::UpdateUserChannels {
4434 channel_memberships: channels
4435 .channel_memberships
4436 .iter()
4437 .map(|m| proto::ChannelMembership {
4438 channel_id: m.channel_id.to_proto(),
4439 role: m.role.into(),
4440 })
4441 .collect(),
4442 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4443 observed_channel_message_id: channels.observed_channel_messages.clone(),
4444 }
4445}
4446
4447fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4448 let mut update = proto::UpdateChannels::default();
4449
4450 for channel in channels.channels {
4451 update.channels.push(channel.to_proto());
4452 }
4453
4454 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4455 update.latest_channel_message_ids = channels.latest_channel_messages;
4456
4457 for (channel_id, participants) in channels.channel_participants {
4458 update
4459 .channel_participants
4460 .push(proto::ChannelParticipants {
4461 channel_id: channel_id.to_proto(),
4462 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4463 });
4464 }
4465
4466 for channel in channels.invited_channels {
4467 update.channel_invitations.push(channel.to_proto());
4468 }
4469
4470 update
4471}
4472
4473fn build_initial_contacts_update(
4474 contacts: Vec<db::Contact>,
4475 pool: &ConnectionPool,
4476) -> proto::UpdateContacts {
4477 let mut update = proto::UpdateContacts::default();
4478
4479 for contact in contacts {
4480 match contact {
4481 db::Contact::Accepted { user_id, busy } => {
4482 update.contacts.push(contact_for_user(user_id, busy, pool));
4483 }
4484 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4485 db::Contact::Incoming { user_id } => {
4486 update
4487 .incoming_requests
4488 .push(proto::IncomingContactRequest {
4489 requester_id: user_id.to_proto(),
4490 })
4491 }
4492 }
4493 }
4494
4495 update
4496}
4497
4498fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4499 proto::Contact {
4500 user_id: user_id.to_proto(),
4501 online: pool.is_user_online(user_id),
4502 busy,
4503 }
4504}
4505
4506fn room_updated(room: &proto::Room, peer: &Peer) {
4507 broadcast(
4508 None,
4509 room.participants
4510 .iter()
4511 .filter_map(|participant| Some(participant.peer_id?.into())),
4512 |peer_id| {
4513 peer.send(
4514 peer_id,
4515 proto::RoomUpdated {
4516 room: Some(room.clone()),
4517 },
4518 )
4519 },
4520 );
4521}
4522
4523fn channel_updated(
4524 channel: &db::channel::Model,
4525 room: &proto::Room,
4526 peer: &Peer,
4527 pool: &ConnectionPool,
4528) {
4529 let participants = room
4530 .participants
4531 .iter()
4532 .map(|p| p.user_id)
4533 .collect::<Vec<_>>();
4534
4535 broadcast(
4536 None,
4537 pool.channel_connection_ids(channel.root_id())
4538 .filter_map(|(channel_id, role)| {
4539 role.can_see_channel(channel.visibility)
4540 .then_some(channel_id)
4541 }),
4542 |peer_id| {
4543 peer.send(
4544 peer_id,
4545 proto::UpdateChannels {
4546 channel_participants: vec![proto::ChannelParticipants {
4547 channel_id: channel.id.to_proto(),
4548 participant_user_ids: participants.clone(),
4549 }],
4550 ..Default::default()
4551 },
4552 )
4553 },
4554 );
4555}
4556
4557async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4558 let db = session.db().await;
4559
4560 let contacts = db.get_contacts(user_id).await?;
4561 let busy = db.is_user_busy(user_id).await?;
4562
4563 let pool = session.connection_pool().await;
4564 let updated_contact = contact_for_user(user_id, busy, &pool);
4565 for contact in contacts {
4566 if let db::Contact::Accepted {
4567 user_id: contact_user_id,
4568 ..
4569 } = contact
4570 {
4571 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4572 session
4573 .peer
4574 .send(
4575 contact_conn_id,
4576 proto::UpdateContacts {
4577 contacts: vec![updated_contact.clone()],
4578 remove_contacts: Default::default(),
4579 incoming_requests: Default::default(),
4580 remove_incoming_requests: Default::default(),
4581 outgoing_requests: Default::default(),
4582 remove_outgoing_requests: Default::default(),
4583 },
4584 )
4585 .trace_err();
4586 }
4587 }
4588 }
4589 Ok(())
4590}
4591
4592async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4593 let mut contacts_to_update = HashSet::default();
4594
4595 let room_id;
4596 let canceled_calls_to_user_ids;
4597 let livekit_room;
4598 let delete_livekit_room;
4599 let room;
4600 let channel;
4601
4602 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4603 contacts_to_update.insert(session.user_id());
4604
4605 for project in left_room.left_projects.values() {
4606 project_left(project, session);
4607 }
4608
4609 room_id = RoomId::from_proto(left_room.room.id);
4610 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4611 livekit_room = mem::take(&mut left_room.room.livekit_room);
4612 delete_livekit_room = left_room.deleted;
4613 room = mem::take(&mut left_room.room);
4614 channel = mem::take(&mut left_room.channel);
4615
4616 room_updated(&room, &session.peer);
4617 } else {
4618 return Ok(());
4619 }
4620
4621 if let Some(channel) = channel {
4622 channel_updated(
4623 &channel,
4624 &room,
4625 &session.peer,
4626 &*session.connection_pool().await,
4627 );
4628 }
4629
4630 {
4631 let pool = session.connection_pool().await;
4632 for canceled_user_id in canceled_calls_to_user_ids {
4633 for connection_id in pool.user_connection_ids(canceled_user_id) {
4634 session
4635 .peer
4636 .send(
4637 connection_id,
4638 proto::CallCanceled {
4639 room_id: room_id.to_proto(),
4640 },
4641 )
4642 .trace_err();
4643 }
4644 contacts_to_update.insert(canceled_user_id);
4645 }
4646 }
4647
4648 for contact_user_id in contacts_to_update {
4649 update_user_contacts(contact_user_id, session).await?;
4650 }
4651
4652 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4653 live_kit
4654 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4655 .await
4656 .trace_err();
4657
4658 if delete_livekit_room {
4659 live_kit.delete_room(livekit_room).await.trace_err();
4660 }
4661 }
4662
4663 Ok(())
4664}
4665
4666async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4667 let left_channel_buffers = session
4668 .db()
4669 .await
4670 .leave_channel_buffers(session.connection_id)
4671 .await?;
4672
4673 for left_buffer in left_channel_buffers {
4674 channel_buffer_updated(
4675 session.connection_id,
4676 left_buffer.connections,
4677 &proto::UpdateChannelBufferCollaborators {
4678 channel_id: left_buffer.channel_id.to_proto(),
4679 collaborators: left_buffer.collaborators,
4680 },
4681 &session.peer,
4682 );
4683 }
4684
4685 Ok(())
4686}
4687
4688fn project_left(project: &db::LeftProject, session: &Session) {
4689 for connection_id in &project.connection_ids {
4690 if project.should_unshare {
4691 session
4692 .peer
4693 .send(
4694 *connection_id,
4695 proto::UnshareProject {
4696 project_id: project.id.to_proto(),
4697 },
4698 )
4699 .trace_err();
4700 } else {
4701 session
4702 .peer
4703 .send(
4704 *connection_id,
4705 proto::RemoveProjectCollaborator {
4706 project_id: project.id.to_proto(),
4707 peer_id: Some(session.connection_id.into()),
4708 },
4709 )
4710 .trace_err();
4711 }
4712 }
4713}
4714
4715pub trait ResultExt {
4716 type Ok;
4717
4718 fn trace_err(self) -> Option<Self::Ok>;
4719}
4720
4721impl<T, E> ResultExt for Result<T, E>
4722where
4723 E: std::fmt::Debug,
4724{
4725 type Ok = T;
4726
4727 #[track_caller]
4728 fn trace_err(self) -> Option<T> {
4729 match self {
4730 Ok(value) => Some(value),
4731 Err(error) => {
4732 tracing::error!("{:?}", error);
4733 None
4734 }
4735 }
4736 }
4737}