1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
314 .add_request_handler(
315 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
316 )
317 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
318 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
323 .add_request_handler(
324 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
325 )
326 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
327 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
328 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
332 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
333 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
337 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
338 .add_request_handler(
339 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
340 )
341 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
342 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
343 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
344 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
345 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
346 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
347 .add_message_handler(create_buffer_for_peer)
348 .add_request_handler(update_buffer)
349 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
350 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
353 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
354 .add_request_handler(get_users)
355 .add_request_handler(fuzzy_search_users)
356 .add_request_handler(request_contact)
357 .add_request_handler(remove_contact)
358 .add_request_handler(respond_to_contact_request)
359 .add_message_handler(subscribe_to_channels)
360 .add_request_handler(create_channel)
361 .add_request_handler(delete_channel)
362 .add_request_handler(invite_channel_member)
363 .add_request_handler(remove_channel_member)
364 .add_request_handler(set_channel_member_role)
365 .add_request_handler(set_channel_visibility)
366 .add_request_handler(rename_channel)
367 .add_request_handler(join_channel_buffer)
368 .add_request_handler(leave_channel_buffer)
369 .add_message_handler(update_channel_buffer)
370 .add_request_handler(rejoin_channel_buffers)
371 .add_request_handler(get_channel_members)
372 .add_request_handler(respond_to_channel_invite)
373 .add_request_handler(join_channel)
374 .add_request_handler(join_channel_chat)
375 .add_message_handler(leave_channel_chat)
376 .add_request_handler(send_channel_message)
377 .add_request_handler(remove_channel_message)
378 .add_request_handler(update_channel_message)
379 .add_request_handler(get_channel_messages)
380 .add_request_handler(get_channel_messages_by_id)
381 .add_request_handler(get_notifications)
382 .add_request_handler(mark_notification_as_read)
383 .add_request_handler(move_channel)
384 .add_request_handler(follow)
385 .add_message_handler(unfollow)
386 .add_message_handler(update_followers)
387 .add_request_handler(get_private_user_info)
388 .add_request_handler(get_llm_api_token)
389 .add_request_handler(accept_terms_of_service)
390 .add_message_handler(acknowledge_channel_message)
391 .add_message_handler(acknowledge_buffer_version)
392 .add_request_handler(get_supermaven_api_key)
393 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
394 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
395 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
396 .add_request_handler(forward_mutating_project_request::<proto::Push>)
397 .add_request_handler(forward_mutating_project_request::<proto::Pull>)
398 .add_request_handler(forward_mutating_project_request::<proto::Fetch>)
399 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
400 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
401 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
402 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
403 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
404 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
405 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
406 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
407 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
408 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
409 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
410 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
411 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
412 .add_message_handler(update_context)
413 .add_request_handler({
414 let app_state = app_state.clone();
415 move |request, response, session| {
416 let app_state = app_state.clone();
417 async move {
418 count_language_model_tokens(request, response, session, &app_state.config)
419 .await
420 }
421 }
422 })
423 .add_request_handler(get_cached_embeddings)
424 .add_request_handler({
425 let app_state = app_state.clone();
426 move |request, response, session| {
427 compute_embeddings(
428 request,
429 response,
430 session,
431 app_state.config.openai_api_key.clone(),
432 )
433 }
434 });
435
436 Arc::new(server)
437 }
438
439 pub async fn start(&self) -> Result<()> {
440 let server_id = *self.id.lock();
441 let app_state = self.app_state.clone();
442 let peer = self.peer.clone();
443 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
444 let pool = self.connection_pool.clone();
445 let livekit_client = self.app_state.livekit_client.clone();
446
447 let span = info_span!("start server");
448 self.app_state.executor.spawn_detached(
449 async move {
450 tracing::info!("waiting for cleanup timeout");
451 timeout.await;
452 tracing::info!("cleanup timeout expired, retrieving stale rooms");
453 if let Some((room_ids, channel_ids)) = app_state
454 .db
455 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
456 .await
457 .trace_err()
458 {
459 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
460 tracing::info!(
461 stale_channel_buffer_count = channel_ids.len(),
462 "retrieved stale channel buffers"
463 );
464
465 for channel_id in channel_ids {
466 if let Some(refreshed_channel_buffer) = app_state
467 .db
468 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
469 .await
470 .trace_err()
471 {
472 for connection_id in refreshed_channel_buffer.connection_ids {
473 peer.send(
474 connection_id,
475 proto::UpdateChannelBufferCollaborators {
476 channel_id: channel_id.to_proto(),
477 collaborators: refreshed_channel_buffer
478 .collaborators
479 .clone(),
480 },
481 )
482 .trace_err();
483 }
484 }
485 }
486
487 for room_id in room_ids {
488 let mut contacts_to_update = HashSet::default();
489 let mut canceled_calls_to_user_ids = Vec::new();
490 let mut livekit_room = String::new();
491 let mut delete_livekit_room = false;
492
493 if let Some(mut refreshed_room) = app_state
494 .db
495 .clear_stale_room_participants(room_id, server_id)
496 .await
497 .trace_err()
498 {
499 tracing::info!(
500 room_id = room_id.0,
501 new_participant_count = refreshed_room.room.participants.len(),
502 "refreshed room"
503 );
504 room_updated(&refreshed_room.room, &peer);
505 if let Some(channel) = refreshed_room.channel.as_ref() {
506 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
507 }
508 contacts_to_update
509 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
510 contacts_to_update
511 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
512 canceled_calls_to_user_ids =
513 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
514 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
515 delete_livekit_room = refreshed_room.room.participants.is_empty();
516 }
517
518 {
519 let pool = pool.lock();
520 for canceled_user_id in canceled_calls_to_user_ids {
521 for connection_id in pool.user_connection_ids(canceled_user_id) {
522 peer.send(
523 connection_id,
524 proto::CallCanceled {
525 room_id: room_id.to_proto(),
526 },
527 )
528 .trace_err();
529 }
530 }
531 }
532
533 for user_id in contacts_to_update {
534 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
535 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
536 if let Some((busy, contacts)) = busy.zip(contacts) {
537 let pool = pool.lock();
538 let updated_contact = contact_for_user(user_id, busy, &pool);
539 for contact in contacts {
540 if let db::Contact::Accepted {
541 user_id: contact_user_id,
542 ..
543 } = contact
544 {
545 for contact_conn_id in
546 pool.user_connection_ids(contact_user_id)
547 {
548 peer.send(
549 contact_conn_id,
550 proto::UpdateContacts {
551 contacts: vec![updated_contact.clone()],
552 remove_contacts: Default::default(),
553 incoming_requests: Default::default(),
554 remove_incoming_requests: Default::default(),
555 outgoing_requests: Default::default(),
556 remove_outgoing_requests: Default::default(),
557 },
558 )
559 .trace_err();
560 }
561 }
562 }
563 }
564 }
565
566 if let Some(live_kit) = livekit_client.as_ref() {
567 if delete_livekit_room {
568 live_kit.delete_room(livekit_room).await.trace_err();
569 }
570 }
571 }
572 }
573
574 app_state
575 .db
576 .delete_stale_servers(&app_state.config.zed_environment, server_id)
577 .await
578 .trace_err();
579 }
580 .instrument(span),
581 );
582 Ok(())
583 }
584
585 pub fn teardown(&self) {
586 self.peer.teardown();
587 self.connection_pool.lock().reset();
588 let _ = self.teardown.send(true);
589 }
590
591 #[cfg(test)]
592 pub fn reset(&self, id: ServerId) {
593 self.teardown();
594 *self.id.lock() = id;
595 self.peer.reset(id.0 as u32);
596 let _ = self.teardown.send(false);
597 }
598
599 #[cfg(test)]
600 pub fn id(&self) -> ServerId {
601 *self.id.lock()
602 }
603
604 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
605 where
606 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
607 Fut: 'static + Send + Future<Output = Result<()>>,
608 M: EnvelopedMessage,
609 {
610 let prev_handler = self.handlers.insert(
611 TypeId::of::<M>(),
612 Box::new(move |envelope, session| {
613 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
614 let received_at = envelope.received_at;
615 tracing::info!("message received");
616 let start_time = Instant::now();
617 let future = (handler)(*envelope, session);
618 async move {
619 let result = future.await;
620 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
621 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
622 let queue_duration_ms = total_duration_ms - processing_duration_ms;
623 let payload_type = M::NAME;
624
625 match result {
626 Err(error) => {
627 tracing::error!(
628 ?error,
629 total_duration_ms,
630 processing_duration_ms,
631 queue_duration_ms,
632 payload_type,
633 "error handling message"
634 )
635 }
636 Ok(()) => tracing::info!(
637 total_duration_ms,
638 processing_duration_ms,
639 queue_duration_ms,
640 "finished handling message"
641 ),
642 }
643 }
644 .boxed()
645 }),
646 );
647 if prev_handler.is_some() {
648 panic!("registered a handler for the same message twice");
649 }
650 self
651 }
652
653 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
654 where
655 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
656 Fut: 'static + Send + Future<Output = Result<()>>,
657 M: EnvelopedMessage,
658 {
659 self.add_handler(move |envelope, session| handler(envelope.payload, session));
660 self
661 }
662
663 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
664 where
665 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
666 Fut: Send + Future<Output = Result<()>>,
667 M: RequestMessage,
668 {
669 let handler = Arc::new(handler);
670 self.add_handler(move |envelope, session| {
671 let receipt = envelope.receipt();
672 let handler = handler.clone();
673 async move {
674 let peer = session.peer.clone();
675 let responded = Arc::new(AtomicBool::default());
676 let response = Response {
677 peer: peer.clone(),
678 responded: responded.clone(),
679 receipt,
680 };
681 match (handler)(envelope.payload, response, session).await {
682 Ok(()) => {
683 if responded.load(std::sync::atomic::Ordering::SeqCst) {
684 Ok(())
685 } else {
686 Err(anyhow!("handler did not send a response"))?
687 }
688 }
689 Err(error) => {
690 let proto_err = match &error {
691 Error::Internal(err) => err.to_proto(),
692 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
693 };
694 peer.respond_with_error(receipt, proto_err)?;
695 Err(error)
696 }
697 }
698 }
699 })
700 }
701
702 #[allow(clippy::too_many_arguments)]
703 pub fn handle_connection(
704 self: &Arc<Self>,
705 connection: Connection,
706 address: String,
707 principal: Principal,
708 zed_version: ZedVersion,
709 geoip_country_code: Option<String>,
710 system_id: Option<String>,
711 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
712 executor: Executor,
713 ) -> impl Future<Output = ()> {
714 let this = self.clone();
715 let span = info_span!("handle connection", %address,
716 connection_id=field::Empty,
717 user_id=field::Empty,
718 login=field::Empty,
719 impersonator=field::Empty,
720 geoip_country_code=field::Empty
721 );
722 principal.update_span(&span);
723 if let Some(country_code) = geoip_country_code.as_ref() {
724 span.record("geoip_country_code", country_code);
725 }
726
727 let mut teardown = self.teardown.subscribe();
728 async move {
729 if *teardown.borrow() {
730 tracing::error!("server is tearing down");
731 return
732 }
733 let (connection_id, handle_io, mut incoming_rx) = this
734 .peer
735 .add_connection(connection, {
736 let executor = executor.clone();
737 move |duration| executor.sleep(duration)
738 });
739 tracing::Span::current().record("connection_id", format!("{}", connection_id));
740
741 tracing::info!("connection opened");
742
743 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
744 let http_client = match ReqwestClient::user_agent(&user_agent) {
745 Ok(http_client) => Arc::new(http_client),
746 Err(error) => {
747 tracing::error!(?error, "failed to create HTTP client");
748 return;
749 }
750 };
751
752 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
753 supermaven_admin_api_key.to_string(),
754 http_client.clone(),
755 )));
756
757 let session = Session {
758 principal: principal.clone(),
759 connection_id,
760 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
761 peer: this.peer.clone(),
762 connection_pool: this.connection_pool.clone(),
763 app_state: this.app_state.clone(),
764 http_client,
765 geoip_country_code,
766 system_id,
767 _executor: executor.clone(),
768 supermaven_client,
769 };
770
771 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
772 tracing::error!(?error, "failed to send initial client update");
773 return;
774 }
775
776 let handle_io = handle_io.fuse();
777 futures::pin_mut!(handle_io);
778
779 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
780 // This prevents deadlocks when e.g., client A performs a request to client B and
781 // client B performs a request to client A. If both clients stop processing further
782 // messages until their respective request completes, they won't have a chance to
783 // respond to the other client's request and cause a deadlock.
784 //
785 // This arrangement ensures we will attempt to process earlier messages first, but fall
786 // back to processing messages arrived later in the spirit of making progress.
787 let mut foreground_message_handlers = FuturesUnordered::new();
788 let concurrent_handlers = Arc::new(Semaphore::new(256));
789 loop {
790 let next_message = async {
791 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
792 let message = incoming_rx.next().await;
793 (permit, message)
794 }.fuse();
795 futures::pin_mut!(next_message);
796 futures::select_biased! {
797 _ = teardown.changed().fuse() => return,
798 result = handle_io => {
799 if let Err(error) = result {
800 tracing::error!(?error, "error handling I/O");
801 }
802 break;
803 }
804 _ = foreground_message_handlers.next() => {}
805 next_message = next_message => {
806 let (permit, message) = next_message;
807 if let Some(message) = message {
808 let type_name = message.payload_type_name();
809 // note: we copy all the fields from the parent span so we can query them in the logs.
810 // (https://github.com/tokio-rs/tracing/issues/2670).
811 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
812 user_id=field::Empty,
813 login=field::Empty,
814 impersonator=field::Empty,
815 );
816 principal.update_span(&span);
817 let span_enter = span.enter();
818 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
819 let is_background = message.is_background();
820 let handle_message = (handler)(message, session.clone());
821 drop(span_enter);
822
823 let handle_message = async move {
824 handle_message.await;
825 drop(permit);
826 }.instrument(span);
827 if is_background {
828 executor.spawn_detached(handle_message);
829 } else {
830 foreground_message_handlers.push(handle_message);
831 }
832 } else {
833 tracing::error!("no message handler");
834 }
835 } else {
836 tracing::info!("connection closed");
837 break;
838 }
839 }
840 }
841 }
842
843 drop(foreground_message_handlers);
844 tracing::info!("signing out");
845 if let Err(error) = connection_lost(session, teardown, executor).await {
846 tracing::error!(?error, "error signing out");
847 }
848
849 }.instrument(span)
850 }
851
852 async fn send_initial_client_update(
853 &self,
854 connection_id: ConnectionId,
855 principal: &Principal,
856 zed_version: ZedVersion,
857 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
858 session: &Session,
859 ) -> Result<()> {
860 self.peer.send(
861 connection_id,
862 proto::Hello {
863 peer_id: Some(connection_id.into()),
864 },
865 )?;
866 tracing::info!("sent hello message");
867 if let Some(send_connection_id) = send_connection_id.take() {
868 let _ = send_connection_id.send(connection_id);
869 }
870
871 match principal {
872 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
873 if !user.connected_once {
874 self.peer.send(connection_id, proto::ShowContacts {})?;
875 self.app_state
876 .db
877 .set_user_connected_once(user.id, true)
878 .await?;
879 }
880
881 update_user_plan(user.id, session).await?;
882
883 let contacts = self.app_state.db.get_contacts(user.id).await?;
884
885 {
886 let mut pool = self.connection_pool.lock();
887 pool.add_connection(connection_id, user.id, user.admin, zed_version);
888 self.peer.send(
889 connection_id,
890 build_initial_contacts_update(contacts, &pool),
891 )?;
892 }
893
894 if should_auto_subscribe_to_channels(zed_version) {
895 subscribe_user_to_channels(user.id, session).await?;
896 }
897
898 if let Some(incoming_call) =
899 self.app_state.db.incoming_call_for_user(user.id).await?
900 {
901 self.peer.send(connection_id, incoming_call)?;
902 }
903
904 update_user_contacts(user.id, session).await?;
905 }
906 }
907
908 Ok(())
909 }
910
911 pub async fn invite_code_redeemed(
912 self: &Arc<Self>,
913 inviter_id: UserId,
914 invitee_id: UserId,
915 ) -> Result<()> {
916 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
917 if let Some(code) = &user.invite_code {
918 let pool = self.connection_pool.lock();
919 let invitee_contact = contact_for_user(invitee_id, false, &pool);
920 for connection_id in pool.user_connection_ids(inviter_id) {
921 self.peer.send(
922 connection_id,
923 proto::UpdateContacts {
924 contacts: vec![invitee_contact.clone()],
925 ..Default::default()
926 },
927 )?;
928 self.peer.send(
929 connection_id,
930 proto::UpdateInviteInfo {
931 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
932 count: user.invite_count as u32,
933 },
934 )?;
935 }
936 }
937 }
938 Ok(())
939 }
940
941 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
942 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
943 if let Some(invite_code) = &user.invite_code {
944 let pool = self.connection_pool.lock();
945 for connection_id in pool.user_connection_ids(user_id) {
946 self.peer.send(
947 connection_id,
948 proto::UpdateInviteInfo {
949 url: format!(
950 "{}{}",
951 self.app_state.config.invite_link_prefix, invite_code
952 ),
953 count: user.invite_count as u32,
954 },
955 )?;
956 }
957 }
958 }
959 Ok(())
960 }
961
962 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
963 let pool = self.connection_pool.lock();
964 for connection_id in pool.user_connection_ids(user_id) {
965 self.peer
966 .send(connection_id, proto::RefreshLlmToken {})
967 .trace_err();
968 }
969 }
970
971 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
972 ServerSnapshot {
973 connection_pool: ConnectionPoolGuard {
974 guard: self.connection_pool.lock(),
975 _not_send: PhantomData,
976 },
977 peer: &self.peer,
978 }
979 }
980}
981
982impl Deref for ConnectionPoolGuard<'_> {
983 type Target = ConnectionPool;
984
985 fn deref(&self) -> &Self::Target {
986 &self.guard
987 }
988}
989
990impl DerefMut for ConnectionPoolGuard<'_> {
991 fn deref_mut(&mut self) -> &mut Self::Target {
992 &mut self.guard
993 }
994}
995
996impl Drop for ConnectionPoolGuard<'_> {
997 fn drop(&mut self) {
998 #[cfg(test)]
999 self.check_invariants();
1000 }
1001}
1002
1003fn broadcast<F>(
1004 sender_id: Option<ConnectionId>,
1005 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1006 mut f: F,
1007) where
1008 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1009{
1010 for receiver_id in receiver_ids {
1011 if Some(receiver_id) != sender_id {
1012 if let Err(error) = f(receiver_id) {
1013 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1014 }
1015 }
1016 }
1017}
1018
1019pub struct ProtocolVersion(u32);
1020
1021impl Header for ProtocolVersion {
1022 fn name() -> &'static HeaderName {
1023 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1024 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1025 }
1026
1027 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1028 where
1029 Self: Sized,
1030 I: Iterator<Item = &'i axum::http::HeaderValue>,
1031 {
1032 let version = values
1033 .next()
1034 .ok_or_else(axum::headers::Error::invalid)?
1035 .to_str()
1036 .map_err(|_| axum::headers::Error::invalid())?
1037 .parse()
1038 .map_err(|_| axum::headers::Error::invalid())?;
1039 Ok(Self(version))
1040 }
1041
1042 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1043 values.extend([self.0.to_string().parse().unwrap()]);
1044 }
1045}
1046
1047pub struct AppVersionHeader(SemanticVersion);
1048impl Header for AppVersionHeader {
1049 fn name() -> &'static HeaderName {
1050 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1051 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1052 }
1053
1054 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1055 where
1056 Self: Sized,
1057 I: Iterator<Item = &'i axum::http::HeaderValue>,
1058 {
1059 let version = values
1060 .next()
1061 .ok_or_else(axum::headers::Error::invalid)?
1062 .to_str()
1063 .map_err(|_| axum::headers::Error::invalid())?
1064 .parse()
1065 .map_err(|_| axum::headers::Error::invalid())?;
1066 Ok(Self(version))
1067 }
1068
1069 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1070 values.extend([self.0.to_string().parse().unwrap()]);
1071 }
1072}
1073
1074pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1075 Router::new()
1076 .route("/rpc", get(handle_websocket_request))
1077 .layer(
1078 ServiceBuilder::new()
1079 .layer(Extension(server.app_state.clone()))
1080 .layer(middleware::from_fn(auth::validate_header)),
1081 )
1082 .route("/metrics", get(handle_metrics))
1083 .layer(Extension(server))
1084}
1085
1086#[allow(clippy::too_many_arguments)]
1087pub async fn handle_websocket_request(
1088 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1089 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1090 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1091 Extension(server): Extension<Arc<Server>>,
1092 Extension(principal): Extension<Principal>,
1093 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1094 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1095 ws: WebSocketUpgrade,
1096) -> axum::response::Response {
1097 if protocol_version != rpc::PROTOCOL_VERSION {
1098 return (
1099 StatusCode::UPGRADE_REQUIRED,
1100 "client must be upgraded".to_string(),
1101 )
1102 .into_response();
1103 }
1104
1105 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1106 return (
1107 StatusCode::UPGRADE_REQUIRED,
1108 "no version header found".to_string(),
1109 )
1110 .into_response();
1111 };
1112
1113 if !version.can_collaborate() {
1114 return (
1115 StatusCode::UPGRADE_REQUIRED,
1116 "client must be upgraded".to_string(),
1117 )
1118 .into_response();
1119 }
1120
1121 let socket_address = socket_address.to_string();
1122 ws.on_upgrade(move |socket| {
1123 let socket = socket
1124 .map_ok(to_tungstenite_message)
1125 .err_into()
1126 .with(|message| async move { to_axum_message(message) });
1127 let connection = Connection::new(Box::pin(socket));
1128 async move {
1129 server
1130 .handle_connection(
1131 connection,
1132 socket_address,
1133 principal,
1134 version,
1135 country_code_header.map(|header| header.to_string()),
1136 system_id_header.map(|header| header.to_string()),
1137 None,
1138 Executor::Production,
1139 )
1140 .await;
1141 }
1142 })
1143}
1144
1145pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1146 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1147 let connections_metric = CONNECTIONS_METRIC
1148 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1149
1150 let connections = server
1151 .connection_pool
1152 .lock()
1153 .connections()
1154 .filter(|connection| !connection.admin)
1155 .count();
1156 connections_metric.set(connections as _);
1157
1158 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1159 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1160 register_int_gauge!(
1161 "shared_projects",
1162 "number of open projects with one or more guests"
1163 )
1164 .unwrap()
1165 });
1166
1167 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1168 shared_projects_metric.set(shared_projects as _);
1169
1170 let encoder = prometheus::TextEncoder::new();
1171 let metric_families = prometheus::gather();
1172 let encoded_metrics = encoder
1173 .encode_to_string(&metric_families)
1174 .map_err(|err| anyhow!("{}", err))?;
1175 Ok(encoded_metrics)
1176}
1177
1178#[instrument(err, skip(executor))]
1179async fn connection_lost(
1180 session: Session,
1181 mut teardown: watch::Receiver<bool>,
1182 executor: Executor,
1183) -> Result<()> {
1184 session.peer.disconnect(session.connection_id);
1185 session
1186 .connection_pool()
1187 .await
1188 .remove_connection(session.connection_id)?;
1189
1190 session
1191 .db()
1192 .await
1193 .connection_lost(session.connection_id)
1194 .await
1195 .trace_err();
1196
1197 futures::select_biased! {
1198 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1199
1200 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1201 leave_room_for_session(&session, session.connection_id).await.trace_err();
1202 leave_channel_buffers_for_session(&session)
1203 .await
1204 .trace_err();
1205
1206 if !session
1207 .connection_pool()
1208 .await
1209 .is_user_online(session.user_id())
1210 {
1211 let db = session.db().await;
1212 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1213 room_updated(&room, &session.peer);
1214 }
1215 }
1216
1217 update_user_contacts(session.user_id(), &session).await?;
1218 },
1219 _ = teardown.changed().fuse() => {}
1220 }
1221
1222 Ok(())
1223}
1224
1225/// Acknowledges a ping from a client, used to keep the connection alive.
1226async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1227 response.send(proto::Ack {})?;
1228 Ok(())
1229}
1230
1231/// Creates a new room for calling (outside of channels)
1232async fn create_room(
1233 _request: proto::CreateRoom,
1234 response: Response<proto::CreateRoom>,
1235 session: Session,
1236) -> Result<()> {
1237 let livekit_room = nanoid::nanoid!(30);
1238
1239 let live_kit_connection_info = util::maybe!(async {
1240 let live_kit = session.app_state.livekit_client.as_ref();
1241 let live_kit = live_kit?;
1242 let user_id = session.user_id().to_string();
1243
1244 let token = live_kit
1245 .room_token(&livekit_room, &user_id.to_string())
1246 .trace_err()?;
1247
1248 Some(proto::LiveKitConnectionInfo {
1249 server_url: live_kit.url().into(),
1250 token,
1251 can_publish: true,
1252 })
1253 })
1254 .await;
1255
1256 let room = session
1257 .db()
1258 .await
1259 .create_room(session.user_id(), session.connection_id, &livekit_room)
1260 .await?;
1261
1262 response.send(proto::CreateRoomResponse {
1263 room: Some(room.clone()),
1264 live_kit_connection_info,
1265 })?;
1266
1267 update_user_contacts(session.user_id(), &session).await?;
1268 Ok(())
1269}
1270
1271/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1272async fn join_room(
1273 request: proto::JoinRoom,
1274 response: Response<proto::JoinRoom>,
1275 session: Session,
1276) -> Result<()> {
1277 let room_id = RoomId::from_proto(request.id);
1278
1279 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1280
1281 if let Some(channel_id) = channel_id {
1282 return join_channel_internal(channel_id, Box::new(response), session).await;
1283 }
1284
1285 let joined_room = {
1286 let room = session
1287 .db()
1288 .await
1289 .join_room(room_id, session.user_id(), session.connection_id)
1290 .await?;
1291 room_updated(&room.room, &session.peer);
1292 room.into_inner()
1293 };
1294
1295 for connection_id in session
1296 .connection_pool()
1297 .await
1298 .user_connection_ids(session.user_id())
1299 {
1300 session
1301 .peer
1302 .send(
1303 connection_id,
1304 proto::CallCanceled {
1305 room_id: room_id.to_proto(),
1306 },
1307 )
1308 .trace_err();
1309 }
1310
1311 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1312 {
1313 live_kit
1314 .room_token(
1315 &joined_room.room.livekit_room,
1316 &session.user_id().to_string(),
1317 )
1318 .trace_err()
1319 .map(|token| proto::LiveKitConnectionInfo {
1320 server_url: live_kit.url().into(),
1321 token,
1322 can_publish: true,
1323 })
1324 } else {
1325 None
1326 };
1327
1328 response.send(proto::JoinRoomResponse {
1329 room: Some(joined_room.room),
1330 channel_id: None,
1331 live_kit_connection_info,
1332 })?;
1333
1334 update_user_contacts(session.user_id(), &session).await?;
1335 Ok(())
1336}
1337
1338/// Rejoin room is used to reconnect to a room after connection errors.
1339async fn rejoin_room(
1340 request: proto::RejoinRoom,
1341 response: Response<proto::RejoinRoom>,
1342 session: Session,
1343) -> Result<()> {
1344 let room;
1345 let channel;
1346 {
1347 let mut rejoined_room = session
1348 .db()
1349 .await
1350 .rejoin_room(request, session.user_id(), session.connection_id)
1351 .await?;
1352
1353 response.send(proto::RejoinRoomResponse {
1354 room: Some(rejoined_room.room.clone()),
1355 reshared_projects: rejoined_room
1356 .reshared_projects
1357 .iter()
1358 .map(|project| proto::ResharedProject {
1359 id: project.id.to_proto(),
1360 collaborators: project
1361 .collaborators
1362 .iter()
1363 .map(|collaborator| collaborator.to_proto())
1364 .collect(),
1365 })
1366 .collect(),
1367 rejoined_projects: rejoined_room
1368 .rejoined_projects
1369 .iter()
1370 .map(|rejoined_project| rejoined_project.to_proto())
1371 .collect(),
1372 })?;
1373 room_updated(&rejoined_room.room, &session.peer);
1374
1375 for project in &rejoined_room.reshared_projects {
1376 for collaborator in &project.collaborators {
1377 session
1378 .peer
1379 .send(
1380 collaborator.connection_id,
1381 proto::UpdateProjectCollaborator {
1382 project_id: project.id.to_proto(),
1383 old_peer_id: Some(project.old_connection_id.into()),
1384 new_peer_id: Some(session.connection_id.into()),
1385 },
1386 )
1387 .trace_err();
1388 }
1389
1390 broadcast(
1391 Some(session.connection_id),
1392 project
1393 .collaborators
1394 .iter()
1395 .map(|collaborator| collaborator.connection_id),
1396 |connection_id| {
1397 session.peer.forward_send(
1398 session.connection_id,
1399 connection_id,
1400 proto::UpdateProject {
1401 project_id: project.id.to_proto(),
1402 worktrees: project.worktrees.clone(),
1403 },
1404 )
1405 },
1406 );
1407 }
1408
1409 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1410
1411 let rejoined_room = rejoined_room.into_inner();
1412
1413 room = rejoined_room.room;
1414 channel = rejoined_room.channel;
1415 }
1416
1417 if let Some(channel) = channel {
1418 channel_updated(
1419 &channel,
1420 &room,
1421 &session.peer,
1422 &*session.connection_pool().await,
1423 );
1424 }
1425
1426 update_user_contacts(session.user_id(), &session).await?;
1427 Ok(())
1428}
1429
1430fn notify_rejoined_projects(
1431 rejoined_projects: &mut Vec<RejoinedProject>,
1432 session: &Session,
1433) -> Result<()> {
1434 for project in rejoined_projects.iter() {
1435 for collaborator in &project.collaborators {
1436 session
1437 .peer
1438 .send(
1439 collaborator.connection_id,
1440 proto::UpdateProjectCollaborator {
1441 project_id: project.id.to_proto(),
1442 old_peer_id: Some(project.old_connection_id.into()),
1443 new_peer_id: Some(session.connection_id.into()),
1444 },
1445 )
1446 .trace_err();
1447 }
1448 }
1449
1450 for project in rejoined_projects {
1451 for worktree in mem::take(&mut project.worktrees) {
1452 // Stream this worktree's entries.
1453 let message = proto::UpdateWorktree {
1454 project_id: project.id.to_proto(),
1455 worktree_id: worktree.id,
1456 abs_path: worktree.abs_path.clone(),
1457 root_name: worktree.root_name,
1458 updated_entries: worktree.updated_entries,
1459 removed_entries: worktree.removed_entries,
1460 scan_id: worktree.scan_id,
1461 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1462 updated_repositories: worktree.updated_repositories,
1463 removed_repositories: worktree.removed_repositories,
1464 };
1465 for update in proto::split_worktree_update(message) {
1466 session.peer.send(session.connection_id, update.clone())?;
1467 }
1468
1469 // Stream this worktree's diagnostics.
1470 for summary in worktree.diagnostic_summaries {
1471 session.peer.send(
1472 session.connection_id,
1473 proto::UpdateDiagnosticSummary {
1474 project_id: project.id.to_proto(),
1475 worktree_id: worktree.id,
1476 summary: Some(summary),
1477 },
1478 )?;
1479 }
1480
1481 for settings_file in worktree.settings_files {
1482 session.peer.send(
1483 session.connection_id,
1484 proto::UpdateWorktreeSettings {
1485 project_id: project.id.to_proto(),
1486 worktree_id: worktree.id,
1487 path: settings_file.path,
1488 content: Some(settings_file.content),
1489 kind: Some(settings_file.kind.to_proto().into()),
1490 },
1491 )?;
1492 }
1493 }
1494
1495 for language_server in &project.language_servers {
1496 session.peer.send(
1497 session.connection_id,
1498 proto::UpdateLanguageServer {
1499 project_id: project.id.to_proto(),
1500 language_server_id: language_server.id,
1501 variant: Some(
1502 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1503 proto::LspDiskBasedDiagnosticsUpdated {},
1504 ),
1505 ),
1506 },
1507 )?;
1508 }
1509 }
1510 Ok(())
1511}
1512
1513/// leave room disconnects from the room.
1514async fn leave_room(
1515 _: proto::LeaveRoom,
1516 response: Response<proto::LeaveRoom>,
1517 session: Session,
1518) -> Result<()> {
1519 leave_room_for_session(&session, session.connection_id).await?;
1520 response.send(proto::Ack {})?;
1521 Ok(())
1522}
1523
1524/// Updates the permissions of someone else in the room.
1525async fn set_room_participant_role(
1526 request: proto::SetRoomParticipantRole,
1527 response: Response<proto::SetRoomParticipantRole>,
1528 session: Session,
1529) -> Result<()> {
1530 let user_id = UserId::from_proto(request.user_id);
1531 let role = ChannelRole::from(request.role());
1532
1533 let (livekit_room, can_publish) = {
1534 let room = session
1535 .db()
1536 .await
1537 .set_room_participant_role(
1538 session.user_id(),
1539 RoomId::from_proto(request.room_id),
1540 user_id,
1541 role,
1542 )
1543 .await?;
1544
1545 let livekit_room = room.livekit_room.clone();
1546 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1547 room_updated(&room, &session.peer);
1548 (livekit_room, can_publish)
1549 };
1550
1551 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1552 live_kit
1553 .update_participant(
1554 livekit_room.clone(),
1555 request.user_id.to_string(),
1556 livekit_api::proto::ParticipantPermission {
1557 can_subscribe: true,
1558 can_publish,
1559 can_publish_data: can_publish,
1560 hidden: false,
1561 recorder: false,
1562 },
1563 )
1564 .await
1565 .trace_err();
1566 }
1567
1568 response.send(proto::Ack {})?;
1569 Ok(())
1570}
1571
1572/// Call someone else into the current room
1573async fn call(
1574 request: proto::Call,
1575 response: Response<proto::Call>,
1576 session: Session,
1577) -> Result<()> {
1578 let room_id = RoomId::from_proto(request.room_id);
1579 let calling_user_id = session.user_id();
1580 let calling_connection_id = session.connection_id;
1581 let called_user_id = UserId::from_proto(request.called_user_id);
1582 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1583 if !session
1584 .db()
1585 .await
1586 .has_contact(calling_user_id, called_user_id)
1587 .await?
1588 {
1589 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1590 }
1591
1592 let incoming_call = {
1593 let (room, incoming_call) = &mut *session
1594 .db()
1595 .await
1596 .call(
1597 room_id,
1598 calling_user_id,
1599 calling_connection_id,
1600 called_user_id,
1601 initial_project_id,
1602 )
1603 .await?;
1604 room_updated(room, &session.peer);
1605 mem::take(incoming_call)
1606 };
1607 update_user_contacts(called_user_id, &session).await?;
1608
1609 let mut calls = session
1610 .connection_pool()
1611 .await
1612 .user_connection_ids(called_user_id)
1613 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1614 .collect::<FuturesUnordered<_>>();
1615
1616 while let Some(call_response) = calls.next().await {
1617 match call_response.as_ref() {
1618 Ok(_) => {
1619 response.send(proto::Ack {})?;
1620 return Ok(());
1621 }
1622 Err(_) => {
1623 call_response.trace_err();
1624 }
1625 }
1626 }
1627
1628 {
1629 let room = session
1630 .db()
1631 .await
1632 .call_failed(room_id, called_user_id)
1633 .await?;
1634 room_updated(&room, &session.peer);
1635 }
1636 update_user_contacts(called_user_id, &session).await?;
1637
1638 Err(anyhow!("failed to ring user"))?
1639}
1640
1641/// Cancel an outgoing call.
1642async fn cancel_call(
1643 request: proto::CancelCall,
1644 response: Response<proto::CancelCall>,
1645 session: Session,
1646) -> Result<()> {
1647 let called_user_id = UserId::from_proto(request.called_user_id);
1648 let room_id = RoomId::from_proto(request.room_id);
1649 {
1650 let room = session
1651 .db()
1652 .await
1653 .cancel_call(room_id, session.connection_id, called_user_id)
1654 .await?;
1655 room_updated(&room, &session.peer);
1656 }
1657
1658 for connection_id in session
1659 .connection_pool()
1660 .await
1661 .user_connection_ids(called_user_id)
1662 {
1663 session
1664 .peer
1665 .send(
1666 connection_id,
1667 proto::CallCanceled {
1668 room_id: room_id.to_proto(),
1669 },
1670 )
1671 .trace_err();
1672 }
1673 response.send(proto::Ack {})?;
1674
1675 update_user_contacts(called_user_id, &session).await?;
1676 Ok(())
1677}
1678
1679/// Decline an incoming call.
1680async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1681 let room_id = RoomId::from_proto(message.room_id);
1682 {
1683 let room = session
1684 .db()
1685 .await
1686 .decline_call(Some(room_id), session.user_id())
1687 .await?
1688 .ok_or_else(|| anyhow!("failed to decline call"))?;
1689 room_updated(&room, &session.peer);
1690 }
1691
1692 for connection_id in session
1693 .connection_pool()
1694 .await
1695 .user_connection_ids(session.user_id())
1696 {
1697 session
1698 .peer
1699 .send(
1700 connection_id,
1701 proto::CallCanceled {
1702 room_id: room_id.to_proto(),
1703 },
1704 )
1705 .trace_err();
1706 }
1707 update_user_contacts(session.user_id(), &session).await?;
1708 Ok(())
1709}
1710
1711/// Updates other participants in the room with your current location.
1712async fn update_participant_location(
1713 request: proto::UpdateParticipantLocation,
1714 response: Response<proto::UpdateParticipantLocation>,
1715 session: Session,
1716) -> Result<()> {
1717 let room_id = RoomId::from_proto(request.room_id);
1718 let location = request
1719 .location
1720 .ok_or_else(|| anyhow!("invalid location"))?;
1721
1722 let db = session.db().await;
1723 let room = db
1724 .update_room_participant_location(room_id, session.connection_id, location)
1725 .await?;
1726
1727 room_updated(&room, &session.peer);
1728 response.send(proto::Ack {})?;
1729 Ok(())
1730}
1731
1732/// Share a project into the room.
1733async fn share_project(
1734 request: proto::ShareProject,
1735 response: Response<proto::ShareProject>,
1736 session: Session,
1737) -> Result<()> {
1738 let (project_id, room) = &*session
1739 .db()
1740 .await
1741 .share_project(
1742 RoomId::from_proto(request.room_id),
1743 session.connection_id,
1744 &request.worktrees,
1745 request.is_ssh_project,
1746 )
1747 .await?;
1748 response.send(proto::ShareProjectResponse {
1749 project_id: project_id.to_proto(),
1750 })?;
1751 room_updated(room, &session.peer);
1752
1753 Ok(())
1754}
1755
1756/// Unshare a project from the room.
1757async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1758 let project_id = ProjectId::from_proto(message.project_id);
1759 unshare_project_internal(project_id, session.connection_id, &session).await
1760}
1761
1762async fn unshare_project_internal(
1763 project_id: ProjectId,
1764 connection_id: ConnectionId,
1765 session: &Session,
1766) -> Result<()> {
1767 let delete = {
1768 let room_guard = session
1769 .db()
1770 .await
1771 .unshare_project(project_id, connection_id)
1772 .await?;
1773
1774 let (delete, room, guest_connection_ids) = &*room_guard;
1775
1776 let message = proto::UnshareProject {
1777 project_id: project_id.to_proto(),
1778 };
1779
1780 broadcast(
1781 Some(connection_id),
1782 guest_connection_ids.iter().copied(),
1783 |conn_id| session.peer.send(conn_id, message.clone()),
1784 );
1785 if let Some(room) = room {
1786 room_updated(room, &session.peer);
1787 }
1788
1789 *delete
1790 };
1791
1792 if delete {
1793 let db = session.db().await;
1794 db.delete_project(project_id).await?;
1795 }
1796
1797 Ok(())
1798}
1799
1800/// Join someone elses shared project.
1801async fn join_project(
1802 request: proto::JoinProject,
1803 response: Response<proto::JoinProject>,
1804 session: Session,
1805) -> Result<()> {
1806 let project_id = ProjectId::from_proto(request.project_id);
1807
1808 tracing::info!(%project_id, "join project");
1809
1810 let db = session.db().await;
1811 let (project, replica_id) = &mut *db
1812 .join_project(project_id, session.connection_id, session.user_id())
1813 .await?;
1814 drop(db);
1815 tracing::info!(%project_id, "join remote project");
1816 join_project_internal(response, session, project, replica_id)
1817}
1818
1819trait JoinProjectInternalResponse {
1820 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1821}
1822impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1823 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1824 Response::<proto::JoinProject>::send(self, result)
1825 }
1826}
1827
1828fn join_project_internal(
1829 response: impl JoinProjectInternalResponse,
1830 session: Session,
1831 project: &mut Project,
1832 replica_id: &ReplicaId,
1833) -> Result<()> {
1834 let collaborators = project
1835 .collaborators
1836 .iter()
1837 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1838 .map(|collaborator| collaborator.to_proto())
1839 .collect::<Vec<_>>();
1840 let project_id = project.id;
1841 let guest_user_id = session.user_id();
1842
1843 let worktrees = project
1844 .worktrees
1845 .iter()
1846 .map(|(id, worktree)| proto::WorktreeMetadata {
1847 id: *id,
1848 root_name: worktree.root_name.clone(),
1849 visible: worktree.visible,
1850 abs_path: worktree.abs_path.clone(),
1851 })
1852 .collect::<Vec<_>>();
1853
1854 let add_project_collaborator = proto::AddProjectCollaborator {
1855 project_id: project_id.to_proto(),
1856 collaborator: Some(proto::Collaborator {
1857 peer_id: Some(session.connection_id.into()),
1858 replica_id: replica_id.0 as u32,
1859 user_id: guest_user_id.to_proto(),
1860 is_host: false,
1861 }),
1862 };
1863
1864 for collaborator in &collaborators {
1865 session
1866 .peer
1867 .send(
1868 collaborator.peer_id.unwrap().into(),
1869 add_project_collaborator.clone(),
1870 )
1871 .trace_err();
1872 }
1873
1874 // First, we send the metadata associated with each worktree.
1875 response.send(proto::JoinProjectResponse {
1876 project_id: project.id.0 as u64,
1877 worktrees: worktrees.clone(),
1878 replica_id: replica_id.0 as u32,
1879 collaborators: collaborators.clone(),
1880 language_servers: project.language_servers.clone(),
1881 role: project.role.into(),
1882 })?;
1883
1884 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1885 // Stream this worktree's entries.
1886 let message = proto::UpdateWorktree {
1887 project_id: project_id.to_proto(),
1888 worktree_id,
1889 abs_path: worktree.abs_path.clone(),
1890 root_name: worktree.root_name,
1891 updated_entries: worktree.entries,
1892 removed_entries: Default::default(),
1893 scan_id: worktree.scan_id,
1894 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1895 updated_repositories: worktree.repository_entries.into_values().collect(),
1896 removed_repositories: Default::default(),
1897 };
1898 for update in proto::split_worktree_update(message) {
1899 session.peer.send(session.connection_id, update.clone())?;
1900 }
1901
1902 // Stream this worktree's diagnostics.
1903 for summary in worktree.diagnostic_summaries {
1904 session.peer.send(
1905 session.connection_id,
1906 proto::UpdateDiagnosticSummary {
1907 project_id: project_id.to_proto(),
1908 worktree_id: worktree.id,
1909 summary: Some(summary),
1910 },
1911 )?;
1912 }
1913
1914 for settings_file in worktree.settings_files {
1915 session.peer.send(
1916 session.connection_id,
1917 proto::UpdateWorktreeSettings {
1918 project_id: project_id.to_proto(),
1919 worktree_id: worktree.id,
1920 path: settings_file.path,
1921 content: Some(settings_file.content),
1922 kind: Some(settings_file.kind.to_proto() as i32),
1923 },
1924 )?;
1925 }
1926 }
1927
1928 for language_server in &project.language_servers {
1929 session.peer.send(
1930 session.connection_id,
1931 proto::UpdateLanguageServer {
1932 project_id: project_id.to_proto(),
1933 language_server_id: language_server.id,
1934 variant: Some(
1935 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1936 proto::LspDiskBasedDiagnosticsUpdated {},
1937 ),
1938 ),
1939 },
1940 )?;
1941 }
1942
1943 Ok(())
1944}
1945
1946/// Leave someone elses shared project.
1947async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1948 let sender_id = session.connection_id;
1949 let project_id = ProjectId::from_proto(request.project_id);
1950 let db = session.db().await;
1951
1952 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1953 tracing::info!(
1954 %project_id,
1955 "leave project"
1956 );
1957
1958 project_left(project, &session);
1959 if let Some(room) = room {
1960 room_updated(room, &session.peer);
1961 }
1962
1963 Ok(())
1964}
1965
1966/// Updates other participants with changes to the project
1967async fn update_project(
1968 request: proto::UpdateProject,
1969 response: Response<proto::UpdateProject>,
1970 session: Session,
1971) -> Result<()> {
1972 let project_id = ProjectId::from_proto(request.project_id);
1973 let (room, guest_connection_ids) = &*session
1974 .db()
1975 .await
1976 .update_project(project_id, session.connection_id, &request.worktrees)
1977 .await?;
1978 broadcast(
1979 Some(session.connection_id),
1980 guest_connection_ids.iter().copied(),
1981 |connection_id| {
1982 session
1983 .peer
1984 .forward_send(session.connection_id, connection_id, request.clone())
1985 },
1986 );
1987 if let Some(room) = room {
1988 room_updated(room, &session.peer);
1989 }
1990 response.send(proto::Ack {})?;
1991
1992 Ok(())
1993}
1994
1995/// Updates other participants with changes to the worktree
1996async fn update_worktree(
1997 request: proto::UpdateWorktree,
1998 response: Response<proto::UpdateWorktree>,
1999 session: Session,
2000) -> Result<()> {
2001 let guest_connection_ids = session
2002 .db()
2003 .await
2004 .update_worktree(&request, session.connection_id)
2005 .await?;
2006
2007 broadcast(
2008 Some(session.connection_id),
2009 guest_connection_ids.iter().copied(),
2010 |connection_id| {
2011 session
2012 .peer
2013 .forward_send(session.connection_id, connection_id, request.clone())
2014 },
2015 );
2016 response.send(proto::Ack {})?;
2017 Ok(())
2018}
2019
2020/// Updates other participants with changes to the diagnostics
2021async fn update_diagnostic_summary(
2022 message: proto::UpdateDiagnosticSummary,
2023 session: Session,
2024) -> Result<()> {
2025 let guest_connection_ids = session
2026 .db()
2027 .await
2028 .update_diagnostic_summary(&message, session.connection_id)
2029 .await?;
2030
2031 broadcast(
2032 Some(session.connection_id),
2033 guest_connection_ids.iter().copied(),
2034 |connection_id| {
2035 session
2036 .peer
2037 .forward_send(session.connection_id, connection_id, message.clone())
2038 },
2039 );
2040
2041 Ok(())
2042}
2043
2044/// Updates other participants with changes to the worktree settings
2045async fn update_worktree_settings(
2046 message: proto::UpdateWorktreeSettings,
2047 session: Session,
2048) -> Result<()> {
2049 let guest_connection_ids = session
2050 .db()
2051 .await
2052 .update_worktree_settings(&message, session.connection_id)
2053 .await?;
2054
2055 broadcast(
2056 Some(session.connection_id),
2057 guest_connection_ids.iter().copied(),
2058 |connection_id| {
2059 session
2060 .peer
2061 .forward_send(session.connection_id, connection_id, message.clone())
2062 },
2063 );
2064
2065 Ok(())
2066}
2067
2068/// Notify other participants that a language server has started.
2069async fn start_language_server(
2070 request: proto::StartLanguageServer,
2071 session: Session,
2072) -> Result<()> {
2073 let guest_connection_ids = session
2074 .db()
2075 .await
2076 .start_language_server(&request, session.connection_id)
2077 .await?;
2078
2079 broadcast(
2080 Some(session.connection_id),
2081 guest_connection_ids.iter().copied(),
2082 |connection_id| {
2083 session
2084 .peer
2085 .forward_send(session.connection_id, connection_id, request.clone())
2086 },
2087 );
2088 Ok(())
2089}
2090
2091/// Notify other participants that a language server has changed.
2092async fn update_language_server(
2093 request: proto::UpdateLanguageServer,
2094 session: Session,
2095) -> Result<()> {
2096 let project_id = ProjectId::from_proto(request.project_id);
2097 let project_connection_ids = session
2098 .db()
2099 .await
2100 .project_connection_ids(project_id, session.connection_id, true)
2101 .await?;
2102 broadcast(
2103 Some(session.connection_id),
2104 project_connection_ids.iter().copied(),
2105 |connection_id| {
2106 session
2107 .peer
2108 .forward_send(session.connection_id, connection_id, request.clone())
2109 },
2110 );
2111 Ok(())
2112}
2113
2114/// forward a project request to the host. These requests should be read only
2115/// as guests are allowed to send them.
2116async fn forward_read_only_project_request<T>(
2117 request: T,
2118 response: Response<T>,
2119 session: Session,
2120) -> Result<()>
2121where
2122 T: EntityMessage + RequestMessage,
2123{
2124 let project_id = ProjectId::from_proto(request.remote_entity_id());
2125 let host_connection_id = session
2126 .db()
2127 .await
2128 .host_for_read_only_project_request(project_id, session.connection_id)
2129 .await?;
2130 let payload = session
2131 .peer
2132 .forward_request(session.connection_id, host_connection_id, request)
2133 .await?;
2134 response.send(payload)?;
2135 Ok(())
2136}
2137
2138async fn forward_find_search_candidates_request(
2139 request: proto::FindSearchCandidates,
2140 response: Response<proto::FindSearchCandidates>,
2141 session: Session,
2142) -> Result<()> {
2143 let project_id = ProjectId::from_proto(request.remote_entity_id());
2144 let host_connection_id = session
2145 .db()
2146 .await
2147 .host_for_read_only_project_request(project_id, session.connection_id)
2148 .await?;
2149 let payload = session
2150 .peer
2151 .forward_request(session.connection_id, host_connection_id, request)
2152 .await?;
2153 response.send(payload)?;
2154 Ok(())
2155}
2156
2157/// forward a project request to the host. These requests are disallowed
2158/// for guests.
2159async fn forward_mutating_project_request<T>(
2160 request: T,
2161 response: Response<T>,
2162 session: Session,
2163) -> Result<()>
2164where
2165 T: EntityMessage + RequestMessage,
2166{
2167 let project_id = ProjectId::from_proto(request.remote_entity_id());
2168
2169 let host_connection_id = session
2170 .db()
2171 .await
2172 .host_for_mutating_project_request(project_id, session.connection_id)
2173 .await?;
2174 let payload = session
2175 .peer
2176 .forward_request(session.connection_id, host_connection_id, request)
2177 .await?;
2178 response.send(payload)?;
2179 Ok(())
2180}
2181
2182/// Notify other participants that a new buffer has been created
2183async fn create_buffer_for_peer(
2184 request: proto::CreateBufferForPeer,
2185 session: Session,
2186) -> Result<()> {
2187 session
2188 .db()
2189 .await
2190 .check_user_is_project_host(
2191 ProjectId::from_proto(request.project_id),
2192 session.connection_id,
2193 )
2194 .await?;
2195 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2196 session
2197 .peer
2198 .forward_send(session.connection_id, peer_id.into(), request)?;
2199 Ok(())
2200}
2201
2202/// Notify other participants that a buffer has been updated. This is
2203/// allowed for guests as long as the update is limited to selections.
2204async fn update_buffer(
2205 request: proto::UpdateBuffer,
2206 response: Response<proto::UpdateBuffer>,
2207 session: Session,
2208) -> Result<()> {
2209 let project_id = ProjectId::from_proto(request.project_id);
2210 let mut capability = Capability::ReadOnly;
2211
2212 for op in request.operations.iter() {
2213 match op.variant {
2214 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2215 Some(_) => capability = Capability::ReadWrite,
2216 }
2217 }
2218
2219 let host = {
2220 let guard = session
2221 .db()
2222 .await
2223 .connections_for_buffer_update(project_id, session.connection_id, capability)
2224 .await?;
2225
2226 let (host, guests) = &*guard;
2227
2228 broadcast(
2229 Some(session.connection_id),
2230 guests.clone(),
2231 |connection_id| {
2232 session
2233 .peer
2234 .forward_send(session.connection_id, connection_id, request.clone())
2235 },
2236 );
2237
2238 *host
2239 };
2240
2241 if host != session.connection_id {
2242 session
2243 .peer
2244 .forward_request(session.connection_id, host, request.clone())
2245 .await?;
2246 }
2247
2248 response.send(proto::Ack {})?;
2249 Ok(())
2250}
2251
2252async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2253 let project_id = ProjectId::from_proto(message.project_id);
2254
2255 let operation = message.operation.as_ref().context("invalid operation")?;
2256 let capability = match operation.variant.as_ref() {
2257 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2258 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2259 match buffer_op.variant {
2260 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2261 Capability::ReadOnly
2262 }
2263 _ => Capability::ReadWrite,
2264 }
2265 } else {
2266 Capability::ReadWrite
2267 }
2268 }
2269 Some(_) => Capability::ReadWrite,
2270 None => Capability::ReadOnly,
2271 };
2272
2273 let guard = session
2274 .db()
2275 .await
2276 .connections_for_buffer_update(project_id, session.connection_id, capability)
2277 .await?;
2278
2279 let (host, guests) = &*guard;
2280
2281 broadcast(
2282 Some(session.connection_id),
2283 guests.iter().chain([host]).copied(),
2284 |connection_id| {
2285 session
2286 .peer
2287 .forward_send(session.connection_id, connection_id, message.clone())
2288 },
2289 );
2290
2291 Ok(())
2292}
2293
2294/// Notify other participants that a project has been updated.
2295async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2296 request: T,
2297 session: Session,
2298) -> Result<()> {
2299 let project_id = ProjectId::from_proto(request.remote_entity_id());
2300 let project_connection_ids = session
2301 .db()
2302 .await
2303 .project_connection_ids(project_id, session.connection_id, false)
2304 .await?;
2305
2306 broadcast(
2307 Some(session.connection_id),
2308 project_connection_ids.iter().copied(),
2309 |connection_id| {
2310 session
2311 .peer
2312 .forward_send(session.connection_id, connection_id, request.clone())
2313 },
2314 );
2315 Ok(())
2316}
2317
2318/// Start following another user in a call.
2319async fn follow(
2320 request: proto::Follow,
2321 response: Response<proto::Follow>,
2322 session: Session,
2323) -> Result<()> {
2324 let room_id = RoomId::from_proto(request.room_id);
2325 let project_id = request.project_id.map(ProjectId::from_proto);
2326 let leader_id = request
2327 .leader_id
2328 .ok_or_else(|| anyhow!("invalid leader id"))?
2329 .into();
2330 let follower_id = session.connection_id;
2331
2332 session
2333 .db()
2334 .await
2335 .check_room_participants(room_id, leader_id, session.connection_id)
2336 .await?;
2337
2338 let response_payload = session
2339 .peer
2340 .forward_request(session.connection_id, leader_id, request)
2341 .await?;
2342 response.send(response_payload)?;
2343
2344 if let Some(project_id) = project_id {
2345 let room = session
2346 .db()
2347 .await
2348 .follow(room_id, project_id, leader_id, follower_id)
2349 .await?;
2350 room_updated(&room, &session.peer);
2351 }
2352
2353 Ok(())
2354}
2355
2356/// Stop following another user in a call.
2357async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2358 let room_id = RoomId::from_proto(request.room_id);
2359 let project_id = request.project_id.map(ProjectId::from_proto);
2360 let leader_id = request
2361 .leader_id
2362 .ok_or_else(|| anyhow!("invalid leader id"))?
2363 .into();
2364 let follower_id = session.connection_id;
2365
2366 session
2367 .db()
2368 .await
2369 .check_room_participants(room_id, leader_id, session.connection_id)
2370 .await?;
2371
2372 session
2373 .peer
2374 .forward_send(session.connection_id, leader_id, request)?;
2375
2376 if let Some(project_id) = project_id {
2377 let room = session
2378 .db()
2379 .await
2380 .unfollow(room_id, project_id, leader_id, follower_id)
2381 .await?;
2382 room_updated(&room, &session.peer);
2383 }
2384
2385 Ok(())
2386}
2387
2388/// Notify everyone following you of your current location.
2389async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2390 let room_id = RoomId::from_proto(request.room_id);
2391 let database = session.db.lock().await;
2392
2393 let connection_ids = if let Some(project_id) = request.project_id {
2394 let project_id = ProjectId::from_proto(project_id);
2395 database
2396 .project_connection_ids(project_id, session.connection_id, true)
2397 .await?
2398 } else {
2399 database
2400 .room_connection_ids(room_id, session.connection_id)
2401 .await?
2402 };
2403
2404 // For now, don't send view update messages back to that view's current leader.
2405 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2406 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2407 _ => None,
2408 });
2409
2410 for connection_id in connection_ids.iter().cloned() {
2411 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2412 session
2413 .peer
2414 .forward_send(session.connection_id, connection_id, request.clone())?;
2415 }
2416 }
2417 Ok(())
2418}
2419
2420/// Get public data about users.
2421async fn get_users(
2422 request: proto::GetUsers,
2423 response: Response<proto::GetUsers>,
2424 session: Session,
2425) -> Result<()> {
2426 let user_ids = request
2427 .user_ids
2428 .into_iter()
2429 .map(UserId::from_proto)
2430 .collect();
2431 let users = session
2432 .db()
2433 .await
2434 .get_users_by_ids(user_ids)
2435 .await?
2436 .into_iter()
2437 .map(|user| proto::User {
2438 id: user.id.to_proto(),
2439 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2440 github_login: user.github_login,
2441 email: user.email_address,
2442 name: user.name,
2443 })
2444 .collect();
2445 response.send(proto::UsersResponse { users })?;
2446 Ok(())
2447}
2448
2449/// Search for users (to invite) buy Github login
2450async fn fuzzy_search_users(
2451 request: proto::FuzzySearchUsers,
2452 response: Response<proto::FuzzySearchUsers>,
2453 session: Session,
2454) -> Result<()> {
2455 let query = request.query;
2456 let users = match query.len() {
2457 0 => vec![],
2458 1 | 2 => session
2459 .db()
2460 .await
2461 .get_user_by_github_login(&query)
2462 .await?
2463 .into_iter()
2464 .collect(),
2465 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2466 };
2467 let users = users
2468 .into_iter()
2469 .filter(|user| user.id != session.user_id())
2470 .map(|user| proto::User {
2471 id: user.id.to_proto(),
2472 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2473 github_login: user.github_login,
2474 name: user.name,
2475 email: user.email_address,
2476 })
2477 .collect();
2478 response.send(proto::UsersResponse { users })?;
2479 Ok(())
2480}
2481
2482/// Send a contact request to another user.
2483async fn request_contact(
2484 request: proto::RequestContact,
2485 response: Response<proto::RequestContact>,
2486 session: Session,
2487) -> Result<()> {
2488 let requester_id = session.user_id();
2489 let responder_id = UserId::from_proto(request.responder_id);
2490 if requester_id == responder_id {
2491 return Err(anyhow!("cannot add yourself as a contact"))?;
2492 }
2493
2494 let notifications = session
2495 .db()
2496 .await
2497 .send_contact_request(requester_id, responder_id)
2498 .await?;
2499
2500 // Update outgoing contact requests of requester
2501 let mut update = proto::UpdateContacts::default();
2502 update.outgoing_requests.push(responder_id.to_proto());
2503 for connection_id in session
2504 .connection_pool()
2505 .await
2506 .user_connection_ids(requester_id)
2507 {
2508 session.peer.send(connection_id, update.clone())?;
2509 }
2510
2511 // Update incoming contact requests of responder
2512 let mut update = proto::UpdateContacts::default();
2513 update
2514 .incoming_requests
2515 .push(proto::IncomingContactRequest {
2516 requester_id: requester_id.to_proto(),
2517 });
2518 let connection_pool = session.connection_pool().await;
2519 for connection_id in connection_pool.user_connection_ids(responder_id) {
2520 session.peer.send(connection_id, update.clone())?;
2521 }
2522
2523 send_notifications(&connection_pool, &session.peer, notifications);
2524
2525 response.send(proto::Ack {})?;
2526 Ok(())
2527}
2528
2529/// Accept or decline a contact request
2530async fn respond_to_contact_request(
2531 request: proto::RespondToContactRequest,
2532 response: Response<proto::RespondToContactRequest>,
2533 session: Session,
2534) -> Result<()> {
2535 let responder_id = session.user_id();
2536 let requester_id = UserId::from_proto(request.requester_id);
2537 let db = session.db().await;
2538 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2539 db.dismiss_contact_notification(responder_id, requester_id)
2540 .await?;
2541 } else {
2542 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2543
2544 let notifications = db
2545 .respond_to_contact_request(responder_id, requester_id, accept)
2546 .await?;
2547 let requester_busy = db.is_user_busy(requester_id).await?;
2548 let responder_busy = db.is_user_busy(responder_id).await?;
2549
2550 let pool = session.connection_pool().await;
2551 // Update responder with new contact
2552 let mut update = proto::UpdateContacts::default();
2553 if accept {
2554 update
2555 .contacts
2556 .push(contact_for_user(requester_id, requester_busy, &pool));
2557 }
2558 update
2559 .remove_incoming_requests
2560 .push(requester_id.to_proto());
2561 for connection_id in pool.user_connection_ids(responder_id) {
2562 session.peer.send(connection_id, update.clone())?;
2563 }
2564
2565 // Update requester with new contact
2566 let mut update = proto::UpdateContacts::default();
2567 if accept {
2568 update
2569 .contacts
2570 .push(contact_for_user(responder_id, responder_busy, &pool));
2571 }
2572 update
2573 .remove_outgoing_requests
2574 .push(responder_id.to_proto());
2575
2576 for connection_id in pool.user_connection_ids(requester_id) {
2577 session.peer.send(connection_id, update.clone())?;
2578 }
2579
2580 send_notifications(&pool, &session.peer, notifications);
2581 }
2582
2583 response.send(proto::Ack {})?;
2584 Ok(())
2585}
2586
2587/// Remove a contact.
2588async fn remove_contact(
2589 request: proto::RemoveContact,
2590 response: Response<proto::RemoveContact>,
2591 session: Session,
2592) -> Result<()> {
2593 let requester_id = session.user_id();
2594 let responder_id = UserId::from_proto(request.user_id);
2595 let db = session.db().await;
2596 let (contact_accepted, deleted_notification_id) =
2597 db.remove_contact(requester_id, responder_id).await?;
2598
2599 let pool = session.connection_pool().await;
2600 // Update outgoing contact requests of requester
2601 let mut update = proto::UpdateContacts::default();
2602 if contact_accepted {
2603 update.remove_contacts.push(responder_id.to_proto());
2604 } else {
2605 update
2606 .remove_outgoing_requests
2607 .push(responder_id.to_proto());
2608 }
2609 for connection_id in pool.user_connection_ids(requester_id) {
2610 session.peer.send(connection_id, update.clone())?;
2611 }
2612
2613 // Update incoming contact requests of responder
2614 let mut update = proto::UpdateContacts::default();
2615 if contact_accepted {
2616 update.remove_contacts.push(requester_id.to_proto());
2617 } else {
2618 update
2619 .remove_incoming_requests
2620 .push(requester_id.to_proto());
2621 }
2622 for connection_id in pool.user_connection_ids(responder_id) {
2623 session.peer.send(connection_id, update.clone())?;
2624 if let Some(notification_id) = deleted_notification_id {
2625 session.peer.send(
2626 connection_id,
2627 proto::DeleteNotification {
2628 notification_id: notification_id.to_proto(),
2629 },
2630 )?;
2631 }
2632 }
2633
2634 response.send(proto::Ack {})?;
2635 Ok(())
2636}
2637
2638fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2639 version.0.minor() < 139
2640}
2641
2642async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2643 let plan = session.current_plan(&session.db().await).await?;
2644
2645 session
2646 .peer
2647 .send(
2648 session.connection_id,
2649 proto::UpdateUserPlan { plan: plan.into() },
2650 )
2651 .trace_err();
2652
2653 Ok(())
2654}
2655
2656async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2657 subscribe_user_to_channels(session.user_id(), &session).await?;
2658 Ok(())
2659}
2660
2661async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2662 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2663 let mut pool = session.connection_pool().await;
2664 for membership in &channels_for_user.channel_memberships {
2665 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2666 }
2667 session.peer.send(
2668 session.connection_id,
2669 build_update_user_channels(&channels_for_user),
2670 )?;
2671 session.peer.send(
2672 session.connection_id,
2673 build_channels_update(channels_for_user),
2674 )?;
2675 Ok(())
2676}
2677
2678/// Creates a new channel.
2679async fn create_channel(
2680 request: proto::CreateChannel,
2681 response: Response<proto::CreateChannel>,
2682 session: Session,
2683) -> Result<()> {
2684 let db = session.db().await;
2685
2686 let parent_id = request.parent_id.map(ChannelId::from_proto);
2687 let (channel, membership) = db
2688 .create_channel(&request.name, parent_id, session.user_id())
2689 .await?;
2690
2691 let root_id = channel.root_id();
2692 let channel = Channel::from_model(channel);
2693
2694 response.send(proto::CreateChannelResponse {
2695 channel: Some(channel.to_proto()),
2696 parent_id: request.parent_id,
2697 })?;
2698
2699 let mut connection_pool = session.connection_pool().await;
2700 if let Some(membership) = membership {
2701 connection_pool.subscribe_to_channel(
2702 membership.user_id,
2703 membership.channel_id,
2704 membership.role,
2705 );
2706 let update = proto::UpdateUserChannels {
2707 channel_memberships: vec![proto::ChannelMembership {
2708 channel_id: membership.channel_id.to_proto(),
2709 role: membership.role.into(),
2710 }],
2711 ..Default::default()
2712 };
2713 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2714 session.peer.send(connection_id, update.clone())?;
2715 }
2716 }
2717
2718 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2719 if !role.can_see_channel(channel.visibility) {
2720 continue;
2721 }
2722
2723 let update = proto::UpdateChannels {
2724 channels: vec![channel.to_proto()],
2725 ..Default::default()
2726 };
2727 session.peer.send(connection_id, update.clone())?;
2728 }
2729
2730 Ok(())
2731}
2732
2733/// Delete a channel
2734async fn delete_channel(
2735 request: proto::DeleteChannel,
2736 response: Response<proto::DeleteChannel>,
2737 session: Session,
2738) -> Result<()> {
2739 let db = session.db().await;
2740
2741 let channel_id = request.channel_id;
2742 let (root_channel, removed_channels) = db
2743 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2744 .await?;
2745 response.send(proto::Ack {})?;
2746
2747 // Notify members of removed channels
2748 let mut update = proto::UpdateChannels::default();
2749 update
2750 .delete_channels
2751 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2752
2753 let connection_pool = session.connection_pool().await;
2754 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2755 session.peer.send(connection_id, update.clone())?;
2756 }
2757
2758 Ok(())
2759}
2760
2761/// Invite someone to join a channel.
2762async fn invite_channel_member(
2763 request: proto::InviteChannelMember,
2764 response: Response<proto::InviteChannelMember>,
2765 session: Session,
2766) -> Result<()> {
2767 let db = session.db().await;
2768 let channel_id = ChannelId::from_proto(request.channel_id);
2769 let invitee_id = UserId::from_proto(request.user_id);
2770 let InviteMemberResult {
2771 channel,
2772 notifications,
2773 } = db
2774 .invite_channel_member(
2775 channel_id,
2776 invitee_id,
2777 session.user_id(),
2778 request.role().into(),
2779 )
2780 .await?;
2781
2782 let update = proto::UpdateChannels {
2783 channel_invitations: vec![channel.to_proto()],
2784 ..Default::default()
2785 };
2786
2787 let connection_pool = session.connection_pool().await;
2788 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2789 session.peer.send(connection_id, update.clone())?;
2790 }
2791
2792 send_notifications(&connection_pool, &session.peer, notifications);
2793
2794 response.send(proto::Ack {})?;
2795 Ok(())
2796}
2797
2798/// remove someone from a channel
2799async fn remove_channel_member(
2800 request: proto::RemoveChannelMember,
2801 response: Response<proto::RemoveChannelMember>,
2802 session: Session,
2803) -> Result<()> {
2804 let db = session.db().await;
2805 let channel_id = ChannelId::from_proto(request.channel_id);
2806 let member_id = UserId::from_proto(request.user_id);
2807
2808 let RemoveChannelMemberResult {
2809 membership_update,
2810 notification_id,
2811 } = db
2812 .remove_channel_member(channel_id, member_id, session.user_id())
2813 .await?;
2814
2815 let mut connection_pool = session.connection_pool().await;
2816 notify_membership_updated(
2817 &mut connection_pool,
2818 membership_update,
2819 member_id,
2820 &session.peer,
2821 );
2822 for connection_id in connection_pool.user_connection_ids(member_id) {
2823 if let Some(notification_id) = notification_id {
2824 session
2825 .peer
2826 .send(
2827 connection_id,
2828 proto::DeleteNotification {
2829 notification_id: notification_id.to_proto(),
2830 },
2831 )
2832 .trace_err();
2833 }
2834 }
2835
2836 response.send(proto::Ack {})?;
2837 Ok(())
2838}
2839
2840/// Toggle the channel between public and private.
2841/// Care is taken to maintain the invariant that public channels only descend from public channels,
2842/// (though members-only channels can appear at any point in the hierarchy).
2843async fn set_channel_visibility(
2844 request: proto::SetChannelVisibility,
2845 response: Response<proto::SetChannelVisibility>,
2846 session: Session,
2847) -> Result<()> {
2848 let db = session.db().await;
2849 let channel_id = ChannelId::from_proto(request.channel_id);
2850 let visibility = request.visibility().into();
2851
2852 let channel_model = db
2853 .set_channel_visibility(channel_id, visibility, session.user_id())
2854 .await?;
2855 let root_id = channel_model.root_id();
2856 let channel = Channel::from_model(channel_model);
2857
2858 let mut connection_pool = session.connection_pool().await;
2859 for (user_id, role) in connection_pool
2860 .channel_user_ids(root_id)
2861 .collect::<Vec<_>>()
2862 .into_iter()
2863 {
2864 let update = if role.can_see_channel(channel.visibility) {
2865 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2866 proto::UpdateChannels {
2867 channels: vec![channel.to_proto()],
2868 ..Default::default()
2869 }
2870 } else {
2871 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2872 proto::UpdateChannels {
2873 delete_channels: vec![channel.id.to_proto()],
2874 ..Default::default()
2875 }
2876 };
2877
2878 for connection_id in connection_pool.user_connection_ids(user_id) {
2879 session.peer.send(connection_id, update.clone())?;
2880 }
2881 }
2882
2883 response.send(proto::Ack {})?;
2884 Ok(())
2885}
2886
2887/// Alter the role for a user in the channel.
2888async fn set_channel_member_role(
2889 request: proto::SetChannelMemberRole,
2890 response: Response<proto::SetChannelMemberRole>,
2891 session: Session,
2892) -> Result<()> {
2893 let db = session.db().await;
2894 let channel_id = ChannelId::from_proto(request.channel_id);
2895 let member_id = UserId::from_proto(request.user_id);
2896 let result = db
2897 .set_channel_member_role(
2898 channel_id,
2899 session.user_id(),
2900 member_id,
2901 request.role().into(),
2902 )
2903 .await?;
2904
2905 match result {
2906 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2907 let mut connection_pool = session.connection_pool().await;
2908 notify_membership_updated(
2909 &mut connection_pool,
2910 membership_update,
2911 member_id,
2912 &session.peer,
2913 )
2914 }
2915 db::SetMemberRoleResult::InviteUpdated(channel) => {
2916 let update = proto::UpdateChannels {
2917 channel_invitations: vec![channel.to_proto()],
2918 ..Default::default()
2919 };
2920
2921 for connection_id in session
2922 .connection_pool()
2923 .await
2924 .user_connection_ids(member_id)
2925 {
2926 session.peer.send(connection_id, update.clone())?;
2927 }
2928 }
2929 }
2930
2931 response.send(proto::Ack {})?;
2932 Ok(())
2933}
2934
2935/// Change the name of a channel
2936async fn rename_channel(
2937 request: proto::RenameChannel,
2938 response: Response<proto::RenameChannel>,
2939 session: Session,
2940) -> Result<()> {
2941 let db = session.db().await;
2942 let channel_id = ChannelId::from_proto(request.channel_id);
2943 let channel_model = db
2944 .rename_channel(channel_id, session.user_id(), &request.name)
2945 .await?;
2946 let root_id = channel_model.root_id();
2947 let channel = Channel::from_model(channel_model);
2948
2949 response.send(proto::RenameChannelResponse {
2950 channel: Some(channel.to_proto()),
2951 })?;
2952
2953 let connection_pool = session.connection_pool().await;
2954 let update = proto::UpdateChannels {
2955 channels: vec![channel.to_proto()],
2956 ..Default::default()
2957 };
2958 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2959 if role.can_see_channel(channel.visibility) {
2960 session.peer.send(connection_id, update.clone())?;
2961 }
2962 }
2963
2964 Ok(())
2965}
2966
2967/// Move a channel to a new parent.
2968async fn move_channel(
2969 request: proto::MoveChannel,
2970 response: Response<proto::MoveChannel>,
2971 session: Session,
2972) -> Result<()> {
2973 let channel_id = ChannelId::from_proto(request.channel_id);
2974 let to = ChannelId::from_proto(request.to);
2975
2976 let (root_id, channels) = session
2977 .db()
2978 .await
2979 .move_channel(channel_id, to, session.user_id())
2980 .await?;
2981
2982 let connection_pool = session.connection_pool().await;
2983 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2984 let channels = channels
2985 .iter()
2986 .filter_map(|channel| {
2987 if role.can_see_channel(channel.visibility) {
2988 Some(channel.to_proto())
2989 } else {
2990 None
2991 }
2992 })
2993 .collect::<Vec<_>>();
2994 if channels.is_empty() {
2995 continue;
2996 }
2997
2998 let update = proto::UpdateChannels {
2999 channels,
3000 ..Default::default()
3001 };
3002
3003 session.peer.send(connection_id, update.clone())?;
3004 }
3005
3006 response.send(Ack {})?;
3007 Ok(())
3008}
3009
3010/// Get the list of channel members
3011async fn get_channel_members(
3012 request: proto::GetChannelMembers,
3013 response: Response<proto::GetChannelMembers>,
3014 session: Session,
3015) -> Result<()> {
3016 let db = session.db().await;
3017 let channel_id = ChannelId::from_proto(request.channel_id);
3018 let limit = if request.limit == 0 {
3019 u16::MAX as u64
3020 } else {
3021 request.limit
3022 };
3023 let (members, users) = db
3024 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3025 .await?;
3026 response.send(proto::GetChannelMembersResponse { members, users })?;
3027 Ok(())
3028}
3029
3030/// Accept or decline a channel invitation.
3031async fn respond_to_channel_invite(
3032 request: proto::RespondToChannelInvite,
3033 response: Response<proto::RespondToChannelInvite>,
3034 session: Session,
3035) -> Result<()> {
3036 let db = session.db().await;
3037 let channel_id = ChannelId::from_proto(request.channel_id);
3038 let RespondToChannelInvite {
3039 membership_update,
3040 notifications,
3041 } = db
3042 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3043 .await?;
3044
3045 let mut connection_pool = session.connection_pool().await;
3046 if let Some(membership_update) = membership_update {
3047 notify_membership_updated(
3048 &mut connection_pool,
3049 membership_update,
3050 session.user_id(),
3051 &session.peer,
3052 );
3053 } else {
3054 let update = proto::UpdateChannels {
3055 remove_channel_invitations: vec![channel_id.to_proto()],
3056 ..Default::default()
3057 };
3058
3059 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3060 session.peer.send(connection_id, update.clone())?;
3061 }
3062 };
3063
3064 send_notifications(&connection_pool, &session.peer, notifications);
3065
3066 response.send(proto::Ack {})?;
3067
3068 Ok(())
3069}
3070
3071/// Join the channels' room
3072async fn join_channel(
3073 request: proto::JoinChannel,
3074 response: Response<proto::JoinChannel>,
3075 session: Session,
3076) -> Result<()> {
3077 let channel_id = ChannelId::from_proto(request.channel_id);
3078 join_channel_internal(channel_id, Box::new(response), session).await
3079}
3080
3081trait JoinChannelInternalResponse {
3082 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3083}
3084impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3085 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3086 Response::<proto::JoinChannel>::send(self, result)
3087 }
3088}
3089impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3090 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3091 Response::<proto::JoinRoom>::send(self, result)
3092 }
3093}
3094
3095async fn join_channel_internal(
3096 channel_id: ChannelId,
3097 response: Box<impl JoinChannelInternalResponse>,
3098 session: Session,
3099) -> Result<()> {
3100 let joined_room = {
3101 let mut db = session.db().await;
3102 // If zed quits without leaving the room, and the user re-opens zed before the
3103 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3104 // room they were in.
3105 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3106 tracing::info!(
3107 stale_connection_id = %connection,
3108 "cleaning up stale connection",
3109 );
3110 drop(db);
3111 leave_room_for_session(&session, connection).await?;
3112 db = session.db().await;
3113 }
3114
3115 let (joined_room, membership_updated, role) = db
3116 .join_channel(channel_id, session.user_id(), session.connection_id)
3117 .await?;
3118
3119 let live_kit_connection_info =
3120 session
3121 .app_state
3122 .livekit_client
3123 .as_ref()
3124 .and_then(|live_kit| {
3125 let (can_publish, token) = if role == ChannelRole::Guest {
3126 (
3127 false,
3128 live_kit
3129 .guest_token(
3130 &joined_room.room.livekit_room,
3131 &session.user_id().to_string(),
3132 )
3133 .trace_err()?,
3134 )
3135 } else {
3136 (
3137 true,
3138 live_kit
3139 .room_token(
3140 &joined_room.room.livekit_room,
3141 &session.user_id().to_string(),
3142 )
3143 .trace_err()?,
3144 )
3145 };
3146
3147 Some(LiveKitConnectionInfo {
3148 server_url: live_kit.url().into(),
3149 token,
3150 can_publish,
3151 })
3152 });
3153
3154 response.send(proto::JoinRoomResponse {
3155 room: Some(joined_room.room.clone()),
3156 channel_id: joined_room
3157 .channel
3158 .as_ref()
3159 .map(|channel| channel.id.to_proto()),
3160 live_kit_connection_info,
3161 })?;
3162
3163 let mut connection_pool = session.connection_pool().await;
3164 if let Some(membership_updated) = membership_updated {
3165 notify_membership_updated(
3166 &mut connection_pool,
3167 membership_updated,
3168 session.user_id(),
3169 &session.peer,
3170 );
3171 }
3172
3173 room_updated(&joined_room.room, &session.peer);
3174
3175 joined_room
3176 };
3177
3178 channel_updated(
3179 &joined_room
3180 .channel
3181 .ok_or_else(|| anyhow!("channel not returned"))?,
3182 &joined_room.room,
3183 &session.peer,
3184 &*session.connection_pool().await,
3185 );
3186
3187 update_user_contacts(session.user_id(), &session).await?;
3188 Ok(())
3189}
3190
3191/// Start editing the channel notes
3192async fn join_channel_buffer(
3193 request: proto::JoinChannelBuffer,
3194 response: Response<proto::JoinChannelBuffer>,
3195 session: Session,
3196) -> Result<()> {
3197 let db = session.db().await;
3198 let channel_id = ChannelId::from_proto(request.channel_id);
3199
3200 let open_response = db
3201 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3202 .await?;
3203
3204 let collaborators = open_response.collaborators.clone();
3205 response.send(open_response)?;
3206
3207 let update = UpdateChannelBufferCollaborators {
3208 channel_id: channel_id.to_proto(),
3209 collaborators: collaborators.clone(),
3210 };
3211 channel_buffer_updated(
3212 session.connection_id,
3213 collaborators
3214 .iter()
3215 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3216 &update,
3217 &session.peer,
3218 );
3219
3220 Ok(())
3221}
3222
3223/// Edit the channel notes
3224async fn update_channel_buffer(
3225 request: proto::UpdateChannelBuffer,
3226 session: Session,
3227) -> Result<()> {
3228 let db = session.db().await;
3229 let channel_id = ChannelId::from_proto(request.channel_id);
3230
3231 let (collaborators, epoch, version) = db
3232 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3233 .await?;
3234
3235 channel_buffer_updated(
3236 session.connection_id,
3237 collaborators.clone(),
3238 &proto::UpdateChannelBuffer {
3239 channel_id: channel_id.to_proto(),
3240 operations: request.operations,
3241 },
3242 &session.peer,
3243 );
3244
3245 let pool = &*session.connection_pool().await;
3246
3247 let non_collaborators =
3248 pool.channel_connection_ids(channel_id)
3249 .filter_map(|(connection_id, _)| {
3250 if collaborators.contains(&connection_id) {
3251 None
3252 } else {
3253 Some(connection_id)
3254 }
3255 });
3256
3257 broadcast(None, non_collaborators, |peer_id| {
3258 session.peer.send(
3259 peer_id,
3260 proto::UpdateChannels {
3261 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3262 channel_id: channel_id.to_proto(),
3263 epoch: epoch as u64,
3264 version: version.clone(),
3265 }],
3266 ..Default::default()
3267 },
3268 )
3269 });
3270
3271 Ok(())
3272}
3273
3274/// Rejoin the channel notes after a connection blip
3275async fn rejoin_channel_buffers(
3276 request: proto::RejoinChannelBuffers,
3277 response: Response<proto::RejoinChannelBuffers>,
3278 session: Session,
3279) -> Result<()> {
3280 let db = session.db().await;
3281 let buffers = db
3282 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3283 .await?;
3284
3285 for rejoined_buffer in &buffers {
3286 let collaborators_to_notify = rejoined_buffer
3287 .buffer
3288 .collaborators
3289 .iter()
3290 .filter_map(|c| Some(c.peer_id?.into()));
3291 channel_buffer_updated(
3292 session.connection_id,
3293 collaborators_to_notify,
3294 &proto::UpdateChannelBufferCollaborators {
3295 channel_id: rejoined_buffer.buffer.channel_id,
3296 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3297 },
3298 &session.peer,
3299 );
3300 }
3301
3302 response.send(proto::RejoinChannelBuffersResponse {
3303 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3304 })?;
3305
3306 Ok(())
3307}
3308
3309/// Stop editing the channel notes
3310async fn leave_channel_buffer(
3311 request: proto::LeaveChannelBuffer,
3312 response: Response<proto::LeaveChannelBuffer>,
3313 session: Session,
3314) -> Result<()> {
3315 let db = session.db().await;
3316 let channel_id = ChannelId::from_proto(request.channel_id);
3317
3318 let left_buffer = db
3319 .leave_channel_buffer(channel_id, session.connection_id)
3320 .await?;
3321
3322 response.send(Ack {})?;
3323
3324 channel_buffer_updated(
3325 session.connection_id,
3326 left_buffer.connections,
3327 &proto::UpdateChannelBufferCollaborators {
3328 channel_id: channel_id.to_proto(),
3329 collaborators: left_buffer.collaborators,
3330 },
3331 &session.peer,
3332 );
3333
3334 Ok(())
3335}
3336
3337fn channel_buffer_updated<T: EnvelopedMessage>(
3338 sender_id: ConnectionId,
3339 collaborators: impl IntoIterator<Item = ConnectionId>,
3340 message: &T,
3341 peer: &Peer,
3342) {
3343 broadcast(Some(sender_id), collaborators, |peer_id| {
3344 peer.send(peer_id, message.clone())
3345 });
3346}
3347
3348fn send_notifications(
3349 connection_pool: &ConnectionPool,
3350 peer: &Peer,
3351 notifications: db::NotificationBatch,
3352) {
3353 for (user_id, notification) in notifications {
3354 for connection_id in connection_pool.user_connection_ids(user_id) {
3355 if let Err(error) = peer.send(
3356 connection_id,
3357 proto::AddNotification {
3358 notification: Some(notification.clone()),
3359 },
3360 ) {
3361 tracing::error!(
3362 "failed to send notification to {:?} {}",
3363 connection_id,
3364 error
3365 );
3366 }
3367 }
3368 }
3369}
3370
3371/// Send a message to the channel
3372async fn send_channel_message(
3373 request: proto::SendChannelMessage,
3374 response: Response<proto::SendChannelMessage>,
3375 session: Session,
3376) -> Result<()> {
3377 // Validate the message body.
3378 let body = request.body.trim().to_string();
3379 if body.len() > MAX_MESSAGE_LEN {
3380 return Err(anyhow!("message is too long"))?;
3381 }
3382 if body.is_empty() {
3383 return Err(anyhow!("message can't be blank"))?;
3384 }
3385
3386 // TODO: adjust mentions if body is trimmed
3387
3388 let timestamp = OffsetDateTime::now_utc();
3389 let nonce = request
3390 .nonce
3391 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3392
3393 let channel_id = ChannelId::from_proto(request.channel_id);
3394 let CreatedChannelMessage {
3395 message_id,
3396 participant_connection_ids,
3397 notifications,
3398 } = session
3399 .db()
3400 .await
3401 .create_channel_message(
3402 channel_id,
3403 session.user_id(),
3404 &body,
3405 &request.mentions,
3406 timestamp,
3407 nonce.clone().into(),
3408 request.reply_to_message_id.map(MessageId::from_proto),
3409 )
3410 .await?;
3411
3412 let message = proto::ChannelMessage {
3413 sender_id: session.user_id().to_proto(),
3414 id: message_id.to_proto(),
3415 body,
3416 mentions: request.mentions,
3417 timestamp: timestamp.unix_timestamp() as u64,
3418 nonce: Some(nonce),
3419 reply_to_message_id: request.reply_to_message_id,
3420 edited_at: None,
3421 };
3422 broadcast(
3423 Some(session.connection_id),
3424 participant_connection_ids.clone(),
3425 |connection| {
3426 session.peer.send(
3427 connection,
3428 proto::ChannelMessageSent {
3429 channel_id: channel_id.to_proto(),
3430 message: Some(message.clone()),
3431 },
3432 )
3433 },
3434 );
3435 response.send(proto::SendChannelMessageResponse {
3436 message: Some(message),
3437 })?;
3438
3439 let pool = &*session.connection_pool().await;
3440 let non_participants =
3441 pool.channel_connection_ids(channel_id)
3442 .filter_map(|(connection_id, _)| {
3443 if participant_connection_ids.contains(&connection_id) {
3444 None
3445 } else {
3446 Some(connection_id)
3447 }
3448 });
3449 broadcast(None, non_participants, |peer_id| {
3450 session.peer.send(
3451 peer_id,
3452 proto::UpdateChannels {
3453 latest_channel_message_ids: vec![proto::ChannelMessageId {
3454 channel_id: channel_id.to_proto(),
3455 message_id: message_id.to_proto(),
3456 }],
3457 ..Default::default()
3458 },
3459 )
3460 });
3461 send_notifications(pool, &session.peer, notifications);
3462
3463 Ok(())
3464}
3465
3466/// Delete a channel message
3467async fn remove_channel_message(
3468 request: proto::RemoveChannelMessage,
3469 response: Response<proto::RemoveChannelMessage>,
3470 session: Session,
3471) -> Result<()> {
3472 let channel_id = ChannelId::from_proto(request.channel_id);
3473 let message_id = MessageId::from_proto(request.message_id);
3474 let (connection_ids, existing_notification_ids) = session
3475 .db()
3476 .await
3477 .remove_channel_message(channel_id, message_id, session.user_id())
3478 .await?;
3479
3480 broadcast(
3481 Some(session.connection_id),
3482 connection_ids,
3483 move |connection| {
3484 session.peer.send(connection, request.clone())?;
3485
3486 for notification_id in &existing_notification_ids {
3487 session.peer.send(
3488 connection,
3489 proto::DeleteNotification {
3490 notification_id: (*notification_id).to_proto(),
3491 },
3492 )?;
3493 }
3494
3495 Ok(())
3496 },
3497 );
3498 response.send(proto::Ack {})?;
3499 Ok(())
3500}
3501
3502async fn update_channel_message(
3503 request: proto::UpdateChannelMessage,
3504 response: Response<proto::UpdateChannelMessage>,
3505 session: Session,
3506) -> Result<()> {
3507 let channel_id = ChannelId::from_proto(request.channel_id);
3508 let message_id = MessageId::from_proto(request.message_id);
3509 let updated_at = OffsetDateTime::now_utc();
3510 let UpdatedChannelMessage {
3511 message_id,
3512 participant_connection_ids,
3513 notifications,
3514 reply_to_message_id,
3515 timestamp,
3516 deleted_mention_notification_ids,
3517 updated_mention_notifications,
3518 } = session
3519 .db()
3520 .await
3521 .update_channel_message(
3522 channel_id,
3523 message_id,
3524 session.user_id(),
3525 request.body.as_str(),
3526 &request.mentions,
3527 updated_at,
3528 )
3529 .await?;
3530
3531 let nonce = request
3532 .nonce
3533 .clone()
3534 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3535
3536 let message = proto::ChannelMessage {
3537 sender_id: session.user_id().to_proto(),
3538 id: message_id.to_proto(),
3539 body: request.body.clone(),
3540 mentions: request.mentions.clone(),
3541 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3542 nonce: Some(nonce),
3543 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3544 edited_at: Some(updated_at.unix_timestamp() as u64),
3545 };
3546
3547 response.send(proto::Ack {})?;
3548
3549 let pool = &*session.connection_pool().await;
3550 broadcast(
3551 Some(session.connection_id),
3552 participant_connection_ids,
3553 |connection| {
3554 session.peer.send(
3555 connection,
3556 proto::ChannelMessageUpdate {
3557 channel_id: channel_id.to_proto(),
3558 message: Some(message.clone()),
3559 },
3560 )?;
3561
3562 for notification_id in &deleted_mention_notification_ids {
3563 session.peer.send(
3564 connection,
3565 proto::DeleteNotification {
3566 notification_id: (*notification_id).to_proto(),
3567 },
3568 )?;
3569 }
3570
3571 for notification in &updated_mention_notifications {
3572 session.peer.send(
3573 connection,
3574 proto::UpdateNotification {
3575 notification: Some(notification.clone()),
3576 },
3577 )?;
3578 }
3579
3580 Ok(())
3581 },
3582 );
3583
3584 send_notifications(pool, &session.peer, notifications);
3585
3586 Ok(())
3587}
3588
3589/// Mark a channel message as read
3590async fn acknowledge_channel_message(
3591 request: proto::AckChannelMessage,
3592 session: Session,
3593) -> Result<()> {
3594 let channel_id = ChannelId::from_proto(request.channel_id);
3595 let message_id = MessageId::from_proto(request.message_id);
3596 let notifications = session
3597 .db()
3598 .await
3599 .observe_channel_message(channel_id, session.user_id(), message_id)
3600 .await?;
3601 send_notifications(
3602 &*session.connection_pool().await,
3603 &session.peer,
3604 notifications,
3605 );
3606 Ok(())
3607}
3608
3609/// Mark a buffer version as synced
3610async fn acknowledge_buffer_version(
3611 request: proto::AckBufferOperation,
3612 session: Session,
3613) -> Result<()> {
3614 let buffer_id = BufferId::from_proto(request.buffer_id);
3615 session
3616 .db()
3617 .await
3618 .observe_buffer_version(
3619 buffer_id,
3620 session.user_id(),
3621 request.epoch as i32,
3622 &request.version,
3623 )
3624 .await?;
3625 Ok(())
3626}
3627
3628async fn count_language_model_tokens(
3629 request: proto::CountLanguageModelTokens,
3630 response: Response<proto::CountLanguageModelTokens>,
3631 session: Session,
3632 config: &Config,
3633) -> Result<()> {
3634 authorize_access_to_legacy_llm_endpoints(&session).await?;
3635
3636 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3637 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3638 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3639 };
3640
3641 session
3642 .app_state
3643 .rate_limiter
3644 .check(&*rate_limit, session.user_id())
3645 .await?;
3646
3647 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3648 Some(proto::LanguageModelProvider::Google) => {
3649 let api_key = config
3650 .google_ai_api_key
3651 .as_ref()
3652 .context("no Google AI API key configured on the server")?;
3653 google_ai::count_tokens(
3654 session.http_client.as_ref(),
3655 google_ai::API_URL,
3656 api_key,
3657 serde_json::from_str(&request.request)?,
3658 )
3659 .await?
3660 }
3661 _ => return Err(anyhow!("unsupported provider"))?,
3662 };
3663
3664 response.send(proto::CountLanguageModelTokensResponse {
3665 token_count: result.total_tokens as u32,
3666 })?;
3667
3668 Ok(())
3669}
3670
3671struct ZedProCountLanguageModelTokensRateLimit;
3672
3673impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3674 fn capacity(&self) -> usize {
3675 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3676 .ok()
3677 .and_then(|v| v.parse().ok())
3678 .unwrap_or(600) // Picked arbitrarily
3679 }
3680
3681 fn refill_duration(&self) -> chrono::Duration {
3682 chrono::Duration::hours(1)
3683 }
3684
3685 fn db_name(&self) -> &'static str {
3686 "zed-pro:count-language-model-tokens"
3687 }
3688}
3689
3690struct FreeCountLanguageModelTokensRateLimit;
3691
3692impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3693 fn capacity(&self) -> usize {
3694 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3695 .ok()
3696 .and_then(|v| v.parse().ok())
3697 .unwrap_or(600 / 10) // Picked arbitrarily
3698 }
3699
3700 fn refill_duration(&self) -> chrono::Duration {
3701 chrono::Duration::hours(1)
3702 }
3703
3704 fn db_name(&self) -> &'static str {
3705 "free:count-language-model-tokens"
3706 }
3707}
3708
3709struct ZedProComputeEmbeddingsRateLimit;
3710
3711impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3712 fn capacity(&self) -> usize {
3713 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3714 .ok()
3715 .and_then(|v| v.parse().ok())
3716 .unwrap_or(5000) // Picked arbitrarily
3717 }
3718
3719 fn refill_duration(&self) -> chrono::Duration {
3720 chrono::Duration::hours(1)
3721 }
3722
3723 fn db_name(&self) -> &'static str {
3724 "zed-pro:compute-embeddings"
3725 }
3726}
3727
3728struct FreeComputeEmbeddingsRateLimit;
3729
3730impl RateLimit for FreeComputeEmbeddingsRateLimit {
3731 fn capacity(&self) -> usize {
3732 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3733 .ok()
3734 .and_then(|v| v.parse().ok())
3735 .unwrap_or(5000 / 10) // Picked arbitrarily
3736 }
3737
3738 fn refill_duration(&self) -> chrono::Duration {
3739 chrono::Duration::hours(1)
3740 }
3741
3742 fn db_name(&self) -> &'static str {
3743 "free:compute-embeddings"
3744 }
3745}
3746
3747async fn compute_embeddings(
3748 request: proto::ComputeEmbeddings,
3749 response: Response<proto::ComputeEmbeddings>,
3750 session: Session,
3751 api_key: Option<Arc<str>>,
3752) -> Result<()> {
3753 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3754 authorize_access_to_legacy_llm_endpoints(&session).await?;
3755
3756 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3757 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3758 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3759 };
3760
3761 session
3762 .app_state
3763 .rate_limiter
3764 .check(&*rate_limit, session.user_id())
3765 .await?;
3766
3767 let embeddings = match request.model.as_str() {
3768 "openai/text-embedding-3-small" => {
3769 open_ai::embed(
3770 session.http_client.as_ref(),
3771 OPEN_AI_API_URL,
3772 &api_key,
3773 OpenAiEmbeddingModel::TextEmbedding3Small,
3774 request.texts.iter().map(|text| text.as_str()),
3775 )
3776 .await?
3777 }
3778 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3779 };
3780
3781 let embeddings = request
3782 .texts
3783 .iter()
3784 .map(|text| {
3785 let mut hasher = sha2::Sha256::new();
3786 hasher.update(text.as_bytes());
3787 let result = hasher.finalize();
3788 result.to_vec()
3789 })
3790 .zip(
3791 embeddings
3792 .data
3793 .into_iter()
3794 .map(|embedding| embedding.embedding),
3795 )
3796 .collect::<HashMap<_, _>>();
3797
3798 let db = session.db().await;
3799 db.save_embeddings(&request.model, &embeddings)
3800 .await
3801 .context("failed to save embeddings")
3802 .trace_err();
3803
3804 response.send(proto::ComputeEmbeddingsResponse {
3805 embeddings: embeddings
3806 .into_iter()
3807 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3808 .collect(),
3809 })?;
3810 Ok(())
3811}
3812
3813async fn get_cached_embeddings(
3814 request: proto::GetCachedEmbeddings,
3815 response: Response<proto::GetCachedEmbeddings>,
3816 session: Session,
3817) -> Result<()> {
3818 authorize_access_to_legacy_llm_endpoints(&session).await?;
3819
3820 let db = session.db().await;
3821 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3822
3823 response.send(proto::GetCachedEmbeddingsResponse {
3824 embeddings: embeddings
3825 .into_iter()
3826 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3827 .collect(),
3828 })?;
3829 Ok(())
3830}
3831
3832/// This is leftover from before the LLM service.
3833///
3834/// The endpoints protected by this check will be moved there eventually.
3835async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3836 if session.is_staff() {
3837 Ok(())
3838 } else {
3839 Err(anyhow!("permission denied"))?
3840 }
3841}
3842
3843/// Get a Supermaven API key for the user
3844async fn get_supermaven_api_key(
3845 _request: proto::GetSupermavenApiKey,
3846 response: Response<proto::GetSupermavenApiKey>,
3847 session: Session,
3848) -> Result<()> {
3849 let user_id: String = session.user_id().to_string();
3850 if !session.is_staff() {
3851 return Err(anyhow!("supermaven not enabled for this account"))?;
3852 }
3853
3854 let email = session
3855 .email()
3856 .ok_or_else(|| anyhow!("user must have an email"))?;
3857
3858 let supermaven_admin_api = session
3859 .supermaven_client
3860 .as_ref()
3861 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3862
3863 let result = supermaven_admin_api
3864 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3865 .await?;
3866
3867 response.send(proto::GetSupermavenApiKeyResponse {
3868 api_key: result.api_key,
3869 })?;
3870
3871 Ok(())
3872}
3873
3874/// Start receiving chat updates for a channel
3875async fn join_channel_chat(
3876 request: proto::JoinChannelChat,
3877 response: Response<proto::JoinChannelChat>,
3878 session: Session,
3879) -> Result<()> {
3880 let channel_id = ChannelId::from_proto(request.channel_id);
3881
3882 let db = session.db().await;
3883 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3884 .await?;
3885 let messages = db
3886 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3887 .await?;
3888 response.send(proto::JoinChannelChatResponse {
3889 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3890 messages,
3891 })?;
3892 Ok(())
3893}
3894
3895/// Stop receiving chat updates for a channel
3896async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3897 let channel_id = ChannelId::from_proto(request.channel_id);
3898 session
3899 .db()
3900 .await
3901 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3902 .await?;
3903 Ok(())
3904}
3905
3906/// Retrieve the chat history for a channel
3907async fn get_channel_messages(
3908 request: proto::GetChannelMessages,
3909 response: Response<proto::GetChannelMessages>,
3910 session: Session,
3911) -> Result<()> {
3912 let channel_id = ChannelId::from_proto(request.channel_id);
3913 let messages = session
3914 .db()
3915 .await
3916 .get_channel_messages(
3917 channel_id,
3918 session.user_id(),
3919 MESSAGE_COUNT_PER_PAGE,
3920 Some(MessageId::from_proto(request.before_message_id)),
3921 )
3922 .await?;
3923 response.send(proto::GetChannelMessagesResponse {
3924 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3925 messages,
3926 })?;
3927 Ok(())
3928}
3929
3930/// Retrieve specific chat messages
3931async fn get_channel_messages_by_id(
3932 request: proto::GetChannelMessagesById,
3933 response: Response<proto::GetChannelMessagesById>,
3934 session: Session,
3935) -> Result<()> {
3936 let message_ids = request
3937 .message_ids
3938 .iter()
3939 .map(|id| MessageId::from_proto(*id))
3940 .collect::<Vec<_>>();
3941 let messages = session
3942 .db()
3943 .await
3944 .get_channel_messages_by_id(session.user_id(), &message_ids)
3945 .await?;
3946 response.send(proto::GetChannelMessagesResponse {
3947 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3948 messages,
3949 })?;
3950 Ok(())
3951}
3952
3953/// Retrieve the current users notifications
3954async fn get_notifications(
3955 request: proto::GetNotifications,
3956 response: Response<proto::GetNotifications>,
3957 session: Session,
3958) -> Result<()> {
3959 let notifications = session
3960 .db()
3961 .await
3962 .get_notifications(
3963 session.user_id(),
3964 NOTIFICATION_COUNT_PER_PAGE,
3965 request.before_id.map(db::NotificationId::from_proto),
3966 )
3967 .await?;
3968 response.send(proto::GetNotificationsResponse {
3969 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3970 notifications,
3971 })?;
3972 Ok(())
3973}
3974
3975/// Mark notifications as read
3976async fn mark_notification_as_read(
3977 request: proto::MarkNotificationRead,
3978 response: Response<proto::MarkNotificationRead>,
3979 session: Session,
3980) -> Result<()> {
3981 let database = &session.db().await;
3982 let notifications = database
3983 .mark_notification_as_read_by_id(
3984 session.user_id(),
3985 NotificationId::from_proto(request.notification_id),
3986 )
3987 .await?;
3988 send_notifications(
3989 &*session.connection_pool().await,
3990 &session.peer,
3991 notifications,
3992 );
3993 response.send(proto::Ack {})?;
3994 Ok(())
3995}
3996
3997/// Get the current users information
3998async fn get_private_user_info(
3999 _request: proto::GetPrivateUserInfo,
4000 response: Response<proto::GetPrivateUserInfo>,
4001 session: Session,
4002) -> Result<()> {
4003 let db = session.db().await;
4004
4005 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4006 let user = db
4007 .get_user_by_id(session.user_id())
4008 .await?
4009 .ok_or_else(|| anyhow!("user not found"))?;
4010 let flags = db.get_user_flags(session.user_id()).await?;
4011
4012 response.send(proto::GetPrivateUserInfoResponse {
4013 metrics_id,
4014 staff: user.admin,
4015 flags,
4016 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4017 })?;
4018 Ok(())
4019}
4020
4021/// Accept the terms of service (tos) on behalf of the current user
4022async fn accept_terms_of_service(
4023 _request: proto::AcceptTermsOfService,
4024 response: Response<proto::AcceptTermsOfService>,
4025 session: Session,
4026) -> Result<()> {
4027 let db = session.db().await;
4028
4029 let accepted_tos_at = Utc::now();
4030 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4031 .await?;
4032
4033 response.send(proto::AcceptTermsOfServiceResponse {
4034 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4035 })?;
4036 Ok(())
4037}
4038
4039/// The minimum account age an account must have in order to use the LLM service.
4040const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4041
4042async fn get_llm_api_token(
4043 _request: proto::GetLlmToken,
4044 response: Response<proto::GetLlmToken>,
4045 session: Session,
4046) -> Result<()> {
4047 let db = session.db().await;
4048
4049 let flags = db.get_user_flags(session.user_id()).await?;
4050 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4051 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4052 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4053
4054 if !session.is_staff() && !has_language_models_feature_flag {
4055 Err(anyhow!("permission denied"))?
4056 }
4057
4058 let user_id = session.user_id();
4059 let user = db
4060 .get_user_by_id(user_id)
4061 .await?
4062 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4063
4064 if user.accepted_tos_at.is_none() {
4065 Err(anyhow!("terms of service not accepted"))?
4066 }
4067
4068 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4069
4070 let bypass_account_age_check =
4071 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4072 if !bypass_account_age_check {
4073 let mut account_created_at = user.created_at;
4074 if let Some(github_created_at) = user.github_user_created_at {
4075 account_created_at = account_created_at.min(github_created_at);
4076 }
4077 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4078 Err(anyhow!("account too young"))?
4079 }
4080 }
4081
4082 let billing_preferences = db.get_billing_preferences(user.id).await?;
4083
4084 let token = LlmTokenClaims::create(
4085 &user,
4086 session.is_staff(),
4087 billing_preferences,
4088 has_llm_closed_beta_feature_flag,
4089 has_predict_edits_feature_flag,
4090 has_llm_subscription,
4091 session.current_plan(&db).await?,
4092 session.system_id.clone(),
4093 &session.app_state.config,
4094 )?;
4095 response.send(proto::GetLlmTokenResponse { token })?;
4096 Ok(())
4097}
4098
4099fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4100 let message = match message {
4101 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4102 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4103 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4104 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4105 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4106 code: frame.code.into(),
4107 reason: frame.reason,
4108 })),
4109 // We should never receive a frame while reading the message, according
4110 // to the `tungstenite` maintainers:
4111 //
4112 // > It cannot occur when you read messages from the WebSocket, but it
4113 // > can be used when you want to send the raw frames (e.g. you want to
4114 // > send the frames to the WebSocket without composing the full message first).
4115 // >
4116 // > — https://github.com/snapview/tungstenite-rs/issues/268
4117 TungsteniteMessage::Frame(_) => {
4118 bail!("received an unexpected frame while reading the message")
4119 }
4120 };
4121
4122 Ok(message)
4123}
4124
4125fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4126 match message {
4127 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4128 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4129 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4130 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4131 AxumMessage::Close(frame) => {
4132 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4133 code: frame.code.into(),
4134 reason: frame.reason,
4135 }))
4136 }
4137 }
4138}
4139
4140fn notify_membership_updated(
4141 connection_pool: &mut ConnectionPool,
4142 result: MembershipUpdated,
4143 user_id: UserId,
4144 peer: &Peer,
4145) {
4146 for membership in &result.new_channels.channel_memberships {
4147 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4148 }
4149 for channel_id in &result.removed_channels {
4150 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4151 }
4152
4153 let user_channels_update = proto::UpdateUserChannels {
4154 channel_memberships: result
4155 .new_channels
4156 .channel_memberships
4157 .iter()
4158 .map(|cm| proto::ChannelMembership {
4159 channel_id: cm.channel_id.to_proto(),
4160 role: cm.role.into(),
4161 })
4162 .collect(),
4163 ..Default::default()
4164 };
4165
4166 let mut update = build_channels_update(result.new_channels);
4167 update.delete_channels = result
4168 .removed_channels
4169 .into_iter()
4170 .map(|id| id.to_proto())
4171 .collect();
4172 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4173
4174 for connection_id in connection_pool.user_connection_ids(user_id) {
4175 peer.send(connection_id, user_channels_update.clone())
4176 .trace_err();
4177 peer.send(connection_id, update.clone()).trace_err();
4178 }
4179}
4180
4181fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4182 proto::UpdateUserChannels {
4183 channel_memberships: channels
4184 .channel_memberships
4185 .iter()
4186 .map(|m| proto::ChannelMembership {
4187 channel_id: m.channel_id.to_proto(),
4188 role: m.role.into(),
4189 })
4190 .collect(),
4191 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4192 observed_channel_message_id: channels.observed_channel_messages.clone(),
4193 }
4194}
4195
4196fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4197 let mut update = proto::UpdateChannels::default();
4198
4199 for channel in channels.channels {
4200 update.channels.push(channel.to_proto());
4201 }
4202
4203 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4204 update.latest_channel_message_ids = channels.latest_channel_messages;
4205
4206 for (channel_id, participants) in channels.channel_participants {
4207 update
4208 .channel_participants
4209 .push(proto::ChannelParticipants {
4210 channel_id: channel_id.to_proto(),
4211 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4212 });
4213 }
4214
4215 for channel in channels.invited_channels {
4216 update.channel_invitations.push(channel.to_proto());
4217 }
4218
4219 update
4220}
4221
4222fn build_initial_contacts_update(
4223 contacts: Vec<db::Contact>,
4224 pool: &ConnectionPool,
4225) -> proto::UpdateContacts {
4226 let mut update = proto::UpdateContacts::default();
4227
4228 for contact in contacts {
4229 match contact {
4230 db::Contact::Accepted { user_id, busy } => {
4231 update.contacts.push(contact_for_user(user_id, busy, pool));
4232 }
4233 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4234 db::Contact::Incoming { user_id } => {
4235 update
4236 .incoming_requests
4237 .push(proto::IncomingContactRequest {
4238 requester_id: user_id.to_proto(),
4239 })
4240 }
4241 }
4242 }
4243
4244 update
4245}
4246
4247fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4248 proto::Contact {
4249 user_id: user_id.to_proto(),
4250 online: pool.is_user_online(user_id),
4251 busy,
4252 }
4253}
4254
4255fn room_updated(room: &proto::Room, peer: &Peer) {
4256 broadcast(
4257 None,
4258 room.participants
4259 .iter()
4260 .filter_map(|participant| Some(participant.peer_id?.into())),
4261 |peer_id| {
4262 peer.send(
4263 peer_id,
4264 proto::RoomUpdated {
4265 room: Some(room.clone()),
4266 },
4267 )
4268 },
4269 );
4270}
4271
4272fn channel_updated(
4273 channel: &db::channel::Model,
4274 room: &proto::Room,
4275 peer: &Peer,
4276 pool: &ConnectionPool,
4277) {
4278 let participants = room
4279 .participants
4280 .iter()
4281 .map(|p| p.user_id)
4282 .collect::<Vec<_>>();
4283
4284 broadcast(
4285 None,
4286 pool.channel_connection_ids(channel.root_id())
4287 .filter_map(|(channel_id, role)| {
4288 role.can_see_channel(channel.visibility)
4289 .then_some(channel_id)
4290 }),
4291 |peer_id| {
4292 peer.send(
4293 peer_id,
4294 proto::UpdateChannels {
4295 channel_participants: vec![proto::ChannelParticipants {
4296 channel_id: channel.id.to_proto(),
4297 participant_user_ids: participants.clone(),
4298 }],
4299 ..Default::default()
4300 },
4301 )
4302 },
4303 );
4304}
4305
4306async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4307 let db = session.db().await;
4308
4309 let contacts = db.get_contacts(user_id).await?;
4310 let busy = db.is_user_busy(user_id).await?;
4311
4312 let pool = session.connection_pool().await;
4313 let updated_contact = contact_for_user(user_id, busy, &pool);
4314 for contact in contacts {
4315 if let db::Contact::Accepted {
4316 user_id: contact_user_id,
4317 ..
4318 } = contact
4319 {
4320 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4321 session
4322 .peer
4323 .send(
4324 contact_conn_id,
4325 proto::UpdateContacts {
4326 contacts: vec![updated_contact.clone()],
4327 remove_contacts: Default::default(),
4328 incoming_requests: Default::default(),
4329 remove_incoming_requests: Default::default(),
4330 outgoing_requests: Default::default(),
4331 remove_outgoing_requests: Default::default(),
4332 },
4333 )
4334 .trace_err();
4335 }
4336 }
4337 }
4338 Ok(())
4339}
4340
4341async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4342 let mut contacts_to_update = HashSet::default();
4343
4344 let room_id;
4345 let canceled_calls_to_user_ids;
4346 let livekit_room;
4347 let delete_livekit_room;
4348 let room;
4349 let channel;
4350
4351 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4352 contacts_to_update.insert(session.user_id());
4353
4354 for project in left_room.left_projects.values() {
4355 project_left(project, session);
4356 }
4357
4358 room_id = RoomId::from_proto(left_room.room.id);
4359 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4360 livekit_room = mem::take(&mut left_room.room.livekit_room);
4361 delete_livekit_room = left_room.deleted;
4362 room = mem::take(&mut left_room.room);
4363 channel = mem::take(&mut left_room.channel);
4364
4365 room_updated(&room, &session.peer);
4366 } else {
4367 return Ok(());
4368 }
4369
4370 if let Some(channel) = channel {
4371 channel_updated(
4372 &channel,
4373 &room,
4374 &session.peer,
4375 &*session.connection_pool().await,
4376 );
4377 }
4378
4379 {
4380 let pool = session.connection_pool().await;
4381 for canceled_user_id in canceled_calls_to_user_ids {
4382 for connection_id in pool.user_connection_ids(canceled_user_id) {
4383 session
4384 .peer
4385 .send(
4386 connection_id,
4387 proto::CallCanceled {
4388 room_id: room_id.to_proto(),
4389 },
4390 )
4391 .trace_err();
4392 }
4393 contacts_to_update.insert(canceled_user_id);
4394 }
4395 }
4396
4397 for contact_user_id in contacts_to_update {
4398 update_user_contacts(contact_user_id, session).await?;
4399 }
4400
4401 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4402 live_kit
4403 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4404 .await
4405 .trace_err();
4406
4407 if delete_livekit_room {
4408 live_kit.delete_room(livekit_room).await.trace_err();
4409 }
4410 }
4411
4412 Ok(())
4413}
4414
4415async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4416 let left_channel_buffers = session
4417 .db()
4418 .await
4419 .leave_channel_buffers(session.connection_id)
4420 .await?;
4421
4422 for left_buffer in left_channel_buffers {
4423 channel_buffer_updated(
4424 session.connection_id,
4425 left_buffer.connections,
4426 &proto::UpdateChannelBufferCollaborators {
4427 channel_id: left_buffer.channel_id.to_proto(),
4428 collaborators: left_buffer.collaborators,
4429 },
4430 &session.peer,
4431 );
4432 }
4433
4434 Ok(())
4435}
4436
4437fn project_left(project: &db::LeftProject, session: &Session) {
4438 for connection_id in &project.connection_ids {
4439 if project.should_unshare {
4440 session
4441 .peer
4442 .send(
4443 *connection_id,
4444 proto::UnshareProject {
4445 project_id: project.id.to_proto(),
4446 },
4447 )
4448 .trace_err();
4449 } else {
4450 session
4451 .peer
4452 .send(
4453 *connection_id,
4454 proto::RemoveProjectCollaborator {
4455 project_id: project.id.to_proto(),
4456 peer_id: Some(session.connection_id.into()),
4457 },
4458 )
4459 .trace_err();
4460 }
4461 }
4462}
4463
4464pub trait ResultExt {
4465 type Ok;
4466
4467 fn trace_err(self) -> Option<Self::Ok>;
4468}
4469
4470impl<T, E> ResultExt for Result<T, E>
4471where
4472 E: std::fmt::Debug,
4473{
4474 type Ok = T;
4475
4476 #[track_caller]
4477 fn trace_err(self) -> Option<T> {
4478 match self {
4479 Ok(value) => Some(value),
4480 Err(error) => {
4481 tracing::error!("{:?}", error);
4482 None
4483 }
4484 }
4485 }
4486}