1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 AppState, Config, Error, RateLimit, Result, auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14};
15use anyhow::{Context as _, anyhow, bail};
16use async_tungstenite::tungstenite::{
17 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
18};
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use chrono::Utc;
33use collections::{HashMap, HashSet};
34pub use connection_pool::{ConnectionPool, ZedVersion};
35use core::fmt::{self, Debug, Formatter};
36use http_client::HttpClient;
37use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
38use reqwest_client::ReqwestClient;
39use rpc::proto::split_repository_update;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
45 stream::FuturesUnordered,
46};
47use prometheus::{IntGauge, register_int_gauge};
48use rpc::{
49 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
50 proto::{
51 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
52 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
53 },
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 Arc, OnceLock,
67 atomic::{AtomicBool, Ordering::SeqCst},
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{MutexGuard, Semaphore, watch};
73use tower::ServiceBuilder;
74use tracing::{
75 Instrument,
76 field::{self},
77 info_span, instrument,
78};
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106#[derive(Clone, Debug)]
107pub enum Principal {
108 User(User),
109 Impersonated { user: User, admin: User },
110}
111
112impl Principal {
113 fn update_span(&self, span: &tracing::Span) {
114 match &self {
115 Principal::User(user) => {
116 span.record("user_id", user.id.0);
117 span.record("login", &user.github_login);
118 }
119 Principal::Impersonated { user, admin } => {
120 span.record("user_id", user.id.0);
121 span.record("login", &user.github_login);
122 span.record("impersonator", &admin.github_login);
123 }
124 }
125 }
126}
127
128#[derive(Clone)]
129struct Session {
130 principal: Principal,
131 connection_id: ConnectionId,
132 db: Arc<tokio::sync::Mutex<DbHandle>>,
133 peer: Arc<Peer>,
134 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
135 app_state: Arc<AppState>,
136 supermaven_client: Option<Arc<SupermavenAdminApi>>,
137 http_client: Arc<dyn HttpClient>,
138 /// The GeoIP country code for the user.
139 #[allow(unused)]
140 geoip_country_code: Option<String>,
141 system_id: Option<String>,
142 _executor: Executor,
143}
144
145impl Session {
146 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
147 #[cfg(test)]
148 tokio::task::yield_now().await;
149 let guard = self.db.lock().await;
150 #[cfg(test)]
151 tokio::task::yield_now().await;
152 guard
153 }
154
155 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 let guard = self.connection_pool.lock();
159 ConnectionPoolGuard {
160 guard,
161 _not_send: PhantomData,
162 }
163 }
164
165 fn is_staff(&self) -> bool {
166 match &self.principal {
167 Principal::User(user) => user.admin,
168 Principal::Impersonated { .. } => true,
169 }
170 }
171
172 pub async fn has_llm_subscription(
173 &self,
174 db: &MutexGuard<'_, DbHandle>,
175 ) -> anyhow::Result<bool> {
176 if self.is_staff() {
177 return Ok(true);
178 }
179
180 let user_id = self.user_id();
181
182 Ok(db.has_active_billing_subscription(user_id).await?)
183 }
184
185 pub async fn current_plan(
186 &self,
187 _db: &MutexGuard<'_, DbHandle>,
188 ) -> anyhow::Result<proto::Plan> {
189 if self.is_staff() {
190 Ok(proto::Plan::ZedPro)
191 } else {
192 Ok(proto::Plan::Free)
193 }
194 }
195
196 fn user_id(&self) -> UserId {
197 match &self.principal {
198 Principal::User(user) => user.id,
199 Principal::Impersonated { user, .. } => user.id,
200 }
201 }
202
203 pub fn email(&self) -> Option<String> {
204 match &self.principal {
205 Principal::User(user) => user.email_address.clone(),
206 Principal::Impersonated { user, .. } => user.email_address.clone(),
207 }
208 }
209}
210
211impl Debug for Session {
212 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
213 let mut result = f.debug_struct("Session");
214 match &self.principal {
215 Principal::User(user) => {
216 result.field("user", &user.github_login);
217 }
218 Principal::Impersonated { user, admin } => {
219 result.field("user", &user.github_login);
220 result.field("impersonator", &admin.github_login);
221 }
222 }
223 result.field("connection_id", &self.connection_id).finish()
224 }
225}
226
227struct DbHandle(Arc<Database>);
228
229impl Deref for DbHandle {
230 type Target = Database;
231
232 fn deref(&self) -> &Self::Target {
233 self.0.as_ref()
234 }
235}
236
237pub struct Server {
238 id: parking_lot::Mutex<ServerId>,
239 peer: Arc<Peer>,
240 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
241 app_state: Arc<AppState>,
242 handlers: HashMap<TypeId, MessageHandler>,
243 teardown: watch::Sender<bool>,
244}
245
246pub(crate) struct ConnectionPoolGuard<'a> {
247 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
248 _not_send: PhantomData<Rc<()>>,
249}
250
251#[derive(Serialize)]
252pub struct ServerSnapshot<'a> {
253 peer: &'a Peer,
254 #[serde(serialize_with = "serialize_deref")]
255 connection_pool: ConnectionPoolGuard<'a>,
256}
257
258pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
259where
260 S: Serializer,
261 T: Deref<Target = U>,
262 U: Serialize,
263{
264 Serialize::serialize(value.deref(), serializer)
265}
266
267impl Server {
268 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
269 let mut server = Self {
270 id: parking_lot::Mutex::new(id),
271 peer: Peer::new(id.0 as u32),
272 app_state: app_state.clone(),
273 connection_pool: Default::default(),
274 handlers: Default::default(),
275 teardown: watch::channel(false).0,
276 };
277
278 server
279 .add_request_handler(ping)
280 .add_request_handler(create_room)
281 .add_request_handler(join_room)
282 .add_request_handler(rejoin_room)
283 .add_request_handler(leave_room)
284 .add_request_handler(set_room_participant_role)
285 .add_request_handler(call)
286 .add_request_handler(cancel_call)
287 .add_message_handler(decline_call)
288 .add_request_handler(update_participant_location)
289 .add_request_handler(share_project)
290 .add_message_handler(unshare_project)
291 .add_request_handler(join_project)
292 .add_message_handler(leave_project)
293 .add_request_handler(update_project)
294 .add_request_handler(update_worktree)
295 .add_request_handler(update_repository)
296 .add_request_handler(remove_repository)
297 .add_message_handler(start_language_server)
298 .add_message_handler(update_language_server)
299 .add_message_handler(update_diagnostic_summary)
300 .add_message_handler(update_worktree_settings)
301 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
302 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
305 .add_request_handler(forward_find_search_candidates_request)
306 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
307 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
308 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
309 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
311 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
312 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
313 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
314 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
315 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
316 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
320 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(
325 forward_read_only_project_request::<proto::LanguageServerIdForName>,
326 )
327 .add_request_handler(
328 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
329 )
330 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
331 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
332 .add_request_handler(
333 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
334 )
335 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
336 .add_request_handler(
337 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
338 )
339 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
358 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
359 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
360 .add_message_handler(create_buffer_for_peer)
361 .add_request_handler(update_buffer)
362 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
363 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
364 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
365 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
366 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
367 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
368 .add_request_handler(get_users)
369 .add_request_handler(fuzzy_search_users)
370 .add_request_handler(request_contact)
371 .add_request_handler(remove_contact)
372 .add_request_handler(respond_to_contact_request)
373 .add_message_handler(subscribe_to_channels)
374 .add_request_handler(create_channel)
375 .add_request_handler(delete_channel)
376 .add_request_handler(invite_channel_member)
377 .add_request_handler(remove_channel_member)
378 .add_request_handler(set_channel_member_role)
379 .add_request_handler(set_channel_visibility)
380 .add_request_handler(rename_channel)
381 .add_request_handler(join_channel_buffer)
382 .add_request_handler(leave_channel_buffer)
383 .add_message_handler(update_channel_buffer)
384 .add_request_handler(rejoin_channel_buffers)
385 .add_request_handler(get_channel_members)
386 .add_request_handler(respond_to_channel_invite)
387 .add_request_handler(join_channel)
388 .add_request_handler(join_channel_chat)
389 .add_message_handler(leave_channel_chat)
390 .add_request_handler(send_channel_message)
391 .add_request_handler(remove_channel_message)
392 .add_request_handler(update_channel_message)
393 .add_request_handler(get_channel_messages)
394 .add_request_handler(get_channel_messages_by_id)
395 .add_request_handler(get_notifications)
396 .add_request_handler(mark_notification_as_read)
397 .add_request_handler(move_channel)
398 .add_request_handler(follow)
399 .add_message_handler(unfollow)
400 .add_message_handler(update_followers)
401 .add_request_handler(get_private_user_info)
402 .add_request_handler(get_llm_api_token)
403 .add_request_handler(accept_terms_of_service)
404 .add_message_handler(acknowledge_channel_message)
405 .add_message_handler(acknowledge_buffer_version)
406 .add_request_handler(get_supermaven_api_key)
407 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
408 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
409 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
410 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
411 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
412 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
413 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
414 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
415 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
416 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
417 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
418 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
419 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
420 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
421 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
422 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
423 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
424 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
425 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
426 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
427 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
428 .add_message_handler(update_context)
429 .add_request_handler({
430 let app_state = app_state.clone();
431 move |request, response, session| {
432 let app_state = app_state.clone();
433 async move {
434 count_language_model_tokens(request, response, session, &app_state.config)
435 .await
436 }
437 }
438 })
439 .add_request_handler(get_cached_embeddings)
440 .add_request_handler({
441 let app_state = app_state.clone();
442 move |request, response, session| {
443 compute_embeddings(
444 request,
445 response,
446 session,
447 app_state.config.openai_api_key.clone(),
448 )
449 }
450 });
451
452 Arc::new(server)
453 }
454
455 pub async fn start(&self) -> Result<()> {
456 let server_id = *self.id.lock();
457 let app_state = self.app_state.clone();
458 let peer = self.peer.clone();
459 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
460 let pool = self.connection_pool.clone();
461 let livekit_client = self.app_state.livekit_client.clone();
462
463 let span = info_span!("start server");
464 self.app_state.executor.spawn_detached(
465 async move {
466 tracing::info!("waiting for cleanup timeout");
467 timeout.await;
468 tracing::info!("cleanup timeout expired, retrieving stale rooms");
469 if let Some((room_ids, channel_ids)) = app_state
470 .db
471 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
472 .await
473 .trace_err()
474 {
475 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
476 tracing::info!(
477 stale_channel_buffer_count = channel_ids.len(),
478 "retrieved stale channel buffers"
479 );
480
481 for channel_id in channel_ids {
482 if let Some(refreshed_channel_buffer) = app_state
483 .db
484 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
485 .await
486 .trace_err()
487 {
488 for connection_id in refreshed_channel_buffer.connection_ids {
489 peer.send(
490 connection_id,
491 proto::UpdateChannelBufferCollaborators {
492 channel_id: channel_id.to_proto(),
493 collaborators: refreshed_channel_buffer
494 .collaborators
495 .clone(),
496 },
497 )
498 .trace_err();
499 }
500 }
501 }
502
503 for room_id in room_ids {
504 let mut contacts_to_update = HashSet::default();
505 let mut canceled_calls_to_user_ids = Vec::new();
506 let mut livekit_room = String::new();
507 let mut delete_livekit_room = false;
508
509 if let Some(mut refreshed_room) = app_state
510 .db
511 .clear_stale_room_participants(room_id, server_id)
512 .await
513 .trace_err()
514 {
515 tracing::info!(
516 room_id = room_id.0,
517 new_participant_count = refreshed_room.room.participants.len(),
518 "refreshed room"
519 );
520 room_updated(&refreshed_room.room, &peer);
521 if let Some(channel) = refreshed_room.channel.as_ref() {
522 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
523 }
524 contacts_to_update
525 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
526 contacts_to_update
527 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
528 canceled_calls_to_user_ids =
529 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
530 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
531 delete_livekit_room = refreshed_room.room.participants.is_empty();
532 }
533
534 {
535 let pool = pool.lock();
536 for canceled_user_id in canceled_calls_to_user_ids {
537 for connection_id in pool.user_connection_ids(canceled_user_id) {
538 peer.send(
539 connection_id,
540 proto::CallCanceled {
541 room_id: room_id.to_proto(),
542 },
543 )
544 .trace_err();
545 }
546 }
547 }
548
549 for user_id in contacts_to_update {
550 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
551 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
552 if let Some((busy, contacts)) = busy.zip(contacts) {
553 let pool = pool.lock();
554 let updated_contact = contact_for_user(user_id, busy, &pool);
555 for contact in contacts {
556 if let db::Contact::Accepted {
557 user_id: contact_user_id,
558 ..
559 } = contact
560 {
561 for contact_conn_id in
562 pool.user_connection_ids(contact_user_id)
563 {
564 peer.send(
565 contact_conn_id,
566 proto::UpdateContacts {
567 contacts: vec![updated_contact.clone()],
568 remove_contacts: Default::default(),
569 incoming_requests: Default::default(),
570 remove_incoming_requests: Default::default(),
571 outgoing_requests: Default::default(),
572 remove_outgoing_requests: Default::default(),
573 },
574 )
575 .trace_err();
576 }
577 }
578 }
579 }
580 }
581
582 if let Some(live_kit) = livekit_client.as_ref() {
583 if delete_livekit_room {
584 live_kit.delete_room(livekit_room).await.trace_err();
585 }
586 }
587 }
588 }
589
590 app_state
591 .db
592 .delete_stale_servers(&app_state.config.zed_environment, server_id)
593 .await
594 .trace_err();
595 }
596 .instrument(span),
597 );
598 Ok(())
599 }
600
601 pub fn teardown(&self) {
602 self.peer.teardown();
603 self.connection_pool.lock().reset();
604 let _ = self.teardown.send(true);
605 }
606
607 #[cfg(test)]
608 pub fn reset(&self, id: ServerId) {
609 self.teardown();
610 *self.id.lock() = id;
611 self.peer.reset(id.0 as u32);
612 let _ = self.teardown.send(false);
613 }
614
615 #[cfg(test)]
616 pub fn id(&self) -> ServerId {
617 *self.id.lock()
618 }
619
620 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
621 where
622 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
623 Fut: 'static + Send + Future<Output = Result<()>>,
624 M: EnvelopedMessage,
625 {
626 let prev_handler = self.handlers.insert(
627 TypeId::of::<M>(),
628 Box::new(move |envelope, session| {
629 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
630 let received_at = envelope.received_at;
631 tracing::info!("message received");
632 let start_time = Instant::now();
633 let future = (handler)(*envelope, session);
634 async move {
635 let result = future.await;
636 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
637 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
638 let queue_duration_ms = total_duration_ms - processing_duration_ms;
639 let payload_type = M::NAME;
640
641 match result {
642 Err(error) => {
643 tracing::error!(
644 ?error,
645 total_duration_ms,
646 processing_duration_ms,
647 queue_duration_ms,
648 payload_type,
649 "error handling message"
650 )
651 }
652 Ok(()) => tracing::info!(
653 total_duration_ms,
654 processing_duration_ms,
655 queue_duration_ms,
656 "finished handling message"
657 ),
658 }
659 }
660 .boxed()
661 }),
662 );
663 if prev_handler.is_some() {
664 panic!("registered a handler for the same message twice");
665 }
666 self
667 }
668
669 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
670 where
671 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
672 Fut: 'static + Send + Future<Output = Result<()>>,
673 M: EnvelopedMessage,
674 {
675 self.add_handler(move |envelope, session| handler(envelope.payload, session));
676 self
677 }
678
679 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
680 where
681 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
682 Fut: Send + Future<Output = Result<()>>,
683 M: RequestMessage,
684 {
685 let handler = Arc::new(handler);
686 self.add_handler(move |envelope, session| {
687 let receipt = envelope.receipt();
688 let handler = handler.clone();
689 async move {
690 let peer = session.peer.clone();
691 let responded = Arc::new(AtomicBool::default());
692 let response = Response {
693 peer: peer.clone(),
694 responded: responded.clone(),
695 receipt,
696 };
697 match (handler)(envelope.payload, response, session).await {
698 Ok(()) => {
699 if responded.load(std::sync::atomic::Ordering::SeqCst) {
700 Ok(())
701 } else {
702 Err(anyhow!("handler did not send a response"))?
703 }
704 }
705 Err(error) => {
706 let proto_err = match &error {
707 Error::Internal(err) => err.to_proto(),
708 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
709 };
710 peer.respond_with_error(receipt, proto_err)?;
711 Err(error)
712 }
713 }
714 }
715 })
716 }
717
718 pub fn handle_connection(
719 self: &Arc<Self>,
720 connection: Connection,
721 address: String,
722 principal: Principal,
723 zed_version: ZedVersion,
724 geoip_country_code: Option<String>,
725 system_id: Option<String>,
726 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
727 executor: Executor,
728 ) -> impl Future<Output = ()> + use<> {
729 let this = self.clone();
730 let span = info_span!("handle connection", %address,
731 connection_id=field::Empty,
732 user_id=field::Empty,
733 login=field::Empty,
734 impersonator=field::Empty,
735 geoip_country_code=field::Empty
736 );
737 principal.update_span(&span);
738 if let Some(country_code) = geoip_country_code.as_ref() {
739 span.record("geoip_country_code", country_code);
740 }
741
742 let mut teardown = self.teardown.subscribe();
743 async move {
744 if *teardown.borrow() {
745 tracing::error!("server is tearing down");
746 return
747 }
748 let (connection_id, handle_io, mut incoming_rx) = this
749 .peer
750 .add_connection(connection, {
751 let executor = executor.clone();
752 move |duration| executor.sleep(duration)
753 });
754 tracing::Span::current().record("connection_id", format!("{}", connection_id));
755
756 tracing::info!("connection opened");
757
758 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
759 let http_client = match ReqwestClient::user_agent(&user_agent) {
760 Ok(http_client) => Arc::new(http_client),
761 Err(error) => {
762 tracing::error!(?error, "failed to create HTTP client");
763 return;
764 }
765 };
766
767 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
768 supermaven_admin_api_key.to_string(),
769 http_client.clone(),
770 )));
771
772 let session = Session {
773 principal: principal.clone(),
774 connection_id,
775 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
776 peer: this.peer.clone(),
777 connection_pool: this.connection_pool.clone(),
778 app_state: this.app_state.clone(),
779 http_client,
780 geoip_country_code,
781 system_id,
782 _executor: executor.clone(),
783 supermaven_client,
784 };
785
786 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
787 tracing::error!(?error, "failed to send initial client update");
788 return;
789 }
790
791 let handle_io = handle_io.fuse();
792 futures::pin_mut!(handle_io);
793
794 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
795 // This prevents deadlocks when e.g., client A performs a request to client B and
796 // client B performs a request to client A. If both clients stop processing further
797 // messages until their respective request completes, they won't have a chance to
798 // respond to the other client's request and cause a deadlock.
799 //
800 // This arrangement ensures we will attempt to process earlier messages first, but fall
801 // back to processing messages arrived later in the spirit of making progress.
802 let mut foreground_message_handlers = FuturesUnordered::new();
803 let concurrent_handlers = Arc::new(Semaphore::new(256));
804 loop {
805 let next_message = async {
806 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
807 let message = incoming_rx.next().await;
808 (permit, message)
809 }.fuse();
810 futures::pin_mut!(next_message);
811 futures::select_biased! {
812 _ = teardown.changed().fuse() => return,
813 result = handle_io => {
814 if let Err(error) = result {
815 tracing::error!(?error, "error handling I/O");
816 }
817 break;
818 }
819 _ = foreground_message_handlers.next() => {}
820 next_message = next_message => {
821 let (permit, message) = next_message;
822 if let Some(message) = message {
823 let type_name = message.payload_type_name();
824 // note: we copy all the fields from the parent span so we can query them in the logs.
825 // (https://github.com/tokio-rs/tracing/issues/2670).
826 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
827 user_id=field::Empty,
828 login=field::Empty,
829 impersonator=field::Empty,
830 );
831 principal.update_span(&span);
832 let span_enter = span.enter();
833 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
834 let is_background = message.is_background();
835 let handle_message = (handler)(message, session.clone());
836 drop(span_enter);
837
838 let handle_message = async move {
839 handle_message.await;
840 drop(permit);
841 }.instrument(span);
842 if is_background {
843 executor.spawn_detached(handle_message);
844 } else {
845 foreground_message_handlers.push(handle_message);
846 }
847 } else {
848 tracing::error!("no message handler");
849 }
850 } else {
851 tracing::info!("connection closed");
852 break;
853 }
854 }
855 }
856 }
857
858 drop(foreground_message_handlers);
859 tracing::info!("signing out");
860 if let Err(error) = connection_lost(session, teardown, executor).await {
861 tracing::error!(?error, "error signing out");
862 }
863
864 }.instrument(span)
865 }
866
867 async fn send_initial_client_update(
868 &self,
869 connection_id: ConnectionId,
870 principal: &Principal,
871 zed_version: ZedVersion,
872 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
873 session: &Session,
874 ) -> Result<()> {
875 self.peer.send(
876 connection_id,
877 proto::Hello {
878 peer_id: Some(connection_id.into()),
879 },
880 )?;
881 tracing::info!("sent hello message");
882 if let Some(send_connection_id) = send_connection_id.take() {
883 let _ = send_connection_id.send(connection_id);
884 }
885
886 match principal {
887 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
888 if !user.connected_once {
889 self.peer.send(connection_id, proto::ShowContacts {})?;
890 self.app_state
891 .db
892 .set_user_connected_once(user.id, true)
893 .await?;
894 }
895
896 update_user_plan(user.id, session).await?;
897
898 let contacts = self.app_state.db.get_contacts(user.id).await?;
899
900 {
901 let mut pool = self.connection_pool.lock();
902 pool.add_connection(connection_id, user.id, user.admin, zed_version);
903 self.peer.send(
904 connection_id,
905 build_initial_contacts_update(contacts, &pool),
906 )?;
907 }
908
909 if should_auto_subscribe_to_channels(zed_version) {
910 subscribe_user_to_channels(user.id, session).await?;
911 }
912
913 if let Some(incoming_call) =
914 self.app_state.db.incoming_call_for_user(user.id).await?
915 {
916 self.peer.send(connection_id, incoming_call)?;
917 }
918
919 update_user_contacts(user.id, session).await?;
920 }
921 }
922
923 Ok(())
924 }
925
926 pub async fn invite_code_redeemed(
927 self: &Arc<Self>,
928 inviter_id: UserId,
929 invitee_id: UserId,
930 ) -> Result<()> {
931 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
932 if let Some(code) = &user.invite_code {
933 let pool = self.connection_pool.lock();
934 let invitee_contact = contact_for_user(invitee_id, false, &pool);
935 for connection_id in pool.user_connection_ids(inviter_id) {
936 self.peer.send(
937 connection_id,
938 proto::UpdateContacts {
939 contacts: vec![invitee_contact.clone()],
940 ..Default::default()
941 },
942 )?;
943 self.peer.send(
944 connection_id,
945 proto::UpdateInviteInfo {
946 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
947 count: user.invite_count as u32,
948 },
949 )?;
950 }
951 }
952 }
953 Ok(())
954 }
955
956 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
957 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
958 if let Some(invite_code) = &user.invite_code {
959 let pool = self.connection_pool.lock();
960 for connection_id in pool.user_connection_ids(user_id) {
961 self.peer.send(
962 connection_id,
963 proto::UpdateInviteInfo {
964 url: format!(
965 "{}{}",
966 self.app_state.config.invite_link_prefix, invite_code
967 ),
968 count: user.invite_count as u32,
969 },
970 )?;
971 }
972 }
973 }
974 Ok(())
975 }
976
977 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
978 let pool = self.connection_pool.lock();
979 for connection_id in pool.user_connection_ids(user_id) {
980 self.peer
981 .send(connection_id, proto::RefreshLlmToken {})
982 .trace_err();
983 }
984 }
985
986 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
987 ServerSnapshot {
988 connection_pool: ConnectionPoolGuard {
989 guard: self.connection_pool.lock(),
990 _not_send: PhantomData,
991 },
992 peer: &self.peer,
993 }
994 }
995}
996
997impl Deref for ConnectionPoolGuard<'_> {
998 type Target = ConnectionPool;
999
1000 fn deref(&self) -> &Self::Target {
1001 &self.guard
1002 }
1003}
1004
1005impl DerefMut for ConnectionPoolGuard<'_> {
1006 fn deref_mut(&mut self) -> &mut Self::Target {
1007 &mut self.guard
1008 }
1009}
1010
1011impl Drop for ConnectionPoolGuard<'_> {
1012 fn drop(&mut self) {
1013 #[cfg(test)]
1014 self.check_invariants();
1015 }
1016}
1017
1018fn broadcast<F>(
1019 sender_id: Option<ConnectionId>,
1020 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1021 mut f: F,
1022) where
1023 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1024{
1025 for receiver_id in receiver_ids {
1026 if Some(receiver_id) != sender_id {
1027 if let Err(error) = f(receiver_id) {
1028 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1029 }
1030 }
1031 }
1032}
1033
1034pub struct ProtocolVersion(u32);
1035
1036impl Header for ProtocolVersion {
1037 fn name() -> &'static HeaderName {
1038 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1039 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1040 }
1041
1042 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1043 where
1044 Self: Sized,
1045 I: Iterator<Item = &'i axum::http::HeaderValue>,
1046 {
1047 let version = values
1048 .next()
1049 .ok_or_else(axum::headers::Error::invalid)?
1050 .to_str()
1051 .map_err(|_| axum::headers::Error::invalid())?
1052 .parse()
1053 .map_err(|_| axum::headers::Error::invalid())?;
1054 Ok(Self(version))
1055 }
1056
1057 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1058 values.extend([self.0.to_string().parse().unwrap()]);
1059 }
1060}
1061
1062pub struct AppVersionHeader(SemanticVersion);
1063impl Header for AppVersionHeader {
1064 fn name() -> &'static HeaderName {
1065 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1066 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1067 }
1068
1069 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1070 where
1071 Self: Sized,
1072 I: Iterator<Item = &'i axum::http::HeaderValue>,
1073 {
1074 let version = values
1075 .next()
1076 .ok_or_else(axum::headers::Error::invalid)?
1077 .to_str()
1078 .map_err(|_| axum::headers::Error::invalid())?
1079 .parse()
1080 .map_err(|_| axum::headers::Error::invalid())?;
1081 Ok(Self(version))
1082 }
1083
1084 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1085 values.extend([self.0.to_string().parse().unwrap()]);
1086 }
1087}
1088
1089pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1090 Router::new()
1091 .route("/rpc", get(handle_websocket_request))
1092 .layer(
1093 ServiceBuilder::new()
1094 .layer(Extension(server.app_state.clone()))
1095 .layer(middleware::from_fn(auth::validate_header)),
1096 )
1097 .route("/metrics", get(handle_metrics))
1098 .layer(Extension(server))
1099}
1100
1101pub async fn handle_websocket_request(
1102 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1103 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1104 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1105 Extension(server): Extension<Arc<Server>>,
1106 Extension(principal): Extension<Principal>,
1107 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1108 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1109 ws: WebSocketUpgrade,
1110) -> axum::response::Response {
1111 if protocol_version != rpc::PROTOCOL_VERSION {
1112 return (
1113 StatusCode::UPGRADE_REQUIRED,
1114 "client must be upgraded".to_string(),
1115 )
1116 .into_response();
1117 }
1118
1119 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1120 return (
1121 StatusCode::UPGRADE_REQUIRED,
1122 "no version header found".to_string(),
1123 )
1124 .into_response();
1125 };
1126
1127 if !version.can_collaborate() {
1128 return (
1129 StatusCode::UPGRADE_REQUIRED,
1130 "client must be upgraded".to_string(),
1131 )
1132 .into_response();
1133 }
1134
1135 let socket_address = socket_address.to_string();
1136 ws.on_upgrade(move |socket| {
1137 let socket = socket
1138 .map_ok(to_tungstenite_message)
1139 .err_into()
1140 .with(|message| async move { to_axum_message(message) });
1141 let connection = Connection::new(Box::pin(socket));
1142 async move {
1143 server
1144 .handle_connection(
1145 connection,
1146 socket_address,
1147 principal,
1148 version,
1149 country_code_header.map(|header| header.to_string()),
1150 system_id_header.map(|header| header.to_string()),
1151 None,
1152 Executor::Production,
1153 )
1154 .await;
1155 }
1156 })
1157}
1158
1159pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1160 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1161 let connections_metric = CONNECTIONS_METRIC
1162 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1163
1164 let connections = server
1165 .connection_pool
1166 .lock()
1167 .connections()
1168 .filter(|connection| !connection.admin)
1169 .count();
1170 connections_metric.set(connections as _);
1171
1172 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1173 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1174 register_int_gauge!(
1175 "shared_projects",
1176 "number of open projects with one or more guests"
1177 )
1178 .unwrap()
1179 });
1180
1181 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1182 shared_projects_metric.set(shared_projects as _);
1183
1184 let encoder = prometheus::TextEncoder::new();
1185 let metric_families = prometheus::gather();
1186 let encoded_metrics = encoder
1187 .encode_to_string(&metric_families)
1188 .map_err(|err| anyhow!("{}", err))?;
1189 Ok(encoded_metrics)
1190}
1191
1192#[instrument(err, skip(executor))]
1193async fn connection_lost(
1194 session: Session,
1195 mut teardown: watch::Receiver<bool>,
1196 executor: Executor,
1197) -> Result<()> {
1198 session.peer.disconnect(session.connection_id);
1199 session
1200 .connection_pool()
1201 .await
1202 .remove_connection(session.connection_id)?;
1203
1204 session
1205 .db()
1206 .await
1207 .connection_lost(session.connection_id)
1208 .await
1209 .trace_err();
1210
1211 futures::select_biased! {
1212 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1213
1214 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1215 leave_room_for_session(&session, session.connection_id).await.trace_err();
1216 leave_channel_buffers_for_session(&session)
1217 .await
1218 .trace_err();
1219
1220 if !session
1221 .connection_pool()
1222 .await
1223 .is_user_online(session.user_id())
1224 {
1225 let db = session.db().await;
1226 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1227 room_updated(&room, &session.peer);
1228 }
1229 }
1230
1231 update_user_contacts(session.user_id(), &session).await?;
1232 },
1233 _ = teardown.changed().fuse() => {}
1234 }
1235
1236 Ok(())
1237}
1238
1239/// Acknowledges a ping from a client, used to keep the connection alive.
1240async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1241 response.send(proto::Ack {})?;
1242 Ok(())
1243}
1244
1245/// Creates a new room for calling (outside of channels)
1246async fn create_room(
1247 _request: proto::CreateRoom,
1248 response: Response<proto::CreateRoom>,
1249 session: Session,
1250) -> Result<()> {
1251 let livekit_room = nanoid::nanoid!(30);
1252
1253 let live_kit_connection_info = util::maybe!(async {
1254 let live_kit = session.app_state.livekit_client.as_ref();
1255 let live_kit = live_kit?;
1256 let user_id = session.user_id().to_string();
1257
1258 let token = live_kit
1259 .room_token(&livekit_room, &user_id.to_string())
1260 .trace_err()?;
1261
1262 Some(proto::LiveKitConnectionInfo {
1263 server_url: live_kit.url().into(),
1264 token,
1265 can_publish: true,
1266 })
1267 })
1268 .await;
1269
1270 let room = session
1271 .db()
1272 .await
1273 .create_room(session.user_id(), session.connection_id, &livekit_room)
1274 .await?;
1275
1276 response.send(proto::CreateRoomResponse {
1277 room: Some(room.clone()),
1278 live_kit_connection_info,
1279 })?;
1280
1281 update_user_contacts(session.user_id(), &session).await?;
1282 Ok(())
1283}
1284
1285/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1286async fn join_room(
1287 request: proto::JoinRoom,
1288 response: Response<proto::JoinRoom>,
1289 session: Session,
1290) -> Result<()> {
1291 let room_id = RoomId::from_proto(request.id);
1292
1293 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1294
1295 if let Some(channel_id) = channel_id {
1296 return join_channel_internal(channel_id, Box::new(response), session).await;
1297 }
1298
1299 let joined_room = {
1300 let room = session
1301 .db()
1302 .await
1303 .join_room(room_id, session.user_id(), session.connection_id)
1304 .await?;
1305 room_updated(&room.room, &session.peer);
1306 room.into_inner()
1307 };
1308
1309 for connection_id in session
1310 .connection_pool()
1311 .await
1312 .user_connection_ids(session.user_id())
1313 {
1314 session
1315 .peer
1316 .send(
1317 connection_id,
1318 proto::CallCanceled {
1319 room_id: room_id.to_proto(),
1320 },
1321 )
1322 .trace_err();
1323 }
1324
1325 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1326 {
1327 live_kit
1328 .room_token(
1329 &joined_room.room.livekit_room,
1330 &session.user_id().to_string(),
1331 )
1332 .trace_err()
1333 .map(|token| proto::LiveKitConnectionInfo {
1334 server_url: live_kit.url().into(),
1335 token,
1336 can_publish: true,
1337 })
1338 } else {
1339 None
1340 };
1341
1342 response.send(proto::JoinRoomResponse {
1343 room: Some(joined_room.room),
1344 channel_id: None,
1345 live_kit_connection_info,
1346 })?;
1347
1348 update_user_contacts(session.user_id(), &session).await?;
1349 Ok(())
1350}
1351
1352/// Rejoin room is used to reconnect to a room after connection errors.
1353async fn rejoin_room(
1354 request: proto::RejoinRoom,
1355 response: Response<proto::RejoinRoom>,
1356 session: Session,
1357) -> Result<()> {
1358 let room;
1359 let channel;
1360 {
1361 let mut rejoined_room = session
1362 .db()
1363 .await
1364 .rejoin_room(request, session.user_id(), session.connection_id)
1365 .await?;
1366
1367 response.send(proto::RejoinRoomResponse {
1368 room: Some(rejoined_room.room.clone()),
1369 reshared_projects: rejoined_room
1370 .reshared_projects
1371 .iter()
1372 .map(|project| proto::ResharedProject {
1373 id: project.id.to_proto(),
1374 collaborators: project
1375 .collaborators
1376 .iter()
1377 .map(|collaborator| collaborator.to_proto())
1378 .collect(),
1379 })
1380 .collect(),
1381 rejoined_projects: rejoined_room
1382 .rejoined_projects
1383 .iter()
1384 .map(|rejoined_project| rejoined_project.to_proto())
1385 .collect(),
1386 })?;
1387 room_updated(&rejoined_room.room, &session.peer);
1388
1389 for project in &rejoined_room.reshared_projects {
1390 for collaborator in &project.collaborators {
1391 session
1392 .peer
1393 .send(
1394 collaborator.connection_id,
1395 proto::UpdateProjectCollaborator {
1396 project_id: project.id.to_proto(),
1397 old_peer_id: Some(project.old_connection_id.into()),
1398 new_peer_id: Some(session.connection_id.into()),
1399 },
1400 )
1401 .trace_err();
1402 }
1403
1404 broadcast(
1405 Some(session.connection_id),
1406 project
1407 .collaborators
1408 .iter()
1409 .map(|collaborator| collaborator.connection_id),
1410 |connection_id| {
1411 session.peer.forward_send(
1412 session.connection_id,
1413 connection_id,
1414 proto::UpdateProject {
1415 project_id: project.id.to_proto(),
1416 worktrees: project.worktrees.clone(),
1417 },
1418 )
1419 },
1420 );
1421 }
1422
1423 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1424
1425 let rejoined_room = rejoined_room.into_inner();
1426
1427 room = rejoined_room.room;
1428 channel = rejoined_room.channel;
1429 }
1430
1431 if let Some(channel) = channel {
1432 channel_updated(
1433 &channel,
1434 &room,
1435 &session.peer,
1436 &*session.connection_pool().await,
1437 );
1438 }
1439
1440 update_user_contacts(session.user_id(), &session).await?;
1441 Ok(())
1442}
1443
1444fn notify_rejoined_projects(
1445 rejoined_projects: &mut Vec<RejoinedProject>,
1446 session: &Session,
1447) -> Result<()> {
1448 for project in rejoined_projects.iter() {
1449 for collaborator in &project.collaborators {
1450 session
1451 .peer
1452 .send(
1453 collaborator.connection_id,
1454 proto::UpdateProjectCollaborator {
1455 project_id: project.id.to_proto(),
1456 old_peer_id: Some(project.old_connection_id.into()),
1457 new_peer_id: Some(session.connection_id.into()),
1458 },
1459 )
1460 .trace_err();
1461 }
1462 }
1463
1464 for project in rejoined_projects {
1465 for worktree in mem::take(&mut project.worktrees) {
1466 // Stream this worktree's entries.
1467 let message = proto::UpdateWorktree {
1468 project_id: project.id.to_proto(),
1469 worktree_id: worktree.id,
1470 abs_path: worktree.abs_path.clone(),
1471 root_name: worktree.root_name,
1472 updated_entries: worktree.updated_entries,
1473 removed_entries: worktree.removed_entries,
1474 scan_id: worktree.scan_id,
1475 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1476 updated_repositories: worktree.updated_repositories,
1477 removed_repositories: worktree.removed_repositories,
1478 };
1479 for update in proto::split_worktree_update(message) {
1480 session.peer.send(session.connection_id, update)?;
1481 }
1482
1483 // Stream this worktree's diagnostics.
1484 for summary in worktree.diagnostic_summaries {
1485 session.peer.send(
1486 session.connection_id,
1487 proto::UpdateDiagnosticSummary {
1488 project_id: project.id.to_proto(),
1489 worktree_id: worktree.id,
1490 summary: Some(summary),
1491 },
1492 )?;
1493 }
1494
1495 for settings_file in worktree.settings_files {
1496 session.peer.send(
1497 session.connection_id,
1498 proto::UpdateWorktreeSettings {
1499 project_id: project.id.to_proto(),
1500 worktree_id: worktree.id,
1501 path: settings_file.path,
1502 content: Some(settings_file.content),
1503 kind: Some(settings_file.kind.to_proto().into()),
1504 },
1505 )?;
1506 }
1507 }
1508
1509 for repository in mem::take(&mut project.updated_repositories) {
1510 for update in split_repository_update(repository) {
1511 session.peer.send(session.connection_id, update)?;
1512 }
1513 }
1514
1515 for id in mem::take(&mut project.removed_repositories) {
1516 session.peer.send(
1517 session.connection_id,
1518 proto::RemoveRepository {
1519 project_id: project.id.to_proto(),
1520 id,
1521 },
1522 )?;
1523 }
1524 }
1525
1526 Ok(())
1527}
1528
1529/// leave room disconnects from the room.
1530async fn leave_room(
1531 _: proto::LeaveRoom,
1532 response: Response<proto::LeaveRoom>,
1533 session: Session,
1534) -> Result<()> {
1535 leave_room_for_session(&session, session.connection_id).await?;
1536 response.send(proto::Ack {})?;
1537 Ok(())
1538}
1539
1540/// Updates the permissions of someone else in the room.
1541async fn set_room_participant_role(
1542 request: proto::SetRoomParticipantRole,
1543 response: Response<proto::SetRoomParticipantRole>,
1544 session: Session,
1545) -> Result<()> {
1546 let user_id = UserId::from_proto(request.user_id);
1547 let role = ChannelRole::from(request.role());
1548
1549 let (livekit_room, can_publish) = {
1550 let room = session
1551 .db()
1552 .await
1553 .set_room_participant_role(
1554 session.user_id(),
1555 RoomId::from_proto(request.room_id),
1556 user_id,
1557 role,
1558 )
1559 .await?;
1560
1561 let livekit_room = room.livekit_room.clone();
1562 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1563 room_updated(&room, &session.peer);
1564 (livekit_room, can_publish)
1565 };
1566
1567 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1568 live_kit
1569 .update_participant(
1570 livekit_room.clone(),
1571 request.user_id.to_string(),
1572 livekit_api::proto::ParticipantPermission {
1573 can_subscribe: true,
1574 can_publish,
1575 can_publish_data: can_publish,
1576 hidden: false,
1577 recorder: false,
1578 },
1579 )
1580 .await
1581 .trace_err();
1582 }
1583
1584 response.send(proto::Ack {})?;
1585 Ok(())
1586}
1587
1588/// Call someone else into the current room
1589async fn call(
1590 request: proto::Call,
1591 response: Response<proto::Call>,
1592 session: Session,
1593) -> Result<()> {
1594 let room_id = RoomId::from_proto(request.room_id);
1595 let calling_user_id = session.user_id();
1596 let calling_connection_id = session.connection_id;
1597 let called_user_id = UserId::from_proto(request.called_user_id);
1598 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1599 if !session
1600 .db()
1601 .await
1602 .has_contact(calling_user_id, called_user_id)
1603 .await?
1604 {
1605 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1606 }
1607
1608 let incoming_call = {
1609 let (room, incoming_call) = &mut *session
1610 .db()
1611 .await
1612 .call(
1613 room_id,
1614 calling_user_id,
1615 calling_connection_id,
1616 called_user_id,
1617 initial_project_id,
1618 )
1619 .await?;
1620 room_updated(room, &session.peer);
1621 mem::take(incoming_call)
1622 };
1623 update_user_contacts(called_user_id, &session).await?;
1624
1625 let mut calls = session
1626 .connection_pool()
1627 .await
1628 .user_connection_ids(called_user_id)
1629 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1630 .collect::<FuturesUnordered<_>>();
1631
1632 while let Some(call_response) = calls.next().await {
1633 match call_response.as_ref() {
1634 Ok(_) => {
1635 response.send(proto::Ack {})?;
1636 return Ok(());
1637 }
1638 Err(_) => {
1639 call_response.trace_err();
1640 }
1641 }
1642 }
1643
1644 {
1645 let room = session
1646 .db()
1647 .await
1648 .call_failed(room_id, called_user_id)
1649 .await?;
1650 room_updated(&room, &session.peer);
1651 }
1652 update_user_contacts(called_user_id, &session).await?;
1653
1654 Err(anyhow!("failed to ring user"))?
1655}
1656
1657/// Cancel an outgoing call.
1658async fn cancel_call(
1659 request: proto::CancelCall,
1660 response: Response<proto::CancelCall>,
1661 session: Session,
1662) -> Result<()> {
1663 let called_user_id = UserId::from_proto(request.called_user_id);
1664 let room_id = RoomId::from_proto(request.room_id);
1665 {
1666 let room = session
1667 .db()
1668 .await
1669 .cancel_call(room_id, session.connection_id, called_user_id)
1670 .await?;
1671 room_updated(&room, &session.peer);
1672 }
1673
1674 for connection_id in session
1675 .connection_pool()
1676 .await
1677 .user_connection_ids(called_user_id)
1678 {
1679 session
1680 .peer
1681 .send(
1682 connection_id,
1683 proto::CallCanceled {
1684 room_id: room_id.to_proto(),
1685 },
1686 )
1687 .trace_err();
1688 }
1689 response.send(proto::Ack {})?;
1690
1691 update_user_contacts(called_user_id, &session).await?;
1692 Ok(())
1693}
1694
1695/// Decline an incoming call.
1696async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1697 let room_id = RoomId::from_proto(message.room_id);
1698 {
1699 let room = session
1700 .db()
1701 .await
1702 .decline_call(Some(room_id), session.user_id())
1703 .await?
1704 .ok_or_else(|| anyhow!("failed to decline call"))?;
1705 room_updated(&room, &session.peer);
1706 }
1707
1708 for connection_id in session
1709 .connection_pool()
1710 .await
1711 .user_connection_ids(session.user_id())
1712 {
1713 session
1714 .peer
1715 .send(
1716 connection_id,
1717 proto::CallCanceled {
1718 room_id: room_id.to_proto(),
1719 },
1720 )
1721 .trace_err();
1722 }
1723 update_user_contacts(session.user_id(), &session).await?;
1724 Ok(())
1725}
1726
1727/// Updates other participants in the room with your current location.
1728async fn update_participant_location(
1729 request: proto::UpdateParticipantLocation,
1730 response: Response<proto::UpdateParticipantLocation>,
1731 session: Session,
1732) -> Result<()> {
1733 let room_id = RoomId::from_proto(request.room_id);
1734 let location = request
1735 .location
1736 .ok_or_else(|| anyhow!("invalid location"))?;
1737
1738 let db = session.db().await;
1739 let room = db
1740 .update_room_participant_location(room_id, session.connection_id, location)
1741 .await?;
1742
1743 room_updated(&room, &session.peer);
1744 response.send(proto::Ack {})?;
1745 Ok(())
1746}
1747
1748/// Share a project into the room.
1749async fn share_project(
1750 request: proto::ShareProject,
1751 response: Response<proto::ShareProject>,
1752 session: Session,
1753) -> Result<()> {
1754 let (project_id, room) = &*session
1755 .db()
1756 .await
1757 .share_project(
1758 RoomId::from_proto(request.room_id),
1759 session.connection_id,
1760 &request.worktrees,
1761 request.is_ssh_project,
1762 )
1763 .await?;
1764 response.send(proto::ShareProjectResponse {
1765 project_id: project_id.to_proto(),
1766 })?;
1767 room_updated(room, &session.peer);
1768
1769 Ok(())
1770}
1771
1772/// Unshare a project from the room.
1773async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1774 let project_id = ProjectId::from_proto(message.project_id);
1775 unshare_project_internal(project_id, session.connection_id, &session).await
1776}
1777
1778async fn unshare_project_internal(
1779 project_id: ProjectId,
1780 connection_id: ConnectionId,
1781 session: &Session,
1782) -> Result<()> {
1783 let delete = {
1784 let room_guard = session
1785 .db()
1786 .await
1787 .unshare_project(project_id, connection_id)
1788 .await?;
1789
1790 let (delete, room, guest_connection_ids) = &*room_guard;
1791
1792 let message = proto::UnshareProject {
1793 project_id: project_id.to_proto(),
1794 };
1795
1796 broadcast(
1797 Some(connection_id),
1798 guest_connection_ids.iter().copied(),
1799 |conn_id| session.peer.send(conn_id, message.clone()),
1800 );
1801 if let Some(room) = room {
1802 room_updated(room, &session.peer);
1803 }
1804
1805 *delete
1806 };
1807
1808 if delete {
1809 let db = session.db().await;
1810 db.delete_project(project_id).await?;
1811 }
1812
1813 Ok(())
1814}
1815
1816/// Join someone elses shared project.
1817async fn join_project(
1818 request: proto::JoinProject,
1819 response: Response<proto::JoinProject>,
1820 session: Session,
1821) -> Result<()> {
1822 let project_id = ProjectId::from_proto(request.project_id);
1823
1824 tracing::info!(%project_id, "join project");
1825
1826 let db = session.db().await;
1827 let (project, replica_id) = &mut *db
1828 .join_project(project_id, session.connection_id, session.user_id())
1829 .await?;
1830 drop(db);
1831 tracing::info!(%project_id, "join remote project");
1832 join_project_internal(response, session, project, replica_id)
1833}
1834
1835trait JoinProjectInternalResponse {
1836 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1837}
1838impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1839 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1840 Response::<proto::JoinProject>::send(self, result)
1841 }
1842}
1843
1844fn join_project_internal(
1845 response: impl JoinProjectInternalResponse,
1846 session: Session,
1847 project: &mut Project,
1848 replica_id: &ReplicaId,
1849) -> Result<()> {
1850 let collaborators = project
1851 .collaborators
1852 .iter()
1853 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1854 .map(|collaborator| collaborator.to_proto())
1855 .collect::<Vec<_>>();
1856 let project_id = project.id;
1857 let guest_user_id = session.user_id();
1858
1859 let worktrees = project
1860 .worktrees
1861 .iter()
1862 .map(|(id, worktree)| proto::WorktreeMetadata {
1863 id: *id,
1864 root_name: worktree.root_name.clone(),
1865 visible: worktree.visible,
1866 abs_path: worktree.abs_path.clone(),
1867 })
1868 .collect::<Vec<_>>();
1869
1870 let add_project_collaborator = proto::AddProjectCollaborator {
1871 project_id: project_id.to_proto(),
1872 collaborator: Some(proto::Collaborator {
1873 peer_id: Some(session.connection_id.into()),
1874 replica_id: replica_id.0 as u32,
1875 user_id: guest_user_id.to_proto(),
1876 is_host: false,
1877 }),
1878 };
1879
1880 for collaborator in &collaborators {
1881 session
1882 .peer
1883 .send(
1884 collaborator.peer_id.unwrap().into(),
1885 add_project_collaborator.clone(),
1886 )
1887 .trace_err();
1888 }
1889
1890 // First, we send the metadata associated with each worktree.
1891 response.send(proto::JoinProjectResponse {
1892 project_id: project.id.0 as u64,
1893 worktrees: worktrees.clone(),
1894 replica_id: replica_id.0 as u32,
1895 collaborators: collaborators.clone(),
1896 language_servers: project.language_servers.clone(),
1897 role: project.role.into(),
1898 })?;
1899
1900 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1901 // Stream this worktree's entries.
1902 let message = proto::UpdateWorktree {
1903 project_id: project_id.to_proto(),
1904 worktree_id,
1905 abs_path: worktree.abs_path.clone(),
1906 root_name: worktree.root_name,
1907 updated_entries: worktree.entries,
1908 removed_entries: Default::default(),
1909 scan_id: worktree.scan_id,
1910 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1911 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1912 removed_repositories: Default::default(),
1913 };
1914 for update in proto::split_worktree_update(message) {
1915 session.peer.send(session.connection_id, update.clone())?;
1916 }
1917
1918 // Stream this worktree's diagnostics.
1919 for summary in worktree.diagnostic_summaries {
1920 session.peer.send(
1921 session.connection_id,
1922 proto::UpdateDiagnosticSummary {
1923 project_id: project_id.to_proto(),
1924 worktree_id: worktree.id,
1925 summary: Some(summary),
1926 },
1927 )?;
1928 }
1929
1930 for settings_file in worktree.settings_files {
1931 session.peer.send(
1932 session.connection_id,
1933 proto::UpdateWorktreeSettings {
1934 project_id: project_id.to_proto(),
1935 worktree_id: worktree.id,
1936 path: settings_file.path,
1937 content: Some(settings_file.content),
1938 kind: Some(settings_file.kind.to_proto() as i32),
1939 },
1940 )?;
1941 }
1942 }
1943
1944 for repository in mem::take(&mut project.repositories) {
1945 for update in split_repository_update(repository) {
1946 session.peer.send(session.connection_id, update)?;
1947 }
1948 }
1949
1950 for language_server in &project.language_servers {
1951 session.peer.send(
1952 session.connection_id,
1953 proto::UpdateLanguageServer {
1954 project_id: project_id.to_proto(),
1955 language_server_id: language_server.id,
1956 variant: Some(
1957 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1958 proto::LspDiskBasedDiagnosticsUpdated {},
1959 ),
1960 ),
1961 },
1962 )?;
1963 }
1964
1965 Ok(())
1966}
1967
1968/// Leave someone elses shared project.
1969async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1970 let sender_id = session.connection_id;
1971 let project_id = ProjectId::from_proto(request.project_id);
1972 let db = session.db().await;
1973
1974 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1975 tracing::info!(
1976 %project_id,
1977 "leave project"
1978 );
1979
1980 project_left(project, &session);
1981 if let Some(room) = room {
1982 room_updated(room, &session.peer);
1983 }
1984
1985 Ok(())
1986}
1987
1988/// Updates other participants with changes to the project
1989async fn update_project(
1990 request: proto::UpdateProject,
1991 response: Response<proto::UpdateProject>,
1992 session: Session,
1993) -> Result<()> {
1994 let project_id = ProjectId::from_proto(request.project_id);
1995 let (room, guest_connection_ids) = &*session
1996 .db()
1997 .await
1998 .update_project(project_id, session.connection_id, &request.worktrees)
1999 .await?;
2000 broadcast(
2001 Some(session.connection_id),
2002 guest_connection_ids.iter().copied(),
2003 |connection_id| {
2004 session
2005 .peer
2006 .forward_send(session.connection_id, connection_id, request.clone())
2007 },
2008 );
2009 if let Some(room) = room {
2010 room_updated(room, &session.peer);
2011 }
2012 response.send(proto::Ack {})?;
2013
2014 Ok(())
2015}
2016
2017/// Updates other participants with changes to the worktree
2018async fn update_worktree(
2019 request: proto::UpdateWorktree,
2020 response: Response<proto::UpdateWorktree>,
2021 session: Session,
2022) -> Result<()> {
2023 let guest_connection_ids = session
2024 .db()
2025 .await
2026 .update_worktree(&request, session.connection_id)
2027 .await?;
2028
2029 broadcast(
2030 Some(session.connection_id),
2031 guest_connection_ids.iter().copied(),
2032 |connection_id| {
2033 session
2034 .peer
2035 .forward_send(session.connection_id, connection_id, request.clone())
2036 },
2037 );
2038 response.send(proto::Ack {})?;
2039 Ok(())
2040}
2041
2042async fn update_repository(
2043 request: proto::UpdateRepository,
2044 response: Response<proto::UpdateRepository>,
2045 session: Session,
2046) -> Result<()> {
2047 let guest_connection_ids = session
2048 .db()
2049 .await
2050 .update_repository(&request, session.connection_id)
2051 .await?;
2052
2053 broadcast(
2054 Some(session.connection_id),
2055 guest_connection_ids.iter().copied(),
2056 |connection_id| {
2057 session
2058 .peer
2059 .forward_send(session.connection_id, connection_id, request.clone())
2060 },
2061 );
2062 response.send(proto::Ack {})?;
2063 Ok(())
2064}
2065
2066async fn remove_repository(
2067 request: proto::RemoveRepository,
2068 response: Response<proto::RemoveRepository>,
2069 session: Session,
2070) -> Result<()> {
2071 let guest_connection_ids = session
2072 .db()
2073 .await
2074 .remove_repository(&request, session.connection_id)
2075 .await?;
2076
2077 broadcast(
2078 Some(session.connection_id),
2079 guest_connection_ids.iter().copied(),
2080 |connection_id| {
2081 session
2082 .peer
2083 .forward_send(session.connection_id, connection_id, request.clone())
2084 },
2085 );
2086 response.send(proto::Ack {})?;
2087 Ok(())
2088}
2089
2090/// Updates other participants with changes to the diagnostics
2091async fn update_diagnostic_summary(
2092 message: proto::UpdateDiagnosticSummary,
2093 session: Session,
2094) -> Result<()> {
2095 let guest_connection_ids = session
2096 .db()
2097 .await
2098 .update_diagnostic_summary(&message, session.connection_id)
2099 .await?;
2100
2101 broadcast(
2102 Some(session.connection_id),
2103 guest_connection_ids.iter().copied(),
2104 |connection_id| {
2105 session
2106 .peer
2107 .forward_send(session.connection_id, connection_id, message.clone())
2108 },
2109 );
2110
2111 Ok(())
2112}
2113
2114/// Updates other participants with changes to the worktree settings
2115async fn update_worktree_settings(
2116 message: proto::UpdateWorktreeSettings,
2117 session: Session,
2118) -> Result<()> {
2119 let guest_connection_ids = session
2120 .db()
2121 .await
2122 .update_worktree_settings(&message, session.connection_id)
2123 .await?;
2124
2125 broadcast(
2126 Some(session.connection_id),
2127 guest_connection_ids.iter().copied(),
2128 |connection_id| {
2129 session
2130 .peer
2131 .forward_send(session.connection_id, connection_id, message.clone())
2132 },
2133 );
2134
2135 Ok(())
2136}
2137
2138/// Notify other participants that a language server has started.
2139async fn start_language_server(
2140 request: proto::StartLanguageServer,
2141 session: Session,
2142) -> Result<()> {
2143 let guest_connection_ids = session
2144 .db()
2145 .await
2146 .start_language_server(&request, session.connection_id)
2147 .await?;
2148
2149 broadcast(
2150 Some(session.connection_id),
2151 guest_connection_ids.iter().copied(),
2152 |connection_id| {
2153 session
2154 .peer
2155 .forward_send(session.connection_id, connection_id, request.clone())
2156 },
2157 );
2158 Ok(())
2159}
2160
2161/// Notify other participants that a language server has changed.
2162async fn update_language_server(
2163 request: proto::UpdateLanguageServer,
2164 session: Session,
2165) -> Result<()> {
2166 let project_id = ProjectId::from_proto(request.project_id);
2167 let project_connection_ids = session
2168 .db()
2169 .await
2170 .project_connection_ids(project_id, session.connection_id, true)
2171 .await?;
2172 broadcast(
2173 Some(session.connection_id),
2174 project_connection_ids.iter().copied(),
2175 |connection_id| {
2176 session
2177 .peer
2178 .forward_send(session.connection_id, connection_id, request.clone())
2179 },
2180 );
2181 Ok(())
2182}
2183
2184/// forward a project request to the host. These requests should be read only
2185/// as guests are allowed to send them.
2186async fn forward_read_only_project_request<T>(
2187 request: T,
2188 response: Response<T>,
2189 session: Session,
2190) -> Result<()>
2191where
2192 T: EntityMessage + RequestMessage,
2193{
2194 let project_id = ProjectId::from_proto(request.remote_entity_id());
2195 let host_connection_id = session
2196 .db()
2197 .await
2198 .host_for_read_only_project_request(project_id, session.connection_id)
2199 .await?;
2200 let payload = session
2201 .peer
2202 .forward_request(session.connection_id, host_connection_id, request)
2203 .await?;
2204 response.send(payload)?;
2205 Ok(())
2206}
2207
2208async fn forward_find_search_candidates_request(
2209 request: proto::FindSearchCandidates,
2210 response: Response<proto::FindSearchCandidates>,
2211 session: Session,
2212) -> Result<()> {
2213 let project_id = ProjectId::from_proto(request.remote_entity_id());
2214 let host_connection_id = session
2215 .db()
2216 .await
2217 .host_for_read_only_project_request(project_id, session.connection_id)
2218 .await?;
2219 let payload = session
2220 .peer
2221 .forward_request(session.connection_id, host_connection_id, request)
2222 .await?;
2223 response.send(payload)?;
2224 Ok(())
2225}
2226
2227/// forward a project request to the host. These requests are disallowed
2228/// for guests.
2229async fn forward_mutating_project_request<T>(
2230 request: T,
2231 response: Response<T>,
2232 session: Session,
2233) -> Result<()>
2234where
2235 T: EntityMessage + RequestMessage,
2236{
2237 let project_id = ProjectId::from_proto(request.remote_entity_id());
2238
2239 let host_connection_id = session
2240 .db()
2241 .await
2242 .host_for_mutating_project_request(project_id, session.connection_id)
2243 .await?;
2244 let payload = session
2245 .peer
2246 .forward_request(session.connection_id, host_connection_id, request)
2247 .await?;
2248 response.send(payload)?;
2249 Ok(())
2250}
2251
2252/// Notify other participants that a new buffer has been created
2253async fn create_buffer_for_peer(
2254 request: proto::CreateBufferForPeer,
2255 session: Session,
2256) -> Result<()> {
2257 session
2258 .db()
2259 .await
2260 .check_user_is_project_host(
2261 ProjectId::from_proto(request.project_id),
2262 session.connection_id,
2263 )
2264 .await?;
2265 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2266 session
2267 .peer
2268 .forward_send(session.connection_id, peer_id.into(), request)?;
2269 Ok(())
2270}
2271
2272/// Notify other participants that a buffer has been updated. This is
2273/// allowed for guests as long as the update is limited to selections.
2274async fn update_buffer(
2275 request: proto::UpdateBuffer,
2276 response: Response<proto::UpdateBuffer>,
2277 session: Session,
2278) -> Result<()> {
2279 let project_id = ProjectId::from_proto(request.project_id);
2280 let mut capability = Capability::ReadOnly;
2281
2282 for op in request.operations.iter() {
2283 match op.variant {
2284 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2285 Some(_) => capability = Capability::ReadWrite,
2286 }
2287 }
2288
2289 let host = {
2290 let guard = session
2291 .db()
2292 .await
2293 .connections_for_buffer_update(project_id, session.connection_id, capability)
2294 .await?;
2295
2296 let (host, guests) = &*guard;
2297
2298 broadcast(
2299 Some(session.connection_id),
2300 guests.clone(),
2301 |connection_id| {
2302 session
2303 .peer
2304 .forward_send(session.connection_id, connection_id, request.clone())
2305 },
2306 );
2307
2308 *host
2309 };
2310
2311 if host != session.connection_id {
2312 session
2313 .peer
2314 .forward_request(session.connection_id, host, request.clone())
2315 .await?;
2316 }
2317
2318 response.send(proto::Ack {})?;
2319 Ok(())
2320}
2321
2322async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2323 let project_id = ProjectId::from_proto(message.project_id);
2324
2325 let operation = message.operation.as_ref().context("invalid operation")?;
2326 let capability = match operation.variant.as_ref() {
2327 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2328 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2329 match buffer_op.variant {
2330 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2331 Capability::ReadOnly
2332 }
2333 _ => Capability::ReadWrite,
2334 }
2335 } else {
2336 Capability::ReadWrite
2337 }
2338 }
2339 Some(_) => Capability::ReadWrite,
2340 None => Capability::ReadOnly,
2341 };
2342
2343 let guard = session
2344 .db()
2345 .await
2346 .connections_for_buffer_update(project_id, session.connection_id, capability)
2347 .await?;
2348
2349 let (host, guests) = &*guard;
2350
2351 broadcast(
2352 Some(session.connection_id),
2353 guests.iter().chain([host]).copied(),
2354 |connection_id| {
2355 session
2356 .peer
2357 .forward_send(session.connection_id, connection_id, message.clone())
2358 },
2359 );
2360
2361 Ok(())
2362}
2363
2364/// Notify other participants that a project has been updated.
2365async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2366 request: T,
2367 session: Session,
2368) -> Result<()> {
2369 let project_id = ProjectId::from_proto(request.remote_entity_id());
2370 let project_connection_ids = session
2371 .db()
2372 .await
2373 .project_connection_ids(project_id, session.connection_id, false)
2374 .await?;
2375
2376 broadcast(
2377 Some(session.connection_id),
2378 project_connection_ids.iter().copied(),
2379 |connection_id| {
2380 session
2381 .peer
2382 .forward_send(session.connection_id, connection_id, request.clone())
2383 },
2384 );
2385 Ok(())
2386}
2387
2388/// Start following another user in a call.
2389async fn follow(
2390 request: proto::Follow,
2391 response: Response<proto::Follow>,
2392 session: Session,
2393) -> Result<()> {
2394 let room_id = RoomId::from_proto(request.room_id);
2395 let project_id = request.project_id.map(ProjectId::from_proto);
2396 let leader_id = request
2397 .leader_id
2398 .ok_or_else(|| anyhow!("invalid leader id"))?
2399 .into();
2400 let follower_id = session.connection_id;
2401
2402 session
2403 .db()
2404 .await
2405 .check_room_participants(room_id, leader_id, session.connection_id)
2406 .await?;
2407
2408 let response_payload = session
2409 .peer
2410 .forward_request(session.connection_id, leader_id, request)
2411 .await?;
2412 response.send(response_payload)?;
2413
2414 if let Some(project_id) = project_id {
2415 let room = session
2416 .db()
2417 .await
2418 .follow(room_id, project_id, leader_id, follower_id)
2419 .await?;
2420 room_updated(&room, &session.peer);
2421 }
2422
2423 Ok(())
2424}
2425
2426/// Stop following another user in a call.
2427async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2428 let room_id = RoomId::from_proto(request.room_id);
2429 let project_id = request.project_id.map(ProjectId::from_proto);
2430 let leader_id = request
2431 .leader_id
2432 .ok_or_else(|| anyhow!("invalid leader id"))?
2433 .into();
2434 let follower_id = session.connection_id;
2435
2436 session
2437 .db()
2438 .await
2439 .check_room_participants(room_id, leader_id, session.connection_id)
2440 .await?;
2441
2442 session
2443 .peer
2444 .forward_send(session.connection_id, leader_id, request)?;
2445
2446 if let Some(project_id) = project_id {
2447 let room = session
2448 .db()
2449 .await
2450 .unfollow(room_id, project_id, leader_id, follower_id)
2451 .await?;
2452 room_updated(&room, &session.peer);
2453 }
2454
2455 Ok(())
2456}
2457
2458/// Notify everyone following you of your current location.
2459async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2460 let room_id = RoomId::from_proto(request.room_id);
2461 let database = session.db.lock().await;
2462
2463 let connection_ids = if let Some(project_id) = request.project_id {
2464 let project_id = ProjectId::from_proto(project_id);
2465 database
2466 .project_connection_ids(project_id, session.connection_id, true)
2467 .await?
2468 } else {
2469 database
2470 .room_connection_ids(room_id, session.connection_id)
2471 .await?
2472 };
2473
2474 // For now, don't send view update messages back to that view's current leader.
2475 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2476 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2477 _ => None,
2478 });
2479
2480 for connection_id in connection_ids.iter().cloned() {
2481 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2482 session
2483 .peer
2484 .forward_send(session.connection_id, connection_id, request.clone())?;
2485 }
2486 }
2487 Ok(())
2488}
2489
2490/// Get public data about users.
2491async fn get_users(
2492 request: proto::GetUsers,
2493 response: Response<proto::GetUsers>,
2494 session: Session,
2495) -> Result<()> {
2496 let user_ids = request
2497 .user_ids
2498 .into_iter()
2499 .map(UserId::from_proto)
2500 .collect();
2501 let users = session
2502 .db()
2503 .await
2504 .get_users_by_ids(user_ids)
2505 .await?
2506 .into_iter()
2507 .map(|user| proto::User {
2508 id: user.id.to_proto(),
2509 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2510 github_login: user.github_login,
2511 email: user.email_address,
2512 name: user.name,
2513 })
2514 .collect();
2515 response.send(proto::UsersResponse { users })?;
2516 Ok(())
2517}
2518
2519/// Search for users (to invite) buy Github login
2520async fn fuzzy_search_users(
2521 request: proto::FuzzySearchUsers,
2522 response: Response<proto::FuzzySearchUsers>,
2523 session: Session,
2524) -> Result<()> {
2525 let query = request.query;
2526 let users = match query.len() {
2527 0 => vec![],
2528 1 | 2 => session
2529 .db()
2530 .await
2531 .get_user_by_github_login(&query)
2532 .await?
2533 .into_iter()
2534 .collect(),
2535 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2536 };
2537 let users = users
2538 .into_iter()
2539 .filter(|user| user.id != session.user_id())
2540 .map(|user| proto::User {
2541 id: user.id.to_proto(),
2542 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2543 github_login: user.github_login,
2544 name: user.name,
2545 email: user.email_address,
2546 })
2547 .collect();
2548 response.send(proto::UsersResponse { users })?;
2549 Ok(())
2550}
2551
2552/// Send a contact request to another user.
2553async fn request_contact(
2554 request: proto::RequestContact,
2555 response: Response<proto::RequestContact>,
2556 session: Session,
2557) -> Result<()> {
2558 let requester_id = session.user_id();
2559 let responder_id = UserId::from_proto(request.responder_id);
2560 if requester_id == responder_id {
2561 return Err(anyhow!("cannot add yourself as a contact"))?;
2562 }
2563
2564 let notifications = session
2565 .db()
2566 .await
2567 .send_contact_request(requester_id, responder_id)
2568 .await?;
2569
2570 // Update outgoing contact requests of requester
2571 let mut update = proto::UpdateContacts::default();
2572 update.outgoing_requests.push(responder_id.to_proto());
2573 for connection_id in session
2574 .connection_pool()
2575 .await
2576 .user_connection_ids(requester_id)
2577 {
2578 session.peer.send(connection_id, update.clone())?;
2579 }
2580
2581 // Update incoming contact requests of responder
2582 let mut update = proto::UpdateContacts::default();
2583 update
2584 .incoming_requests
2585 .push(proto::IncomingContactRequest {
2586 requester_id: requester_id.to_proto(),
2587 });
2588 let connection_pool = session.connection_pool().await;
2589 for connection_id in connection_pool.user_connection_ids(responder_id) {
2590 session.peer.send(connection_id, update.clone())?;
2591 }
2592
2593 send_notifications(&connection_pool, &session.peer, notifications);
2594
2595 response.send(proto::Ack {})?;
2596 Ok(())
2597}
2598
2599/// Accept or decline a contact request
2600async fn respond_to_contact_request(
2601 request: proto::RespondToContactRequest,
2602 response: Response<proto::RespondToContactRequest>,
2603 session: Session,
2604) -> Result<()> {
2605 let responder_id = session.user_id();
2606 let requester_id = UserId::from_proto(request.requester_id);
2607 let db = session.db().await;
2608 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2609 db.dismiss_contact_notification(responder_id, requester_id)
2610 .await?;
2611 } else {
2612 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2613
2614 let notifications = db
2615 .respond_to_contact_request(responder_id, requester_id, accept)
2616 .await?;
2617 let requester_busy = db.is_user_busy(requester_id).await?;
2618 let responder_busy = db.is_user_busy(responder_id).await?;
2619
2620 let pool = session.connection_pool().await;
2621 // Update responder with new contact
2622 let mut update = proto::UpdateContacts::default();
2623 if accept {
2624 update
2625 .contacts
2626 .push(contact_for_user(requester_id, requester_busy, &pool));
2627 }
2628 update
2629 .remove_incoming_requests
2630 .push(requester_id.to_proto());
2631 for connection_id in pool.user_connection_ids(responder_id) {
2632 session.peer.send(connection_id, update.clone())?;
2633 }
2634
2635 // Update requester with new contact
2636 let mut update = proto::UpdateContacts::default();
2637 if accept {
2638 update
2639 .contacts
2640 .push(contact_for_user(responder_id, responder_busy, &pool));
2641 }
2642 update
2643 .remove_outgoing_requests
2644 .push(responder_id.to_proto());
2645
2646 for connection_id in pool.user_connection_ids(requester_id) {
2647 session.peer.send(connection_id, update.clone())?;
2648 }
2649
2650 send_notifications(&pool, &session.peer, notifications);
2651 }
2652
2653 response.send(proto::Ack {})?;
2654 Ok(())
2655}
2656
2657/// Remove a contact.
2658async fn remove_contact(
2659 request: proto::RemoveContact,
2660 response: Response<proto::RemoveContact>,
2661 session: Session,
2662) -> Result<()> {
2663 let requester_id = session.user_id();
2664 let responder_id = UserId::from_proto(request.user_id);
2665 let db = session.db().await;
2666 let (contact_accepted, deleted_notification_id) =
2667 db.remove_contact(requester_id, responder_id).await?;
2668
2669 let pool = session.connection_pool().await;
2670 // Update outgoing contact requests of requester
2671 let mut update = proto::UpdateContacts::default();
2672 if contact_accepted {
2673 update.remove_contacts.push(responder_id.to_proto());
2674 } else {
2675 update
2676 .remove_outgoing_requests
2677 .push(responder_id.to_proto());
2678 }
2679 for connection_id in pool.user_connection_ids(requester_id) {
2680 session.peer.send(connection_id, update.clone())?;
2681 }
2682
2683 // Update incoming contact requests of responder
2684 let mut update = proto::UpdateContacts::default();
2685 if contact_accepted {
2686 update.remove_contacts.push(requester_id.to_proto());
2687 } else {
2688 update
2689 .remove_incoming_requests
2690 .push(requester_id.to_proto());
2691 }
2692 for connection_id in pool.user_connection_ids(responder_id) {
2693 session.peer.send(connection_id, update.clone())?;
2694 if let Some(notification_id) = deleted_notification_id {
2695 session.peer.send(
2696 connection_id,
2697 proto::DeleteNotification {
2698 notification_id: notification_id.to_proto(),
2699 },
2700 )?;
2701 }
2702 }
2703
2704 response.send(proto::Ack {})?;
2705 Ok(())
2706}
2707
2708fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2709 version.0.minor() < 139
2710}
2711
2712async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2713 let plan = session.current_plan(&session.db().await).await?;
2714
2715 session
2716 .peer
2717 .send(
2718 session.connection_id,
2719 proto::UpdateUserPlan { plan: plan.into() },
2720 )
2721 .trace_err();
2722
2723 Ok(())
2724}
2725
2726async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2727 subscribe_user_to_channels(session.user_id(), &session).await?;
2728 Ok(())
2729}
2730
2731async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2732 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2733 let mut pool = session.connection_pool().await;
2734 for membership in &channels_for_user.channel_memberships {
2735 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2736 }
2737 session.peer.send(
2738 session.connection_id,
2739 build_update_user_channels(&channels_for_user),
2740 )?;
2741 session.peer.send(
2742 session.connection_id,
2743 build_channels_update(channels_for_user),
2744 )?;
2745 Ok(())
2746}
2747
2748/// Creates a new channel.
2749async fn create_channel(
2750 request: proto::CreateChannel,
2751 response: Response<proto::CreateChannel>,
2752 session: Session,
2753) -> Result<()> {
2754 let db = session.db().await;
2755
2756 let parent_id = request.parent_id.map(ChannelId::from_proto);
2757 let (channel, membership) = db
2758 .create_channel(&request.name, parent_id, session.user_id())
2759 .await?;
2760
2761 let root_id = channel.root_id();
2762 let channel = Channel::from_model(channel);
2763
2764 response.send(proto::CreateChannelResponse {
2765 channel: Some(channel.to_proto()),
2766 parent_id: request.parent_id,
2767 })?;
2768
2769 let mut connection_pool = session.connection_pool().await;
2770 if let Some(membership) = membership {
2771 connection_pool.subscribe_to_channel(
2772 membership.user_id,
2773 membership.channel_id,
2774 membership.role,
2775 );
2776 let update = proto::UpdateUserChannels {
2777 channel_memberships: vec![proto::ChannelMembership {
2778 channel_id: membership.channel_id.to_proto(),
2779 role: membership.role.into(),
2780 }],
2781 ..Default::default()
2782 };
2783 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2784 session.peer.send(connection_id, update.clone())?;
2785 }
2786 }
2787
2788 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2789 if !role.can_see_channel(channel.visibility) {
2790 continue;
2791 }
2792
2793 let update = proto::UpdateChannels {
2794 channels: vec![channel.to_proto()],
2795 ..Default::default()
2796 };
2797 session.peer.send(connection_id, update.clone())?;
2798 }
2799
2800 Ok(())
2801}
2802
2803/// Delete a channel
2804async fn delete_channel(
2805 request: proto::DeleteChannel,
2806 response: Response<proto::DeleteChannel>,
2807 session: Session,
2808) -> Result<()> {
2809 let db = session.db().await;
2810
2811 let channel_id = request.channel_id;
2812 let (root_channel, removed_channels) = db
2813 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2814 .await?;
2815 response.send(proto::Ack {})?;
2816
2817 // Notify members of removed channels
2818 let mut update = proto::UpdateChannels::default();
2819 update
2820 .delete_channels
2821 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2822
2823 let connection_pool = session.connection_pool().await;
2824 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2825 session.peer.send(connection_id, update.clone())?;
2826 }
2827
2828 Ok(())
2829}
2830
2831/// Invite someone to join a channel.
2832async fn invite_channel_member(
2833 request: proto::InviteChannelMember,
2834 response: Response<proto::InviteChannelMember>,
2835 session: Session,
2836) -> Result<()> {
2837 let db = session.db().await;
2838 let channel_id = ChannelId::from_proto(request.channel_id);
2839 let invitee_id = UserId::from_proto(request.user_id);
2840 let InviteMemberResult {
2841 channel,
2842 notifications,
2843 } = db
2844 .invite_channel_member(
2845 channel_id,
2846 invitee_id,
2847 session.user_id(),
2848 request.role().into(),
2849 )
2850 .await?;
2851
2852 let update = proto::UpdateChannels {
2853 channel_invitations: vec![channel.to_proto()],
2854 ..Default::default()
2855 };
2856
2857 let connection_pool = session.connection_pool().await;
2858 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2859 session.peer.send(connection_id, update.clone())?;
2860 }
2861
2862 send_notifications(&connection_pool, &session.peer, notifications);
2863
2864 response.send(proto::Ack {})?;
2865 Ok(())
2866}
2867
2868/// remove someone from a channel
2869async fn remove_channel_member(
2870 request: proto::RemoveChannelMember,
2871 response: Response<proto::RemoveChannelMember>,
2872 session: Session,
2873) -> Result<()> {
2874 let db = session.db().await;
2875 let channel_id = ChannelId::from_proto(request.channel_id);
2876 let member_id = UserId::from_proto(request.user_id);
2877
2878 let RemoveChannelMemberResult {
2879 membership_update,
2880 notification_id,
2881 } = db
2882 .remove_channel_member(channel_id, member_id, session.user_id())
2883 .await?;
2884
2885 let mut connection_pool = session.connection_pool().await;
2886 notify_membership_updated(
2887 &mut connection_pool,
2888 membership_update,
2889 member_id,
2890 &session.peer,
2891 );
2892 for connection_id in connection_pool.user_connection_ids(member_id) {
2893 if let Some(notification_id) = notification_id {
2894 session
2895 .peer
2896 .send(
2897 connection_id,
2898 proto::DeleteNotification {
2899 notification_id: notification_id.to_proto(),
2900 },
2901 )
2902 .trace_err();
2903 }
2904 }
2905
2906 response.send(proto::Ack {})?;
2907 Ok(())
2908}
2909
2910/// Toggle the channel between public and private.
2911/// Care is taken to maintain the invariant that public channels only descend from public channels,
2912/// (though members-only channels can appear at any point in the hierarchy).
2913async fn set_channel_visibility(
2914 request: proto::SetChannelVisibility,
2915 response: Response<proto::SetChannelVisibility>,
2916 session: Session,
2917) -> Result<()> {
2918 let db = session.db().await;
2919 let channel_id = ChannelId::from_proto(request.channel_id);
2920 let visibility = request.visibility().into();
2921
2922 let channel_model = db
2923 .set_channel_visibility(channel_id, visibility, session.user_id())
2924 .await?;
2925 let root_id = channel_model.root_id();
2926 let channel = Channel::from_model(channel_model);
2927
2928 let mut connection_pool = session.connection_pool().await;
2929 for (user_id, role) in connection_pool
2930 .channel_user_ids(root_id)
2931 .collect::<Vec<_>>()
2932 .into_iter()
2933 {
2934 let update = if role.can_see_channel(channel.visibility) {
2935 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2936 proto::UpdateChannels {
2937 channels: vec![channel.to_proto()],
2938 ..Default::default()
2939 }
2940 } else {
2941 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2942 proto::UpdateChannels {
2943 delete_channels: vec![channel.id.to_proto()],
2944 ..Default::default()
2945 }
2946 };
2947
2948 for connection_id in connection_pool.user_connection_ids(user_id) {
2949 session.peer.send(connection_id, update.clone())?;
2950 }
2951 }
2952
2953 response.send(proto::Ack {})?;
2954 Ok(())
2955}
2956
2957/// Alter the role for a user in the channel.
2958async fn set_channel_member_role(
2959 request: proto::SetChannelMemberRole,
2960 response: Response<proto::SetChannelMemberRole>,
2961 session: Session,
2962) -> Result<()> {
2963 let db = session.db().await;
2964 let channel_id = ChannelId::from_proto(request.channel_id);
2965 let member_id = UserId::from_proto(request.user_id);
2966 let result = db
2967 .set_channel_member_role(
2968 channel_id,
2969 session.user_id(),
2970 member_id,
2971 request.role().into(),
2972 )
2973 .await?;
2974
2975 match result {
2976 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2977 let mut connection_pool = session.connection_pool().await;
2978 notify_membership_updated(
2979 &mut connection_pool,
2980 membership_update,
2981 member_id,
2982 &session.peer,
2983 )
2984 }
2985 db::SetMemberRoleResult::InviteUpdated(channel) => {
2986 let update = proto::UpdateChannels {
2987 channel_invitations: vec![channel.to_proto()],
2988 ..Default::default()
2989 };
2990
2991 for connection_id in session
2992 .connection_pool()
2993 .await
2994 .user_connection_ids(member_id)
2995 {
2996 session.peer.send(connection_id, update.clone())?;
2997 }
2998 }
2999 }
3000
3001 response.send(proto::Ack {})?;
3002 Ok(())
3003}
3004
3005/// Change the name of a channel
3006async fn rename_channel(
3007 request: proto::RenameChannel,
3008 response: Response<proto::RenameChannel>,
3009 session: Session,
3010) -> Result<()> {
3011 let db = session.db().await;
3012 let channel_id = ChannelId::from_proto(request.channel_id);
3013 let channel_model = db
3014 .rename_channel(channel_id, session.user_id(), &request.name)
3015 .await?;
3016 let root_id = channel_model.root_id();
3017 let channel = Channel::from_model(channel_model);
3018
3019 response.send(proto::RenameChannelResponse {
3020 channel: Some(channel.to_proto()),
3021 })?;
3022
3023 let connection_pool = session.connection_pool().await;
3024 let update = proto::UpdateChannels {
3025 channels: vec![channel.to_proto()],
3026 ..Default::default()
3027 };
3028 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3029 if role.can_see_channel(channel.visibility) {
3030 session.peer.send(connection_id, update.clone())?;
3031 }
3032 }
3033
3034 Ok(())
3035}
3036
3037/// Move a channel to a new parent.
3038async fn move_channel(
3039 request: proto::MoveChannel,
3040 response: Response<proto::MoveChannel>,
3041 session: Session,
3042) -> Result<()> {
3043 let channel_id = ChannelId::from_proto(request.channel_id);
3044 let to = ChannelId::from_proto(request.to);
3045
3046 let (root_id, channels) = session
3047 .db()
3048 .await
3049 .move_channel(channel_id, to, session.user_id())
3050 .await?;
3051
3052 let connection_pool = session.connection_pool().await;
3053 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3054 let channels = channels
3055 .iter()
3056 .filter_map(|channel| {
3057 if role.can_see_channel(channel.visibility) {
3058 Some(channel.to_proto())
3059 } else {
3060 None
3061 }
3062 })
3063 .collect::<Vec<_>>();
3064 if channels.is_empty() {
3065 continue;
3066 }
3067
3068 let update = proto::UpdateChannels {
3069 channels,
3070 ..Default::default()
3071 };
3072
3073 session.peer.send(connection_id, update.clone())?;
3074 }
3075
3076 response.send(Ack {})?;
3077 Ok(())
3078}
3079
3080/// Get the list of channel members
3081async fn get_channel_members(
3082 request: proto::GetChannelMembers,
3083 response: Response<proto::GetChannelMembers>,
3084 session: Session,
3085) -> Result<()> {
3086 let db = session.db().await;
3087 let channel_id = ChannelId::from_proto(request.channel_id);
3088 let limit = if request.limit == 0 {
3089 u16::MAX as u64
3090 } else {
3091 request.limit
3092 };
3093 let (members, users) = db
3094 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3095 .await?;
3096 response.send(proto::GetChannelMembersResponse { members, users })?;
3097 Ok(())
3098}
3099
3100/// Accept or decline a channel invitation.
3101async fn respond_to_channel_invite(
3102 request: proto::RespondToChannelInvite,
3103 response: Response<proto::RespondToChannelInvite>,
3104 session: Session,
3105) -> Result<()> {
3106 let db = session.db().await;
3107 let channel_id = ChannelId::from_proto(request.channel_id);
3108 let RespondToChannelInvite {
3109 membership_update,
3110 notifications,
3111 } = db
3112 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3113 .await?;
3114
3115 let mut connection_pool = session.connection_pool().await;
3116 if let Some(membership_update) = membership_update {
3117 notify_membership_updated(
3118 &mut connection_pool,
3119 membership_update,
3120 session.user_id(),
3121 &session.peer,
3122 );
3123 } else {
3124 let update = proto::UpdateChannels {
3125 remove_channel_invitations: vec![channel_id.to_proto()],
3126 ..Default::default()
3127 };
3128
3129 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3130 session.peer.send(connection_id, update.clone())?;
3131 }
3132 };
3133
3134 send_notifications(&connection_pool, &session.peer, notifications);
3135
3136 response.send(proto::Ack {})?;
3137
3138 Ok(())
3139}
3140
3141/// Join the channels' room
3142async fn join_channel(
3143 request: proto::JoinChannel,
3144 response: Response<proto::JoinChannel>,
3145 session: Session,
3146) -> Result<()> {
3147 let channel_id = ChannelId::from_proto(request.channel_id);
3148 join_channel_internal(channel_id, Box::new(response), session).await
3149}
3150
3151trait JoinChannelInternalResponse {
3152 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3153}
3154impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3155 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3156 Response::<proto::JoinChannel>::send(self, result)
3157 }
3158}
3159impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3160 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3161 Response::<proto::JoinRoom>::send(self, result)
3162 }
3163}
3164
3165async fn join_channel_internal(
3166 channel_id: ChannelId,
3167 response: Box<impl JoinChannelInternalResponse>,
3168 session: Session,
3169) -> Result<()> {
3170 let joined_room = {
3171 let mut db = session.db().await;
3172 // If zed quits without leaving the room, and the user re-opens zed before the
3173 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3174 // room they were in.
3175 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3176 tracing::info!(
3177 stale_connection_id = %connection,
3178 "cleaning up stale connection",
3179 );
3180 drop(db);
3181 leave_room_for_session(&session, connection).await?;
3182 db = session.db().await;
3183 }
3184
3185 let (joined_room, membership_updated, role) = db
3186 .join_channel(channel_id, session.user_id(), session.connection_id)
3187 .await?;
3188
3189 let live_kit_connection_info =
3190 session
3191 .app_state
3192 .livekit_client
3193 .as_ref()
3194 .and_then(|live_kit| {
3195 let (can_publish, token) = if role == ChannelRole::Guest {
3196 (
3197 false,
3198 live_kit
3199 .guest_token(
3200 &joined_room.room.livekit_room,
3201 &session.user_id().to_string(),
3202 )
3203 .trace_err()?,
3204 )
3205 } else {
3206 (
3207 true,
3208 live_kit
3209 .room_token(
3210 &joined_room.room.livekit_room,
3211 &session.user_id().to_string(),
3212 )
3213 .trace_err()?,
3214 )
3215 };
3216
3217 Some(LiveKitConnectionInfo {
3218 server_url: live_kit.url().into(),
3219 token,
3220 can_publish,
3221 })
3222 });
3223
3224 response.send(proto::JoinRoomResponse {
3225 room: Some(joined_room.room.clone()),
3226 channel_id: joined_room
3227 .channel
3228 .as_ref()
3229 .map(|channel| channel.id.to_proto()),
3230 live_kit_connection_info,
3231 })?;
3232
3233 let mut connection_pool = session.connection_pool().await;
3234 if let Some(membership_updated) = membership_updated {
3235 notify_membership_updated(
3236 &mut connection_pool,
3237 membership_updated,
3238 session.user_id(),
3239 &session.peer,
3240 );
3241 }
3242
3243 room_updated(&joined_room.room, &session.peer);
3244
3245 joined_room
3246 };
3247
3248 channel_updated(
3249 &joined_room
3250 .channel
3251 .ok_or_else(|| anyhow!("channel not returned"))?,
3252 &joined_room.room,
3253 &session.peer,
3254 &*session.connection_pool().await,
3255 );
3256
3257 update_user_contacts(session.user_id(), &session).await?;
3258 Ok(())
3259}
3260
3261/// Start editing the channel notes
3262async fn join_channel_buffer(
3263 request: proto::JoinChannelBuffer,
3264 response: Response<proto::JoinChannelBuffer>,
3265 session: Session,
3266) -> Result<()> {
3267 let db = session.db().await;
3268 let channel_id = ChannelId::from_proto(request.channel_id);
3269
3270 let open_response = db
3271 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3272 .await?;
3273
3274 let collaborators = open_response.collaborators.clone();
3275 response.send(open_response)?;
3276
3277 let update = UpdateChannelBufferCollaborators {
3278 channel_id: channel_id.to_proto(),
3279 collaborators: collaborators.clone(),
3280 };
3281 channel_buffer_updated(
3282 session.connection_id,
3283 collaborators
3284 .iter()
3285 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3286 &update,
3287 &session.peer,
3288 );
3289
3290 Ok(())
3291}
3292
3293/// Edit the channel notes
3294async fn update_channel_buffer(
3295 request: proto::UpdateChannelBuffer,
3296 session: Session,
3297) -> Result<()> {
3298 let db = session.db().await;
3299 let channel_id = ChannelId::from_proto(request.channel_id);
3300
3301 let (collaborators, epoch, version) = db
3302 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3303 .await?;
3304
3305 channel_buffer_updated(
3306 session.connection_id,
3307 collaborators.clone(),
3308 &proto::UpdateChannelBuffer {
3309 channel_id: channel_id.to_proto(),
3310 operations: request.operations,
3311 },
3312 &session.peer,
3313 );
3314
3315 let pool = &*session.connection_pool().await;
3316
3317 let non_collaborators =
3318 pool.channel_connection_ids(channel_id)
3319 .filter_map(|(connection_id, _)| {
3320 if collaborators.contains(&connection_id) {
3321 None
3322 } else {
3323 Some(connection_id)
3324 }
3325 });
3326
3327 broadcast(None, non_collaborators, |peer_id| {
3328 session.peer.send(
3329 peer_id,
3330 proto::UpdateChannels {
3331 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3332 channel_id: channel_id.to_proto(),
3333 epoch: epoch as u64,
3334 version: version.clone(),
3335 }],
3336 ..Default::default()
3337 },
3338 )
3339 });
3340
3341 Ok(())
3342}
3343
3344/// Rejoin the channel notes after a connection blip
3345async fn rejoin_channel_buffers(
3346 request: proto::RejoinChannelBuffers,
3347 response: Response<proto::RejoinChannelBuffers>,
3348 session: Session,
3349) -> Result<()> {
3350 let db = session.db().await;
3351 let buffers = db
3352 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3353 .await?;
3354
3355 for rejoined_buffer in &buffers {
3356 let collaborators_to_notify = rejoined_buffer
3357 .buffer
3358 .collaborators
3359 .iter()
3360 .filter_map(|c| Some(c.peer_id?.into()));
3361 channel_buffer_updated(
3362 session.connection_id,
3363 collaborators_to_notify,
3364 &proto::UpdateChannelBufferCollaborators {
3365 channel_id: rejoined_buffer.buffer.channel_id,
3366 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3367 },
3368 &session.peer,
3369 );
3370 }
3371
3372 response.send(proto::RejoinChannelBuffersResponse {
3373 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3374 })?;
3375
3376 Ok(())
3377}
3378
3379/// Stop editing the channel notes
3380async fn leave_channel_buffer(
3381 request: proto::LeaveChannelBuffer,
3382 response: Response<proto::LeaveChannelBuffer>,
3383 session: Session,
3384) -> Result<()> {
3385 let db = session.db().await;
3386 let channel_id = ChannelId::from_proto(request.channel_id);
3387
3388 let left_buffer = db
3389 .leave_channel_buffer(channel_id, session.connection_id)
3390 .await?;
3391
3392 response.send(Ack {})?;
3393
3394 channel_buffer_updated(
3395 session.connection_id,
3396 left_buffer.connections,
3397 &proto::UpdateChannelBufferCollaborators {
3398 channel_id: channel_id.to_proto(),
3399 collaborators: left_buffer.collaborators,
3400 },
3401 &session.peer,
3402 );
3403
3404 Ok(())
3405}
3406
3407fn channel_buffer_updated<T: EnvelopedMessage>(
3408 sender_id: ConnectionId,
3409 collaborators: impl IntoIterator<Item = ConnectionId>,
3410 message: &T,
3411 peer: &Peer,
3412) {
3413 broadcast(Some(sender_id), collaborators, |peer_id| {
3414 peer.send(peer_id, message.clone())
3415 });
3416}
3417
3418fn send_notifications(
3419 connection_pool: &ConnectionPool,
3420 peer: &Peer,
3421 notifications: db::NotificationBatch,
3422) {
3423 for (user_id, notification) in notifications {
3424 for connection_id in connection_pool.user_connection_ids(user_id) {
3425 if let Err(error) = peer.send(
3426 connection_id,
3427 proto::AddNotification {
3428 notification: Some(notification.clone()),
3429 },
3430 ) {
3431 tracing::error!(
3432 "failed to send notification to {:?} {}",
3433 connection_id,
3434 error
3435 );
3436 }
3437 }
3438 }
3439}
3440
3441/// Send a message to the channel
3442async fn send_channel_message(
3443 request: proto::SendChannelMessage,
3444 response: Response<proto::SendChannelMessage>,
3445 session: Session,
3446) -> Result<()> {
3447 // Validate the message body.
3448 let body = request.body.trim().to_string();
3449 if body.len() > MAX_MESSAGE_LEN {
3450 return Err(anyhow!("message is too long"))?;
3451 }
3452 if body.is_empty() {
3453 return Err(anyhow!("message can't be blank"))?;
3454 }
3455
3456 // TODO: adjust mentions if body is trimmed
3457
3458 let timestamp = OffsetDateTime::now_utc();
3459 let nonce = request
3460 .nonce
3461 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3462
3463 let channel_id = ChannelId::from_proto(request.channel_id);
3464 let CreatedChannelMessage {
3465 message_id,
3466 participant_connection_ids,
3467 notifications,
3468 } = session
3469 .db()
3470 .await
3471 .create_channel_message(
3472 channel_id,
3473 session.user_id(),
3474 &body,
3475 &request.mentions,
3476 timestamp,
3477 nonce.clone().into(),
3478 request.reply_to_message_id.map(MessageId::from_proto),
3479 )
3480 .await?;
3481
3482 let message = proto::ChannelMessage {
3483 sender_id: session.user_id().to_proto(),
3484 id: message_id.to_proto(),
3485 body,
3486 mentions: request.mentions,
3487 timestamp: timestamp.unix_timestamp() as u64,
3488 nonce: Some(nonce),
3489 reply_to_message_id: request.reply_to_message_id,
3490 edited_at: None,
3491 };
3492 broadcast(
3493 Some(session.connection_id),
3494 participant_connection_ids.clone(),
3495 |connection| {
3496 session.peer.send(
3497 connection,
3498 proto::ChannelMessageSent {
3499 channel_id: channel_id.to_proto(),
3500 message: Some(message.clone()),
3501 },
3502 )
3503 },
3504 );
3505 response.send(proto::SendChannelMessageResponse {
3506 message: Some(message),
3507 })?;
3508
3509 let pool = &*session.connection_pool().await;
3510 let non_participants =
3511 pool.channel_connection_ids(channel_id)
3512 .filter_map(|(connection_id, _)| {
3513 if participant_connection_ids.contains(&connection_id) {
3514 None
3515 } else {
3516 Some(connection_id)
3517 }
3518 });
3519 broadcast(None, non_participants, |peer_id| {
3520 session.peer.send(
3521 peer_id,
3522 proto::UpdateChannels {
3523 latest_channel_message_ids: vec![proto::ChannelMessageId {
3524 channel_id: channel_id.to_proto(),
3525 message_id: message_id.to_proto(),
3526 }],
3527 ..Default::default()
3528 },
3529 )
3530 });
3531 send_notifications(pool, &session.peer, notifications);
3532
3533 Ok(())
3534}
3535
3536/// Delete a channel message
3537async fn remove_channel_message(
3538 request: proto::RemoveChannelMessage,
3539 response: Response<proto::RemoveChannelMessage>,
3540 session: Session,
3541) -> Result<()> {
3542 let channel_id = ChannelId::from_proto(request.channel_id);
3543 let message_id = MessageId::from_proto(request.message_id);
3544 let (connection_ids, existing_notification_ids) = session
3545 .db()
3546 .await
3547 .remove_channel_message(channel_id, message_id, session.user_id())
3548 .await?;
3549
3550 broadcast(
3551 Some(session.connection_id),
3552 connection_ids,
3553 move |connection| {
3554 session.peer.send(connection, request.clone())?;
3555
3556 for notification_id in &existing_notification_ids {
3557 session.peer.send(
3558 connection,
3559 proto::DeleteNotification {
3560 notification_id: (*notification_id).to_proto(),
3561 },
3562 )?;
3563 }
3564
3565 Ok(())
3566 },
3567 );
3568 response.send(proto::Ack {})?;
3569 Ok(())
3570}
3571
3572async fn update_channel_message(
3573 request: proto::UpdateChannelMessage,
3574 response: Response<proto::UpdateChannelMessage>,
3575 session: Session,
3576) -> Result<()> {
3577 let channel_id = ChannelId::from_proto(request.channel_id);
3578 let message_id = MessageId::from_proto(request.message_id);
3579 let updated_at = OffsetDateTime::now_utc();
3580 let UpdatedChannelMessage {
3581 message_id,
3582 participant_connection_ids,
3583 notifications,
3584 reply_to_message_id,
3585 timestamp,
3586 deleted_mention_notification_ids,
3587 updated_mention_notifications,
3588 } = session
3589 .db()
3590 .await
3591 .update_channel_message(
3592 channel_id,
3593 message_id,
3594 session.user_id(),
3595 request.body.as_str(),
3596 &request.mentions,
3597 updated_at,
3598 )
3599 .await?;
3600
3601 let nonce = request
3602 .nonce
3603 .clone()
3604 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3605
3606 let message = proto::ChannelMessage {
3607 sender_id: session.user_id().to_proto(),
3608 id: message_id.to_proto(),
3609 body: request.body.clone(),
3610 mentions: request.mentions.clone(),
3611 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3612 nonce: Some(nonce),
3613 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3614 edited_at: Some(updated_at.unix_timestamp() as u64),
3615 };
3616
3617 response.send(proto::Ack {})?;
3618
3619 let pool = &*session.connection_pool().await;
3620 broadcast(
3621 Some(session.connection_id),
3622 participant_connection_ids,
3623 |connection| {
3624 session.peer.send(
3625 connection,
3626 proto::ChannelMessageUpdate {
3627 channel_id: channel_id.to_proto(),
3628 message: Some(message.clone()),
3629 },
3630 )?;
3631
3632 for notification_id in &deleted_mention_notification_ids {
3633 session.peer.send(
3634 connection,
3635 proto::DeleteNotification {
3636 notification_id: (*notification_id).to_proto(),
3637 },
3638 )?;
3639 }
3640
3641 for notification in &updated_mention_notifications {
3642 session.peer.send(
3643 connection,
3644 proto::UpdateNotification {
3645 notification: Some(notification.clone()),
3646 },
3647 )?;
3648 }
3649
3650 Ok(())
3651 },
3652 );
3653
3654 send_notifications(pool, &session.peer, notifications);
3655
3656 Ok(())
3657}
3658
3659/// Mark a channel message as read
3660async fn acknowledge_channel_message(
3661 request: proto::AckChannelMessage,
3662 session: Session,
3663) -> Result<()> {
3664 let channel_id = ChannelId::from_proto(request.channel_id);
3665 let message_id = MessageId::from_proto(request.message_id);
3666 let notifications = session
3667 .db()
3668 .await
3669 .observe_channel_message(channel_id, session.user_id(), message_id)
3670 .await?;
3671 send_notifications(
3672 &*session.connection_pool().await,
3673 &session.peer,
3674 notifications,
3675 );
3676 Ok(())
3677}
3678
3679/// Mark a buffer version as synced
3680async fn acknowledge_buffer_version(
3681 request: proto::AckBufferOperation,
3682 session: Session,
3683) -> Result<()> {
3684 let buffer_id = BufferId::from_proto(request.buffer_id);
3685 session
3686 .db()
3687 .await
3688 .observe_buffer_version(
3689 buffer_id,
3690 session.user_id(),
3691 request.epoch as i32,
3692 &request.version,
3693 )
3694 .await?;
3695 Ok(())
3696}
3697
3698async fn count_language_model_tokens(
3699 request: proto::CountLanguageModelTokens,
3700 response: Response<proto::CountLanguageModelTokens>,
3701 session: Session,
3702 config: &Config,
3703) -> Result<()> {
3704 authorize_access_to_legacy_llm_endpoints(&session).await?;
3705
3706 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3707 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3708 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3709 };
3710
3711 session
3712 .app_state
3713 .rate_limiter
3714 .check(&*rate_limit, session.user_id())
3715 .await?;
3716
3717 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3718 Some(proto::LanguageModelProvider::Google) => {
3719 let api_key = config
3720 .google_ai_api_key
3721 .as_ref()
3722 .context("no Google AI API key configured on the server")?;
3723 google_ai::count_tokens(
3724 session.http_client.as_ref(),
3725 google_ai::API_URL,
3726 api_key,
3727 serde_json::from_str(&request.request)?,
3728 )
3729 .await?
3730 }
3731 _ => return Err(anyhow!("unsupported provider"))?,
3732 };
3733
3734 response.send(proto::CountLanguageModelTokensResponse {
3735 token_count: result.total_tokens as u32,
3736 })?;
3737
3738 Ok(())
3739}
3740
3741struct ZedProCountLanguageModelTokensRateLimit;
3742
3743impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3744 fn capacity(&self) -> usize {
3745 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3746 .ok()
3747 .and_then(|v| v.parse().ok())
3748 .unwrap_or(600) // Picked arbitrarily
3749 }
3750
3751 fn refill_duration(&self) -> chrono::Duration {
3752 chrono::Duration::hours(1)
3753 }
3754
3755 fn db_name(&self) -> &'static str {
3756 "zed-pro:count-language-model-tokens"
3757 }
3758}
3759
3760struct FreeCountLanguageModelTokensRateLimit;
3761
3762impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3763 fn capacity(&self) -> usize {
3764 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3765 .ok()
3766 .and_then(|v| v.parse().ok())
3767 .unwrap_or(600 / 10) // Picked arbitrarily
3768 }
3769
3770 fn refill_duration(&self) -> chrono::Duration {
3771 chrono::Duration::hours(1)
3772 }
3773
3774 fn db_name(&self) -> &'static str {
3775 "free:count-language-model-tokens"
3776 }
3777}
3778
3779struct ZedProComputeEmbeddingsRateLimit;
3780
3781impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3782 fn capacity(&self) -> usize {
3783 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3784 .ok()
3785 .and_then(|v| v.parse().ok())
3786 .unwrap_or(5000) // Picked arbitrarily
3787 }
3788
3789 fn refill_duration(&self) -> chrono::Duration {
3790 chrono::Duration::hours(1)
3791 }
3792
3793 fn db_name(&self) -> &'static str {
3794 "zed-pro:compute-embeddings"
3795 }
3796}
3797
3798struct FreeComputeEmbeddingsRateLimit;
3799
3800impl RateLimit for FreeComputeEmbeddingsRateLimit {
3801 fn capacity(&self) -> usize {
3802 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3803 .ok()
3804 .and_then(|v| v.parse().ok())
3805 .unwrap_or(5000 / 10) // Picked arbitrarily
3806 }
3807
3808 fn refill_duration(&self) -> chrono::Duration {
3809 chrono::Duration::hours(1)
3810 }
3811
3812 fn db_name(&self) -> &'static str {
3813 "free:compute-embeddings"
3814 }
3815}
3816
3817async fn compute_embeddings(
3818 request: proto::ComputeEmbeddings,
3819 response: Response<proto::ComputeEmbeddings>,
3820 session: Session,
3821 api_key: Option<Arc<str>>,
3822) -> Result<()> {
3823 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3824 authorize_access_to_legacy_llm_endpoints(&session).await?;
3825
3826 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3827 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3828 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3829 };
3830
3831 session
3832 .app_state
3833 .rate_limiter
3834 .check(&*rate_limit, session.user_id())
3835 .await?;
3836
3837 let embeddings = match request.model.as_str() {
3838 "openai/text-embedding-3-small" => {
3839 open_ai::embed(
3840 session.http_client.as_ref(),
3841 OPEN_AI_API_URL,
3842 &api_key,
3843 OpenAiEmbeddingModel::TextEmbedding3Small,
3844 request.texts.iter().map(|text| text.as_str()),
3845 )
3846 .await?
3847 }
3848 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3849 };
3850
3851 let embeddings = request
3852 .texts
3853 .iter()
3854 .map(|text| {
3855 let mut hasher = sha2::Sha256::new();
3856 hasher.update(text.as_bytes());
3857 let result = hasher.finalize();
3858 result.to_vec()
3859 })
3860 .zip(
3861 embeddings
3862 .data
3863 .into_iter()
3864 .map(|embedding| embedding.embedding),
3865 )
3866 .collect::<HashMap<_, _>>();
3867
3868 let db = session.db().await;
3869 db.save_embeddings(&request.model, &embeddings)
3870 .await
3871 .context("failed to save embeddings")
3872 .trace_err();
3873
3874 response.send(proto::ComputeEmbeddingsResponse {
3875 embeddings: embeddings
3876 .into_iter()
3877 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3878 .collect(),
3879 })?;
3880 Ok(())
3881}
3882
3883async fn get_cached_embeddings(
3884 request: proto::GetCachedEmbeddings,
3885 response: Response<proto::GetCachedEmbeddings>,
3886 session: Session,
3887) -> Result<()> {
3888 authorize_access_to_legacy_llm_endpoints(&session).await?;
3889
3890 let db = session.db().await;
3891 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3892
3893 response.send(proto::GetCachedEmbeddingsResponse {
3894 embeddings: embeddings
3895 .into_iter()
3896 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3897 .collect(),
3898 })?;
3899 Ok(())
3900}
3901
3902/// This is leftover from before the LLM service.
3903///
3904/// The endpoints protected by this check will be moved there eventually.
3905async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3906 if session.is_staff() {
3907 Ok(())
3908 } else {
3909 Err(anyhow!("permission denied"))?
3910 }
3911}
3912
3913/// Get a Supermaven API key for the user
3914async fn get_supermaven_api_key(
3915 _request: proto::GetSupermavenApiKey,
3916 response: Response<proto::GetSupermavenApiKey>,
3917 session: Session,
3918) -> Result<()> {
3919 let user_id: String = session.user_id().to_string();
3920 if !session.is_staff() {
3921 return Err(anyhow!("supermaven not enabled for this account"))?;
3922 }
3923
3924 let email = session
3925 .email()
3926 .ok_or_else(|| anyhow!("user must have an email"))?;
3927
3928 let supermaven_admin_api = session
3929 .supermaven_client
3930 .as_ref()
3931 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3932
3933 let result = supermaven_admin_api
3934 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3935 .await?;
3936
3937 response.send(proto::GetSupermavenApiKeyResponse {
3938 api_key: result.api_key,
3939 })?;
3940
3941 Ok(())
3942}
3943
3944/// Start receiving chat updates for a channel
3945async fn join_channel_chat(
3946 request: proto::JoinChannelChat,
3947 response: Response<proto::JoinChannelChat>,
3948 session: Session,
3949) -> Result<()> {
3950 let channel_id = ChannelId::from_proto(request.channel_id);
3951
3952 let db = session.db().await;
3953 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3954 .await?;
3955 let messages = db
3956 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3957 .await?;
3958 response.send(proto::JoinChannelChatResponse {
3959 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3960 messages,
3961 })?;
3962 Ok(())
3963}
3964
3965/// Stop receiving chat updates for a channel
3966async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3967 let channel_id = ChannelId::from_proto(request.channel_id);
3968 session
3969 .db()
3970 .await
3971 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3972 .await?;
3973 Ok(())
3974}
3975
3976/// Retrieve the chat history for a channel
3977async fn get_channel_messages(
3978 request: proto::GetChannelMessages,
3979 response: Response<proto::GetChannelMessages>,
3980 session: Session,
3981) -> Result<()> {
3982 let channel_id = ChannelId::from_proto(request.channel_id);
3983 let messages = session
3984 .db()
3985 .await
3986 .get_channel_messages(
3987 channel_id,
3988 session.user_id(),
3989 MESSAGE_COUNT_PER_PAGE,
3990 Some(MessageId::from_proto(request.before_message_id)),
3991 )
3992 .await?;
3993 response.send(proto::GetChannelMessagesResponse {
3994 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3995 messages,
3996 })?;
3997 Ok(())
3998}
3999
4000/// Retrieve specific chat messages
4001async fn get_channel_messages_by_id(
4002 request: proto::GetChannelMessagesById,
4003 response: Response<proto::GetChannelMessagesById>,
4004 session: Session,
4005) -> Result<()> {
4006 let message_ids = request
4007 .message_ids
4008 .iter()
4009 .map(|id| MessageId::from_proto(*id))
4010 .collect::<Vec<_>>();
4011 let messages = session
4012 .db()
4013 .await
4014 .get_channel_messages_by_id(session.user_id(), &message_ids)
4015 .await?;
4016 response.send(proto::GetChannelMessagesResponse {
4017 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4018 messages,
4019 })?;
4020 Ok(())
4021}
4022
4023/// Retrieve the current users notifications
4024async fn get_notifications(
4025 request: proto::GetNotifications,
4026 response: Response<proto::GetNotifications>,
4027 session: Session,
4028) -> Result<()> {
4029 let notifications = session
4030 .db()
4031 .await
4032 .get_notifications(
4033 session.user_id(),
4034 NOTIFICATION_COUNT_PER_PAGE,
4035 request.before_id.map(db::NotificationId::from_proto),
4036 )
4037 .await?;
4038 response.send(proto::GetNotificationsResponse {
4039 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4040 notifications,
4041 })?;
4042 Ok(())
4043}
4044
4045/// Mark notifications as read
4046async fn mark_notification_as_read(
4047 request: proto::MarkNotificationRead,
4048 response: Response<proto::MarkNotificationRead>,
4049 session: Session,
4050) -> Result<()> {
4051 let database = &session.db().await;
4052 let notifications = database
4053 .mark_notification_as_read_by_id(
4054 session.user_id(),
4055 NotificationId::from_proto(request.notification_id),
4056 )
4057 .await?;
4058 send_notifications(
4059 &*session.connection_pool().await,
4060 &session.peer,
4061 notifications,
4062 );
4063 response.send(proto::Ack {})?;
4064 Ok(())
4065}
4066
4067/// Get the current users information
4068async fn get_private_user_info(
4069 _request: proto::GetPrivateUserInfo,
4070 response: Response<proto::GetPrivateUserInfo>,
4071 session: Session,
4072) -> Result<()> {
4073 let db = session.db().await;
4074
4075 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4076 let user = db
4077 .get_user_by_id(session.user_id())
4078 .await?
4079 .ok_or_else(|| anyhow!("user not found"))?;
4080 let flags = db.get_user_flags(session.user_id()).await?;
4081
4082 response.send(proto::GetPrivateUserInfoResponse {
4083 metrics_id,
4084 staff: user.admin,
4085 flags,
4086 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4087 })?;
4088 Ok(())
4089}
4090
4091/// Accept the terms of service (tos) on behalf of the current user
4092async fn accept_terms_of_service(
4093 _request: proto::AcceptTermsOfService,
4094 response: Response<proto::AcceptTermsOfService>,
4095 session: Session,
4096) -> Result<()> {
4097 let db = session.db().await;
4098
4099 let accepted_tos_at = Utc::now();
4100 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4101 .await?;
4102
4103 response.send(proto::AcceptTermsOfServiceResponse {
4104 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4105 })?;
4106 Ok(())
4107}
4108
4109/// The minimum account age an account must have in order to use the LLM service.
4110pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4111
4112async fn get_llm_api_token(
4113 _request: proto::GetLlmToken,
4114 response: Response<proto::GetLlmToken>,
4115 session: Session,
4116) -> Result<()> {
4117 let db = session.db().await;
4118
4119 let flags = db.get_user_flags(session.user_id()).await?;
4120 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4121
4122 if !session.is_staff() && !has_language_models_feature_flag {
4123 Err(anyhow!("permission denied"))?
4124 }
4125
4126 let user_id = session.user_id();
4127 let user = db
4128 .get_user_by_id(user_id)
4129 .await?
4130 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4131
4132 if user.accepted_tos_at.is_none() {
4133 Err(anyhow!("terms of service not accepted"))?
4134 }
4135
4136 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4137 let billing_preferences = db.get_billing_preferences(user.id).await?;
4138
4139 let token = LlmTokenClaims::create(
4140 &user,
4141 session.is_staff(),
4142 billing_preferences,
4143 &flags,
4144 has_llm_subscription,
4145 session.current_plan(&db).await?,
4146 session.system_id.clone(),
4147 &session.app_state.config,
4148 )?;
4149 response.send(proto::GetLlmTokenResponse { token })?;
4150 Ok(())
4151}
4152
4153fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4154 let message = match message {
4155 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4156 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4157 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4158 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4159 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4160 code: frame.code.into(),
4161 reason: frame.reason,
4162 })),
4163 // We should never receive a frame while reading the message, according
4164 // to the `tungstenite` maintainers:
4165 //
4166 // > It cannot occur when you read messages from the WebSocket, but it
4167 // > can be used when you want to send the raw frames (e.g. you want to
4168 // > send the frames to the WebSocket without composing the full message first).
4169 // >
4170 // > — https://github.com/snapview/tungstenite-rs/issues/268
4171 TungsteniteMessage::Frame(_) => {
4172 bail!("received an unexpected frame while reading the message")
4173 }
4174 };
4175
4176 Ok(message)
4177}
4178
4179fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4180 match message {
4181 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4182 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4183 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4184 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4185 AxumMessage::Close(frame) => {
4186 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4187 code: frame.code.into(),
4188 reason: frame.reason,
4189 }))
4190 }
4191 }
4192}
4193
4194fn notify_membership_updated(
4195 connection_pool: &mut ConnectionPool,
4196 result: MembershipUpdated,
4197 user_id: UserId,
4198 peer: &Peer,
4199) {
4200 for membership in &result.new_channels.channel_memberships {
4201 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4202 }
4203 for channel_id in &result.removed_channels {
4204 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4205 }
4206
4207 let user_channels_update = proto::UpdateUserChannels {
4208 channel_memberships: result
4209 .new_channels
4210 .channel_memberships
4211 .iter()
4212 .map(|cm| proto::ChannelMembership {
4213 channel_id: cm.channel_id.to_proto(),
4214 role: cm.role.into(),
4215 })
4216 .collect(),
4217 ..Default::default()
4218 };
4219
4220 let mut update = build_channels_update(result.new_channels);
4221 update.delete_channels = result
4222 .removed_channels
4223 .into_iter()
4224 .map(|id| id.to_proto())
4225 .collect();
4226 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4227
4228 for connection_id in connection_pool.user_connection_ids(user_id) {
4229 peer.send(connection_id, user_channels_update.clone())
4230 .trace_err();
4231 peer.send(connection_id, update.clone()).trace_err();
4232 }
4233}
4234
4235fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4236 proto::UpdateUserChannels {
4237 channel_memberships: channels
4238 .channel_memberships
4239 .iter()
4240 .map(|m| proto::ChannelMembership {
4241 channel_id: m.channel_id.to_proto(),
4242 role: m.role.into(),
4243 })
4244 .collect(),
4245 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4246 observed_channel_message_id: channels.observed_channel_messages.clone(),
4247 }
4248}
4249
4250fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4251 let mut update = proto::UpdateChannels::default();
4252
4253 for channel in channels.channels {
4254 update.channels.push(channel.to_proto());
4255 }
4256
4257 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4258 update.latest_channel_message_ids = channels.latest_channel_messages;
4259
4260 for (channel_id, participants) in channels.channel_participants {
4261 update
4262 .channel_participants
4263 .push(proto::ChannelParticipants {
4264 channel_id: channel_id.to_proto(),
4265 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4266 });
4267 }
4268
4269 for channel in channels.invited_channels {
4270 update.channel_invitations.push(channel.to_proto());
4271 }
4272
4273 update
4274}
4275
4276fn build_initial_contacts_update(
4277 contacts: Vec<db::Contact>,
4278 pool: &ConnectionPool,
4279) -> proto::UpdateContacts {
4280 let mut update = proto::UpdateContacts::default();
4281
4282 for contact in contacts {
4283 match contact {
4284 db::Contact::Accepted { user_id, busy } => {
4285 update.contacts.push(contact_for_user(user_id, busy, pool));
4286 }
4287 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4288 db::Contact::Incoming { user_id } => {
4289 update
4290 .incoming_requests
4291 .push(proto::IncomingContactRequest {
4292 requester_id: user_id.to_proto(),
4293 })
4294 }
4295 }
4296 }
4297
4298 update
4299}
4300
4301fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4302 proto::Contact {
4303 user_id: user_id.to_proto(),
4304 online: pool.is_user_online(user_id),
4305 busy,
4306 }
4307}
4308
4309fn room_updated(room: &proto::Room, peer: &Peer) {
4310 broadcast(
4311 None,
4312 room.participants
4313 .iter()
4314 .filter_map(|participant| Some(participant.peer_id?.into())),
4315 |peer_id| {
4316 peer.send(
4317 peer_id,
4318 proto::RoomUpdated {
4319 room: Some(room.clone()),
4320 },
4321 )
4322 },
4323 );
4324}
4325
4326fn channel_updated(
4327 channel: &db::channel::Model,
4328 room: &proto::Room,
4329 peer: &Peer,
4330 pool: &ConnectionPool,
4331) {
4332 let participants = room
4333 .participants
4334 .iter()
4335 .map(|p| p.user_id)
4336 .collect::<Vec<_>>();
4337
4338 broadcast(
4339 None,
4340 pool.channel_connection_ids(channel.root_id())
4341 .filter_map(|(channel_id, role)| {
4342 role.can_see_channel(channel.visibility)
4343 .then_some(channel_id)
4344 }),
4345 |peer_id| {
4346 peer.send(
4347 peer_id,
4348 proto::UpdateChannels {
4349 channel_participants: vec![proto::ChannelParticipants {
4350 channel_id: channel.id.to_proto(),
4351 participant_user_ids: participants.clone(),
4352 }],
4353 ..Default::default()
4354 },
4355 )
4356 },
4357 );
4358}
4359
4360async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4361 let db = session.db().await;
4362
4363 let contacts = db.get_contacts(user_id).await?;
4364 let busy = db.is_user_busy(user_id).await?;
4365
4366 let pool = session.connection_pool().await;
4367 let updated_contact = contact_for_user(user_id, busy, &pool);
4368 for contact in contacts {
4369 if let db::Contact::Accepted {
4370 user_id: contact_user_id,
4371 ..
4372 } = contact
4373 {
4374 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4375 session
4376 .peer
4377 .send(
4378 contact_conn_id,
4379 proto::UpdateContacts {
4380 contacts: vec![updated_contact.clone()],
4381 remove_contacts: Default::default(),
4382 incoming_requests: Default::default(),
4383 remove_incoming_requests: Default::default(),
4384 outgoing_requests: Default::default(),
4385 remove_outgoing_requests: Default::default(),
4386 },
4387 )
4388 .trace_err();
4389 }
4390 }
4391 }
4392 Ok(())
4393}
4394
4395async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4396 let mut contacts_to_update = HashSet::default();
4397
4398 let room_id;
4399 let canceled_calls_to_user_ids;
4400 let livekit_room;
4401 let delete_livekit_room;
4402 let room;
4403 let channel;
4404
4405 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4406 contacts_to_update.insert(session.user_id());
4407
4408 for project in left_room.left_projects.values() {
4409 project_left(project, session);
4410 }
4411
4412 room_id = RoomId::from_proto(left_room.room.id);
4413 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4414 livekit_room = mem::take(&mut left_room.room.livekit_room);
4415 delete_livekit_room = left_room.deleted;
4416 room = mem::take(&mut left_room.room);
4417 channel = mem::take(&mut left_room.channel);
4418
4419 room_updated(&room, &session.peer);
4420 } else {
4421 return Ok(());
4422 }
4423
4424 if let Some(channel) = channel {
4425 channel_updated(
4426 &channel,
4427 &room,
4428 &session.peer,
4429 &*session.connection_pool().await,
4430 );
4431 }
4432
4433 {
4434 let pool = session.connection_pool().await;
4435 for canceled_user_id in canceled_calls_to_user_ids {
4436 for connection_id in pool.user_connection_ids(canceled_user_id) {
4437 session
4438 .peer
4439 .send(
4440 connection_id,
4441 proto::CallCanceled {
4442 room_id: room_id.to_proto(),
4443 },
4444 )
4445 .trace_err();
4446 }
4447 contacts_to_update.insert(canceled_user_id);
4448 }
4449 }
4450
4451 for contact_user_id in contacts_to_update {
4452 update_user_contacts(contact_user_id, session).await?;
4453 }
4454
4455 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4456 live_kit
4457 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4458 .await
4459 .trace_err();
4460
4461 if delete_livekit_room {
4462 live_kit.delete_room(livekit_room).await.trace_err();
4463 }
4464 }
4465
4466 Ok(())
4467}
4468
4469async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4470 let left_channel_buffers = session
4471 .db()
4472 .await
4473 .leave_channel_buffers(session.connection_id)
4474 .await?;
4475
4476 for left_buffer in left_channel_buffers {
4477 channel_buffer_updated(
4478 session.connection_id,
4479 left_buffer.connections,
4480 &proto::UpdateChannelBufferCollaborators {
4481 channel_id: left_buffer.channel_id.to_proto(),
4482 collaborators: left_buffer.collaborators,
4483 },
4484 &session.peer,
4485 );
4486 }
4487
4488 Ok(())
4489}
4490
4491fn project_left(project: &db::LeftProject, session: &Session) {
4492 for connection_id in &project.connection_ids {
4493 if project.should_unshare {
4494 session
4495 .peer
4496 .send(
4497 *connection_id,
4498 proto::UnshareProject {
4499 project_id: project.id.to_proto(),
4500 },
4501 )
4502 .trace_err();
4503 } else {
4504 session
4505 .peer
4506 .send(
4507 *connection_id,
4508 proto::RemoveProjectCollaborator {
4509 project_id: project.id.to_proto(),
4510 peer_id: Some(session.connection_id.into()),
4511 },
4512 )
4513 .trace_err();
4514 }
4515 }
4516}
4517
4518pub trait ResultExt {
4519 type Ok;
4520
4521 fn trace_err(self) -> Option<Self::Ok>;
4522}
4523
4524impl<T, E> ResultExt for Result<T, E>
4525where
4526 E: std::fmt::Debug,
4527{
4528 type Ok = T;
4529
4530 #[track_caller]
4531 fn trace_err(self) -> Option<T> {
4532 match self {
4533 Ok(value) => Some(value),
4534 Err(error) => {
4535 tracing::error!("{:?}", error);
4536 None
4537 }
4538 }
4539 }
4540}