1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::GetStagedText>)
313 .add_request_handler(
314 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
315 )
316 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
317 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
318 .add_request_handler(
319 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
320 )
321 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
322 .add_request_handler(
323 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
324 )
325 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
326 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
327 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
328 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
330 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
332 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
333 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
336 .add_request_handler(
337 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
338 )
339 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
340 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
341 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
342 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
343 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
344 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
345 .add_message_handler(create_buffer_for_peer)
346 .add_request_handler(update_buffer)
347 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
348 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
349 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
350 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
352 .add_request_handler(get_users)
353 .add_request_handler(fuzzy_search_users)
354 .add_request_handler(request_contact)
355 .add_request_handler(remove_contact)
356 .add_request_handler(respond_to_contact_request)
357 .add_message_handler(subscribe_to_channels)
358 .add_request_handler(create_channel)
359 .add_request_handler(delete_channel)
360 .add_request_handler(invite_channel_member)
361 .add_request_handler(remove_channel_member)
362 .add_request_handler(set_channel_member_role)
363 .add_request_handler(set_channel_visibility)
364 .add_request_handler(rename_channel)
365 .add_request_handler(join_channel_buffer)
366 .add_request_handler(leave_channel_buffer)
367 .add_message_handler(update_channel_buffer)
368 .add_request_handler(rejoin_channel_buffers)
369 .add_request_handler(get_channel_members)
370 .add_request_handler(respond_to_channel_invite)
371 .add_request_handler(join_channel)
372 .add_request_handler(join_channel_chat)
373 .add_message_handler(leave_channel_chat)
374 .add_request_handler(send_channel_message)
375 .add_request_handler(remove_channel_message)
376 .add_request_handler(update_channel_message)
377 .add_request_handler(get_channel_messages)
378 .add_request_handler(get_channel_messages_by_id)
379 .add_request_handler(get_notifications)
380 .add_request_handler(mark_notification_as_read)
381 .add_request_handler(move_channel)
382 .add_request_handler(follow)
383 .add_message_handler(unfollow)
384 .add_message_handler(update_followers)
385 .add_request_handler(get_private_user_info)
386 .add_request_handler(get_llm_api_token)
387 .add_request_handler(accept_terms_of_service)
388 .add_message_handler(acknowledge_channel_message)
389 .add_message_handler(acknowledge_buffer_version)
390 .add_request_handler(get_supermaven_api_key)
391 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
392 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
393 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
394 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
395 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
396 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
397 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
398 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
399 .add_message_handler(update_context)
400 .add_request_handler({
401 let app_state = app_state.clone();
402 move |request, response, session| {
403 let app_state = app_state.clone();
404 async move {
405 count_language_model_tokens(request, response, session, &app_state.config)
406 .await
407 }
408 }
409 })
410 .add_request_handler(get_cached_embeddings)
411 .add_request_handler({
412 let app_state = app_state.clone();
413 move |request, response, session| {
414 compute_embeddings(
415 request,
416 response,
417 session,
418 app_state.config.openai_api_key.clone(),
419 )
420 }
421 });
422
423 Arc::new(server)
424 }
425
426 pub async fn start(&self) -> Result<()> {
427 let server_id = *self.id.lock();
428 let app_state = self.app_state.clone();
429 let peer = self.peer.clone();
430 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
431 let pool = self.connection_pool.clone();
432 let livekit_client = self.app_state.livekit_client.clone();
433
434 let span = info_span!("start server");
435 self.app_state.executor.spawn_detached(
436 async move {
437 tracing::info!("waiting for cleanup timeout");
438 timeout.await;
439 tracing::info!("cleanup timeout expired, retrieving stale rooms");
440 if let Some((room_ids, channel_ids)) = app_state
441 .db
442 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
443 .await
444 .trace_err()
445 {
446 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
447 tracing::info!(
448 stale_channel_buffer_count = channel_ids.len(),
449 "retrieved stale channel buffers"
450 );
451
452 for channel_id in channel_ids {
453 if let Some(refreshed_channel_buffer) = app_state
454 .db
455 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
456 .await
457 .trace_err()
458 {
459 for connection_id in refreshed_channel_buffer.connection_ids {
460 peer.send(
461 connection_id,
462 proto::UpdateChannelBufferCollaborators {
463 channel_id: channel_id.to_proto(),
464 collaborators: refreshed_channel_buffer
465 .collaborators
466 .clone(),
467 },
468 )
469 .trace_err();
470 }
471 }
472 }
473
474 for room_id in room_ids {
475 let mut contacts_to_update = HashSet::default();
476 let mut canceled_calls_to_user_ids = Vec::new();
477 let mut livekit_room = String::new();
478 let mut delete_livekit_room = false;
479
480 if let Some(mut refreshed_room) = app_state
481 .db
482 .clear_stale_room_participants(room_id, server_id)
483 .await
484 .trace_err()
485 {
486 tracing::info!(
487 room_id = room_id.0,
488 new_participant_count = refreshed_room.room.participants.len(),
489 "refreshed room"
490 );
491 room_updated(&refreshed_room.room, &peer);
492 if let Some(channel) = refreshed_room.channel.as_ref() {
493 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
494 }
495 contacts_to_update
496 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
497 contacts_to_update
498 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
499 canceled_calls_to_user_ids =
500 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
501 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
502 delete_livekit_room = refreshed_room.room.participants.is_empty();
503 }
504
505 {
506 let pool = pool.lock();
507 for canceled_user_id in canceled_calls_to_user_ids {
508 for connection_id in pool.user_connection_ids(canceled_user_id) {
509 peer.send(
510 connection_id,
511 proto::CallCanceled {
512 room_id: room_id.to_proto(),
513 },
514 )
515 .trace_err();
516 }
517 }
518 }
519
520 for user_id in contacts_to_update {
521 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
522 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
523 if let Some((busy, contacts)) = busy.zip(contacts) {
524 let pool = pool.lock();
525 let updated_contact = contact_for_user(user_id, busy, &pool);
526 for contact in contacts {
527 if let db::Contact::Accepted {
528 user_id: contact_user_id,
529 ..
530 } = contact
531 {
532 for contact_conn_id in
533 pool.user_connection_ids(contact_user_id)
534 {
535 peer.send(
536 contact_conn_id,
537 proto::UpdateContacts {
538 contacts: vec![updated_contact.clone()],
539 remove_contacts: Default::default(),
540 incoming_requests: Default::default(),
541 remove_incoming_requests: Default::default(),
542 outgoing_requests: Default::default(),
543 remove_outgoing_requests: Default::default(),
544 },
545 )
546 .trace_err();
547 }
548 }
549 }
550 }
551 }
552
553 if let Some(live_kit) = livekit_client.as_ref() {
554 if delete_livekit_room {
555 live_kit.delete_room(livekit_room).await.trace_err();
556 }
557 }
558 }
559 }
560
561 app_state
562 .db
563 .delete_stale_servers(&app_state.config.zed_environment, server_id)
564 .await
565 .trace_err();
566 }
567 .instrument(span),
568 );
569 Ok(())
570 }
571
572 pub fn teardown(&self) {
573 self.peer.teardown();
574 self.connection_pool.lock().reset();
575 let _ = self.teardown.send(true);
576 }
577
578 #[cfg(test)]
579 pub fn reset(&self, id: ServerId) {
580 self.teardown();
581 *self.id.lock() = id;
582 self.peer.reset(id.0 as u32);
583 let _ = self.teardown.send(false);
584 }
585
586 #[cfg(test)]
587 pub fn id(&self) -> ServerId {
588 *self.id.lock()
589 }
590
591 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
592 where
593 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
594 Fut: 'static + Send + Future<Output = Result<()>>,
595 M: EnvelopedMessage,
596 {
597 let prev_handler = self.handlers.insert(
598 TypeId::of::<M>(),
599 Box::new(move |envelope, session| {
600 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
601 let received_at = envelope.received_at;
602 tracing::info!("message received");
603 let start_time = Instant::now();
604 let future = (handler)(*envelope, session);
605 async move {
606 let result = future.await;
607 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
608 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
609 let queue_duration_ms = total_duration_ms - processing_duration_ms;
610 let payload_type = M::NAME;
611
612 match result {
613 Err(error) => {
614 tracing::error!(
615 ?error,
616 total_duration_ms,
617 processing_duration_ms,
618 queue_duration_ms,
619 payload_type,
620 "error handling message"
621 )
622 }
623 Ok(()) => tracing::info!(
624 total_duration_ms,
625 processing_duration_ms,
626 queue_duration_ms,
627 "finished handling message"
628 ),
629 }
630 }
631 .boxed()
632 }),
633 );
634 if prev_handler.is_some() {
635 panic!("registered a handler for the same message twice");
636 }
637 self
638 }
639
640 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
641 where
642 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
643 Fut: 'static + Send + Future<Output = Result<()>>,
644 M: EnvelopedMessage,
645 {
646 self.add_handler(move |envelope, session| handler(envelope.payload, session));
647 self
648 }
649
650 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
651 where
652 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
653 Fut: Send + Future<Output = Result<()>>,
654 M: RequestMessage,
655 {
656 let handler = Arc::new(handler);
657 self.add_handler(move |envelope, session| {
658 let receipt = envelope.receipt();
659 let handler = handler.clone();
660 async move {
661 let peer = session.peer.clone();
662 let responded = Arc::new(AtomicBool::default());
663 let response = Response {
664 peer: peer.clone(),
665 responded: responded.clone(),
666 receipt,
667 };
668 match (handler)(envelope.payload, response, session).await {
669 Ok(()) => {
670 if responded.load(std::sync::atomic::Ordering::SeqCst) {
671 Ok(())
672 } else {
673 Err(anyhow!("handler did not send a response"))?
674 }
675 }
676 Err(error) => {
677 let proto_err = match &error {
678 Error::Internal(err) => err.to_proto(),
679 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
680 };
681 peer.respond_with_error(receipt, proto_err)?;
682 Err(error)
683 }
684 }
685 }
686 })
687 }
688
689 #[allow(clippy::too_many_arguments)]
690 pub fn handle_connection(
691 self: &Arc<Self>,
692 connection: Connection,
693 address: String,
694 principal: Principal,
695 zed_version: ZedVersion,
696 geoip_country_code: Option<String>,
697 system_id: Option<String>,
698 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
699 executor: Executor,
700 ) -> impl Future<Output = ()> {
701 let this = self.clone();
702 let span = info_span!("handle connection", %address,
703 connection_id=field::Empty,
704 user_id=field::Empty,
705 login=field::Empty,
706 impersonator=field::Empty,
707 geoip_country_code=field::Empty
708 );
709 principal.update_span(&span);
710 if let Some(country_code) = geoip_country_code.as_ref() {
711 span.record("geoip_country_code", country_code);
712 }
713
714 let mut teardown = self.teardown.subscribe();
715 async move {
716 if *teardown.borrow() {
717 tracing::error!("server is tearing down");
718 return
719 }
720 let (connection_id, handle_io, mut incoming_rx) = this
721 .peer
722 .add_connection(connection, {
723 let executor = executor.clone();
724 move |duration| executor.sleep(duration)
725 });
726 tracing::Span::current().record("connection_id", format!("{}", connection_id));
727
728 tracing::info!("connection opened");
729
730 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
731 let http_client = match ReqwestClient::user_agent(&user_agent) {
732 Ok(http_client) => Arc::new(http_client),
733 Err(error) => {
734 tracing::error!(?error, "failed to create HTTP client");
735 return;
736 }
737 };
738
739 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
740 supermaven_admin_api_key.to_string(),
741 http_client.clone(),
742 )));
743
744 let session = Session {
745 principal: principal.clone(),
746 connection_id,
747 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
748 peer: this.peer.clone(),
749 connection_pool: this.connection_pool.clone(),
750 app_state: this.app_state.clone(),
751 http_client,
752 geoip_country_code,
753 system_id,
754 _executor: executor.clone(),
755 supermaven_client,
756 };
757
758 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
759 tracing::error!(?error, "failed to send initial client update");
760 return;
761 }
762
763 let handle_io = handle_io.fuse();
764 futures::pin_mut!(handle_io);
765
766 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
767 // This prevents deadlocks when e.g., client A performs a request to client B and
768 // client B performs a request to client A. If both clients stop processing further
769 // messages until their respective request completes, they won't have a chance to
770 // respond to the other client's request and cause a deadlock.
771 //
772 // This arrangement ensures we will attempt to process earlier messages first, but fall
773 // back to processing messages arrived later in the spirit of making progress.
774 let mut foreground_message_handlers = FuturesUnordered::new();
775 let concurrent_handlers = Arc::new(Semaphore::new(256));
776 loop {
777 let next_message = async {
778 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
779 let message = incoming_rx.next().await;
780 (permit, message)
781 }.fuse();
782 futures::pin_mut!(next_message);
783 futures::select_biased! {
784 _ = teardown.changed().fuse() => return,
785 result = handle_io => {
786 if let Err(error) = result {
787 tracing::error!(?error, "error handling I/O");
788 }
789 break;
790 }
791 _ = foreground_message_handlers.next() => {}
792 next_message = next_message => {
793 let (permit, message) = next_message;
794 if let Some(message) = message {
795 let type_name = message.payload_type_name();
796 // note: we copy all the fields from the parent span so we can query them in the logs.
797 // (https://github.com/tokio-rs/tracing/issues/2670).
798 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
799 user_id=field::Empty,
800 login=field::Empty,
801 impersonator=field::Empty,
802 );
803 principal.update_span(&span);
804 let span_enter = span.enter();
805 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
806 let is_background = message.is_background();
807 let handle_message = (handler)(message, session.clone());
808 drop(span_enter);
809
810 let handle_message = async move {
811 handle_message.await;
812 drop(permit);
813 }.instrument(span);
814 if is_background {
815 executor.spawn_detached(handle_message);
816 } else {
817 foreground_message_handlers.push(handle_message);
818 }
819 } else {
820 tracing::error!("no message handler");
821 }
822 } else {
823 tracing::info!("connection closed");
824 break;
825 }
826 }
827 }
828 }
829
830 drop(foreground_message_handlers);
831 tracing::info!("signing out");
832 if let Err(error) = connection_lost(session, teardown, executor).await {
833 tracing::error!(?error, "error signing out");
834 }
835
836 }.instrument(span)
837 }
838
839 async fn send_initial_client_update(
840 &self,
841 connection_id: ConnectionId,
842 principal: &Principal,
843 zed_version: ZedVersion,
844 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
845 session: &Session,
846 ) -> Result<()> {
847 self.peer.send(
848 connection_id,
849 proto::Hello {
850 peer_id: Some(connection_id.into()),
851 },
852 )?;
853 tracing::info!("sent hello message");
854 if let Some(send_connection_id) = send_connection_id.take() {
855 let _ = send_connection_id.send(connection_id);
856 }
857
858 match principal {
859 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
860 if !user.connected_once {
861 self.peer.send(connection_id, proto::ShowContacts {})?;
862 self.app_state
863 .db
864 .set_user_connected_once(user.id, true)
865 .await?;
866 }
867
868 update_user_plan(user.id, session).await?;
869
870 let contacts = self.app_state.db.get_contacts(user.id).await?;
871
872 {
873 let mut pool = self.connection_pool.lock();
874 pool.add_connection(connection_id, user.id, user.admin, zed_version);
875 self.peer.send(
876 connection_id,
877 build_initial_contacts_update(contacts, &pool),
878 )?;
879 }
880
881 if should_auto_subscribe_to_channels(zed_version) {
882 subscribe_user_to_channels(user.id, session).await?;
883 }
884
885 if let Some(incoming_call) =
886 self.app_state.db.incoming_call_for_user(user.id).await?
887 {
888 self.peer.send(connection_id, incoming_call)?;
889 }
890
891 update_user_contacts(user.id, session).await?;
892 }
893 }
894
895 Ok(())
896 }
897
898 pub async fn invite_code_redeemed(
899 self: &Arc<Self>,
900 inviter_id: UserId,
901 invitee_id: UserId,
902 ) -> Result<()> {
903 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
904 if let Some(code) = &user.invite_code {
905 let pool = self.connection_pool.lock();
906 let invitee_contact = contact_for_user(invitee_id, false, &pool);
907 for connection_id in pool.user_connection_ids(inviter_id) {
908 self.peer.send(
909 connection_id,
910 proto::UpdateContacts {
911 contacts: vec![invitee_contact.clone()],
912 ..Default::default()
913 },
914 )?;
915 self.peer.send(
916 connection_id,
917 proto::UpdateInviteInfo {
918 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
919 count: user.invite_count as u32,
920 },
921 )?;
922 }
923 }
924 }
925 Ok(())
926 }
927
928 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
929 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
930 if let Some(invite_code) = &user.invite_code {
931 let pool = self.connection_pool.lock();
932 for connection_id in pool.user_connection_ids(user_id) {
933 self.peer.send(
934 connection_id,
935 proto::UpdateInviteInfo {
936 url: format!(
937 "{}{}",
938 self.app_state.config.invite_link_prefix, invite_code
939 ),
940 count: user.invite_count as u32,
941 },
942 )?;
943 }
944 }
945 }
946 Ok(())
947 }
948
949 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
950 let pool = self.connection_pool.lock();
951 for connection_id in pool.user_connection_ids(user_id) {
952 self.peer
953 .send(connection_id, proto::RefreshLlmToken {})
954 .trace_err();
955 }
956 }
957
958 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
959 ServerSnapshot {
960 connection_pool: ConnectionPoolGuard {
961 guard: self.connection_pool.lock(),
962 _not_send: PhantomData,
963 },
964 peer: &self.peer,
965 }
966 }
967}
968
969impl<'a> Deref for ConnectionPoolGuard<'a> {
970 type Target = ConnectionPool;
971
972 fn deref(&self) -> &Self::Target {
973 &self.guard
974 }
975}
976
977impl<'a> DerefMut for ConnectionPoolGuard<'a> {
978 fn deref_mut(&mut self) -> &mut Self::Target {
979 &mut self.guard
980 }
981}
982
983impl<'a> Drop for ConnectionPoolGuard<'a> {
984 fn drop(&mut self) {
985 #[cfg(test)]
986 self.check_invariants();
987 }
988}
989
990fn broadcast<F>(
991 sender_id: Option<ConnectionId>,
992 receiver_ids: impl IntoIterator<Item = ConnectionId>,
993 mut f: F,
994) where
995 F: FnMut(ConnectionId) -> anyhow::Result<()>,
996{
997 for receiver_id in receiver_ids {
998 if Some(receiver_id) != sender_id {
999 if let Err(error) = f(receiver_id) {
1000 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1001 }
1002 }
1003 }
1004}
1005
1006pub struct ProtocolVersion(u32);
1007
1008impl Header for ProtocolVersion {
1009 fn name() -> &'static HeaderName {
1010 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1011 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1012 }
1013
1014 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1015 where
1016 Self: Sized,
1017 I: Iterator<Item = &'i axum::http::HeaderValue>,
1018 {
1019 let version = values
1020 .next()
1021 .ok_or_else(axum::headers::Error::invalid)?
1022 .to_str()
1023 .map_err(|_| axum::headers::Error::invalid())?
1024 .parse()
1025 .map_err(|_| axum::headers::Error::invalid())?;
1026 Ok(Self(version))
1027 }
1028
1029 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1030 values.extend([self.0.to_string().parse().unwrap()]);
1031 }
1032}
1033
1034pub struct AppVersionHeader(SemanticVersion);
1035impl Header for AppVersionHeader {
1036 fn name() -> &'static HeaderName {
1037 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1038 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1039 }
1040
1041 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1042 where
1043 Self: Sized,
1044 I: Iterator<Item = &'i axum::http::HeaderValue>,
1045 {
1046 let version = values
1047 .next()
1048 .ok_or_else(axum::headers::Error::invalid)?
1049 .to_str()
1050 .map_err(|_| axum::headers::Error::invalid())?
1051 .parse()
1052 .map_err(|_| axum::headers::Error::invalid())?;
1053 Ok(Self(version))
1054 }
1055
1056 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1057 values.extend([self.0.to_string().parse().unwrap()]);
1058 }
1059}
1060
1061pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1062 Router::new()
1063 .route("/rpc", get(handle_websocket_request))
1064 .layer(
1065 ServiceBuilder::new()
1066 .layer(Extension(server.app_state.clone()))
1067 .layer(middleware::from_fn(auth::validate_header)),
1068 )
1069 .route("/metrics", get(handle_metrics))
1070 .layer(Extension(server))
1071}
1072
1073#[allow(clippy::too_many_arguments)]
1074pub async fn handle_websocket_request(
1075 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1076 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1077 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1078 Extension(server): Extension<Arc<Server>>,
1079 Extension(principal): Extension<Principal>,
1080 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1081 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1082 ws: WebSocketUpgrade,
1083) -> axum::response::Response {
1084 if protocol_version != rpc::PROTOCOL_VERSION {
1085 return (
1086 StatusCode::UPGRADE_REQUIRED,
1087 "client must be upgraded".to_string(),
1088 )
1089 .into_response();
1090 }
1091
1092 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1093 return (
1094 StatusCode::UPGRADE_REQUIRED,
1095 "no version header found".to_string(),
1096 )
1097 .into_response();
1098 };
1099
1100 if !version.can_collaborate() {
1101 return (
1102 StatusCode::UPGRADE_REQUIRED,
1103 "client must be upgraded".to_string(),
1104 )
1105 .into_response();
1106 }
1107
1108 let socket_address = socket_address.to_string();
1109 ws.on_upgrade(move |socket| {
1110 let socket = socket
1111 .map_ok(to_tungstenite_message)
1112 .err_into()
1113 .with(|message| async move { to_axum_message(message) });
1114 let connection = Connection::new(Box::pin(socket));
1115 async move {
1116 server
1117 .handle_connection(
1118 connection,
1119 socket_address,
1120 principal,
1121 version,
1122 country_code_header.map(|header| header.to_string()),
1123 system_id_header.map(|header| header.to_string()),
1124 None,
1125 Executor::Production,
1126 )
1127 .await;
1128 }
1129 })
1130}
1131
1132pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1133 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1134 let connections_metric = CONNECTIONS_METRIC
1135 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1136
1137 let connections = server
1138 .connection_pool
1139 .lock()
1140 .connections()
1141 .filter(|connection| !connection.admin)
1142 .count();
1143 connections_metric.set(connections as _);
1144
1145 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1146 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1147 register_int_gauge!(
1148 "shared_projects",
1149 "number of open projects with one or more guests"
1150 )
1151 .unwrap()
1152 });
1153
1154 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1155 shared_projects_metric.set(shared_projects as _);
1156
1157 let encoder = prometheus::TextEncoder::new();
1158 let metric_families = prometheus::gather();
1159 let encoded_metrics = encoder
1160 .encode_to_string(&metric_families)
1161 .map_err(|err| anyhow!("{}", err))?;
1162 Ok(encoded_metrics)
1163}
1164
1165#[instrument(err, skip(executor))]
1166async fn connection_lost(
1167 session: Session,
1168 mut teardown: watch::Receiver<bool>,
1169 executor: Executor,
1170) -> Result<()> {
1171 session.peer.disconnect(session.connection_id);
1172 session
1173 .connection_pool()
1174 .await
1175 .remove_connection(session.connection_id)?;
1176
1177 session
1178 .db()
1179 .await
1180 .connection_lost(session.connection_id)
1181 .await
1182 .trace_err();
1183
1184 futures::select_biased! {
1185 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1186
1187 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1188 leave_room_for_session(&session, session.connection_id).await.trace_err();
1189 leave_channel_buffers_for_session(&session)
1190 .await
1191 .trace_err();
1192
1193 if !session
1194 .connection_pool()
1195 .await
1196 .is_user_online(session.user_id())
1197 {
1198 let db = session.db().await;
1199 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1200 room_updated(&room, &session.peer);
1201 }
1202 }
1203
1204 update_user_contacts(session.user_id(), &session).await?;
1205 },
1206 _ = teardown.changed().fuse() => {}
1207 }
1208
1209 Ok(())
1210}
1211
1212/// Acknowledges a ping from a client, used to keep the connection alive.
1213async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1214 response.send(proto::Ack {})?;
1215 Ok(())
1216}
1217
1218/// Creates a new room for calling (outside of channels)
1219async fn create_room(
1220 _request: proto::CreateRoom,
1221 response: Response<proto::CreateRoom>,
1222 session: Session,
1223) -> Result<()> {
1224 let livekit_room = nanoid::nanoid!(30);
1225
1226 let live_kit_connection_info = util::maybe!(async {
1227 let live_kit = session.app_state.livekit_client.as_ref();
1228 let live_kit = live_kit?;
1229 let user_id = session.user_id().to_string();
1230
1231 let token = live_kit
1232 .room_token(&livekit_room, &user_id.to_string())
1233 .trace_err()?;
1234
1235 Some(proto::LiveKitConnectionInfo {
1236 server_url: live_kit.url().into(),
1237 token,
1238 can_publish: true,
1239 })
1240 })
1241 .await;
1242
1243 let room = session
1244 .db()
1245 .await
1246 .create_room(session.user_id(), session.connection_id, &livekit_room)
1247 .await?;
1248
1249 response.send(proto::CreateRoomResponse {
1250 room: Some(room.clone()),
1251 live_kit_connection_info,
1252 })?;
1253
1254 update_user_contacts(session.user_id(), &session).await?;
1255 Ok(())
1256}
1257
1258/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1259async fn join_room(
1260 request: proto::JoinRoom,
1261 response: Response<proto::JoinRoom>,
1262 session: Session,
1263) -> Result<()> {
1264 let room_id = RoomId::from_proto(request.id);
1265
1266 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1267
1268 if let Some(channel_id) = channel_id {
1269 return join_channel_internal(channel_id, Box::new(response), session).await;
1270 }
1271
1272 let joined_room = {
1273 let room = session
1274 .db()
1275 .await
1276 .join_room(room_id, session.user_id(), session.connection_id)
1277 .await?;
1278 room_updated(&room.room, &session.peer);
1279 room.into_inner()
1280 };
1281
1282 for connection_id in session
1283 .connection_pool()
1284 .await
1285 .user_connection_ids(session.user_id())
1286 {
1287 session
1288 .peer
1289 .send(
1290 connection_id,
1291 proto::CallCanceled {
1292 room_id: room_id.to_proto(),
1293 },
1294 )
1295 .trace_err();
1296 }
1297
1298 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1299 {
1300 live_kit
1301 .room_token(
1302 &joined_room.room.livekit_room,
1303 &session.user_id().to_string(),
1304 )
1305 .trace_err()
1306 .map(|token| proto::LiveKitConnectionInfo {
1307 server_url: live_kit.url().into(),
1308 token,
1309 can_publish: true,
1310 })
1311 } else {
1312 None
1313 };
1314
1315 response.send(proto::JoinRoomResponse {
1316 room: Some(joined_room.room),
1317 channel_id: None,
1318 live_kit_connection_info,
1319 })?;
1320
1321 update_user_contacts(session.user_id(), &session).await?;
1322 Ok(())
1323}
1324
1325/// Rejoin room is used to reconnect to a room after connection errors.
1326async fn rejoin_room(
1327 request: proto::RejoinRoom,
1328 response: Response<proto::RejoinRoom>,
1329 session: Session,
1330) -> Result<()> {
1331 let room;
1332 let channel;
1333 {
1334 let mut rejoined_room = session
1335 .db()
1336 .await
1337 .rejoin_room(request, session.user_id(), session.connection_id)
1338 .await?;
1339
1340 response.send(proto::RejoinRoomResponse {
1341 room: Some(rejoined_room.room.clone()),
1342 reshared_projects: rejoined_room
1343 .reshared_projects
1344 .iter()
1345 .map(|project| proto::ResharedProject {
1346 id: project.id.to_proto(),
1347 collaborators: project
1348 .collaborators
1349 .iter()
1350 .map(|collaborator| collaborator.to_proto())
1351 .collect(),
1352 })
1353 .collect(),
1354 rejoined_projects: rejoined_room
1355 .rejoined_projects
1356 .iter()
1357 .map(|rejoined_project| rejoined_project.to_proto())
1358 .collect(),
1359 })?;
1360 room_updated(&rejoined_room.room, &session.peer);
1361
1362 for project in &rejoined_room.reshared_projects {
1363 for collaborator in &project.collaborators {
1364 session
1365 .peer
1366 .send(
1367 collaborator.connection_id,
1368 proto::UpdateProjectCollaborator {
1369 project_id: project.id.to_proto(),
1370 old_peer_id: Some(project.old_connection_id.into()),
1371 new_peer_id: Some(session.connection_id.into()),
1372 },
1373 )
1374 .trace_err();
1375 }
1376
1377 broadcast(
1378 Some(session.connection_id),
1379 project
1380 .collaborators
1381 .iter()
1382 .map(|collaborator| collaborator.connection_id),
1383 |connection_id| {
1384 session.peer.forward_send(
1385 session.connection_id,
1386 connection_id,
1387 proto::UpdateProject {
1388 project_id: project.id.to_proto(),
1389 worktrees: project.worktrees.clone(),
1390 },
1391 )
1392 },
1393 );
1394 }
1395
1396 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1397
1398 let rejoined_room = rejoined_room.into_inner();
1399
1400 room = rejoined_room.room;
1401 channel = rejoined_room.channel;
1402 }
1403
1404 if let Some(channel) = channel {
1405 channel_updated(
1406 &channel,
1407 &room,
1408 &session.peer,
1409 &*session.connection_pool().await,
1410 );
1411 }
1412
1413 update_user_contacts(session.user_id(), &session).await?;
1414 Ok(())
1415}
1416
1417fn notify_rejoined_projects(
1418 rejoined_projects: &mut Vec<RejoinedProject>,
1419 session: &Session,
1420) -> Result<()> {
1421 for project in rejoined_projects.iter() {
1422 for collaborator in &project.collaborators {
1423 session
1424 .peer
1425 .send(
1426 collaborator.connection_id,
1427 proto::UpdateProjectCollaborator {
1428 project_id: project.id.to_proto(),
1429 old_peer_id: Some(project.old_connection_id.into()),
1430 new_peer_id: Some(session.connection_id.into()),
1431 },
1432 )
1433 .trace_err();
1434 }
1435 }
1436
1437 for project in rejoined_projects {
1438 for worktree in mem::take(&mut project.worktrees) {
1439 // Stream this worktree's entries.
1440 let message = proto::UpdateWorktree {
1441 project_id: project.id.to_proto(),
1442 worktree_id: worktree.id,
1443 abs_path: worktree.abs_path.clone(),
1444 root_name: worktree.root_name,
1445 updated_entries: worktree.updated_entries,
1446 removed_entries: worktree.removed_entries,
1447 scan_id: worktree.scan_id,
1448 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1449 updated_repositories: worktree.updated_repositories,
1450 removed_repositories: worktree.removed_repositories,
1451 };
1452 for update in proto::split_worktree_update(message) {
1453 session.peer.send(session.connection_id, update.clone())?;
1454 }
1455
1456 // Stream this worktree's diagnostics.
1457 for summary in worktree.diagnostic_summaries {
1458 session.peer.send(
1459 session.connection_id,
1460 proto::UpdateDiagnosticSummary {
1461 project_id: project.id.to_proto(),
1462 worktree_id: worktree.id,
1463 summary: Some(summary),
1464 },
1465 )?;
1466 }
1467
1468 for settings_file in worktree.settings_files {
1469 session.peer.send(
1470 session.connection_id,
1471 proto::UpdateWorktreeSettings {
1472 project_id: project.id.to_proto(),
1473 worktree_id: worktree.id,
1474 path: settings_file.path,
1475 content: Some(settings_file.content),
1476 kind: Some(settings_file.kind.to_proto().into()),
1477 },
1478 )?;
1479 }
1480 }
1481
1482 for language_server in &project.language_servers {
1483 session.peer.send(
1484 session.connection_id,
1485 proto::UpdateLanguageServer {
1486 project_id: project.id.to_proto(),
1487 language_server_id: language_server.id,
1488 variant: Some(
1489 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1490 proto::LspDiskBasedDiagnosticsUpdated {},
1491 ),
1492 ),
1493 },
1494 )?;
1495 }
1496 }
1497 Ok(())
1498}
1499
1500/// leave room disconnects from the room.
1501async fn leave_room(
1502 _: proto::LeaveRoom,
1503 response: Response<proto::LeaveRoom>,
1504 session: Session,
1505) -> Result<()> {
1506 leave_room_for_session(&session, session.connection_id).await?;
1507 response.send(proto::Ack {})?;
1508 Ok(())
1509}
1510
1511/// Updates the permissions of someone else in the room.
1512async fn set_room_participant_role(
1513 request: proto::SetRoomParticipantRole,
1514 response: Response<proto::SetRoomParticipantRole>,
1515 session: Session,
1516) -> Result<()> {
1517 let user_id = UserId::from_proto(request.user_id);
1518 let role = ChannelRole::from(request.role());
1519
1520 let (livekit_room, can_publish) = {
1521 let room = session
1522 .db()
1523 .await
1524 .set_room_participant_role(
1525 session.user_id(),
1526 RoomId::from_proto(request.room_id),
1527 user_id,
1528 role,
1529 )
1530 .await?;
1531
1532 let livekit_room = room.livekit_room.clone();
1533 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1534 room_updated(&room, &session.peer);
1535 (livekit_room, can_publish)
1536 };
1537
1538 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1539 live_kit
1540 .update_participant(
1541 livekit_room.clone(),
1542 request.user_id.to_string(),
1543 livekit_server::proto::ParticipantPermission {
1544 can_subscribe: true,
1545 can_publish,
1546 can_publish_data: can_publish,
1547 hidden: false,
1548 recorder: false,
1549 },
1550 )
1551 .await
1552 .trace_err();
1553 }
1554
1555 response.send(proto::Ack {})?;
1556 Ok(())
1557}
1558
1559/// Call someone else into the current room
1560async fn call(
1561 request: proto::Call,
1562 response: Response<proto::Call>,
1563 session: Session,
1564) -> Result<()> {
1565 let room_id = RoomId::from_proto(request.room_id);
1566 let calling_user_id = session.user_id();
1567 let calling_connection_id = session.connection_id;
1568 let called_user_id = UserId::from_proto(request.called_user_id);
1569 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1570 if !session
1571 .db()
1572 .await
1573 .has_contact(calling_user_id, called_user_id)
1574 .await?
1575 {
1576 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1577 }
1578
1579 let incoming_call = {
1580 let (room, incoming_call) = &mut *session
1581 .db()
1582 .await
1583 .call(
1584 room_id,
1585 calling_user_id,
1586 calling_connection_id,
1587 called_user_id,
1588 initial_project_id,
1589 )
1590 .await?;
1591 room_updated(room, &session.peer);
1592 mem::take(incoming_call)
1593 };
1594 update_user_contacts(called_user_id, &session).await?;
1595
1596 let mut calls = session
1597 .connection_pool()
1598 .await
1599 .user_connection_ids(called_user_id)
1600 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1601 .collect::<FuturesUnordered<_>>();
1602
1603 while let Some(call_response) = calls.next().await {
1604 match call_response.as_ref() {
1605 Ok(_) => {
1606 response.send(proto::Ack {})?;
1607 return Ok(());
1608 }
1609 Err(_) => {
1610 call_response.trace_err();
1611 }
1612 }
1613 }
1614
1615 {
1616 let room = session
1617 .db()
1618 .await
1619 .call_failed(room_id, called_user_id)
1620 .await?;
1621 room_updated(&room, &session.peer);
1622 }
1623 update_user_contacts(called_user_id, &session).await?;
1624
1625 Err(anyhow!("failed to ring user"))?
1626}
1627
1628/// Cancel an outgoing call.
1629async fn cancel_call(
1630 request: proto::CancelCall,
1631 response: Response<proto::CancelCall>,
1632 session: Session,
1633) -> Result<()> {
1634 let called_user_id = UserId::from_proto(request.called_user_id);
1635 let room_id = RoomId::from_proto(request.room_id);
1636 {
1637 let room = session
1638 .db()
1639 .await
1640 .cancel_call(room_id, session.connection_id, called_user_id)
1641 .await?;
1642 room_updated(&room, &session.peer);
1643 }
1644
1645 for connection_id in session
1646 .connection_pool()
1647 .await
1648 .user_connection_ids(called_user_id)
1649 {
1650 session
1651 .peer
1652 .send(
1653 connection_id,
1654 proto::CallCanceled {
1655 room_id: room_id.to_proto(),
1656 },
1657 )
1658 .trace_err();
1659 }
1660 response.send(proto::Ack {})?;
1661
1662 update_user_contacts(called_user_id, &session).await?;
1663 Ok(())
1664}
1665
1666/// Decline an incoming call.
1667async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1668 let room_id = RoomId::from_proto(message.room_id);
1669 {
1670 let room = session
1671 .db()
1672 .await
1673 .decline_call(Some(room_id), session.user_id())
1674 .await?
1675 .ok_or_else(|| anyhow!("failed to decline call"))?;
1676 room_updated(&room, &session.peer);
1677 }
1678
1679 for connection_id in session
1680 .connection_pool()
1681 .await
1682 .user_connection_ids(session.user_id())
1683 {
1684 session
1685 .peer
1686 .send(
1687 connection_id,
1688 proto::CallCanceled {
1689 room_id: room_id.to_proto(),
1690 },
1691 )
1692 .trace_err();
1693 }
1694 update_user_contacts(session.user_id(), &session).await?;
1695 Ok(())
1696}
1697
1698/// Updates other participants in the room with your current location.
1699async fn update_participant_location(
1700 request: proto::UpdateParticipantLocation,
1701 response: Response<proto::UpdateParticipantLocation>,
1702 session: Session,
1703) -> Result<()> {
1704 let room_id = RoomId::from_proto(request.room_id);
1705 let location = request
1706 .location
1707 .ok_or_else(|| anyhow!("invalid location"))?;
1708
1709 let db = session.db().await;
1710 let room = db
1711 .update_room_participant_location(room_id, session.connection_id, location)
1712 .await?;
1713
1714 room_updated(&room, &session.peer);
1715 response.send(proto::Ack {})?;
1716 Ok(())
1717}
1718
1719/// Share a project into the room.
1720async fn share_project(
1721 request: proto::ShareProject,
1722 response: Response<proto::ShareProject>,
1723 session: Session,
1724) -> Result<()> {
1725 let (project_id, room) = &*session
1726 .db()
1727 .await
1728 .share_project(
1729 RoomId::from_proto(request.room_id),
1730 session.connection_id,
1731 &request.worktrees,
1732 request.is_ssh_project,
1733 )
1734 .await?;
1735 response.send(proto::ShareProjectResponse {
1736 project_id: project_id.to_proto(),
1737 })?;
1738 room_updated(room, &session.peer);
1739
1740 Ok(())
1741}
1742
1743/// Unshare a project from the room.
1744async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1745 let project_id = ProjectId::from_proto(message.project_id);
1746 unshare_project_internal(project_id, session.connection_id, &session).await
1747}
1748
1749async fn unshare_project_internal(
1750 project_id: ProjectId,
1751 connection_id: ConnectionId,
1752 session: &Session,
1753) -> Result<()> {
1754 let delete = {
1755 let room_guard = session
1756 .db()
1757 .await
1758 .unshare_project(project_id, connection_id)
1759 .await?;
1760
1761 let (delete, room, guest_connection_ids) = &*room_guard;
1762
1763 let message = proto::UnshareProject {
1764 project_id: project_id.to_proto(),
1765 };
1766
1767 broadcast(
1768 Some(connection_id),
1769 guest_connection_ids.iter().copied(),
1770 |conn_id| session.peer.send(conn_id, message.clone()),
1771 );
1772 if let Some(room) = room {
1773 room_updated(room, &session.peer);
1774 }
1775
1776 *delete
1777 };
1778
1779 if delete {
1780 let db = session.db().await;
1781 db.delete_project(project_id).await?;
1782 }
1783
1784 Ok(())
1785}
1786
1787/// Join someone elses shared project.
1788async fn join_project(
1789 request: proto::JoinProject,
1790 response: Response<proto::JoinProject>,
1791 session: Session,
1792) -> Result<()> {
1793 let project_id = ProjectId::from_proto(request.project_id);
1794
1795 tracing::info!(%project_id, "join project");
1796
1797 let db = session.db().await;
1798 let (project, replica_id) = &mut *db
1799 .join_project(project_id, session.connection_id, session.user_id())
1800 .await?;
1801 drop(db);
1802 tracing::info!(%project_id, "join remote project");
1803 join_project_internal(response, session, project, replica_id)
1804}
1805
1806trait JoinProjectInternalResponse {
1807 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1808}
1809impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1810 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1811 Response::<proto::JoinProject>::send(self, result)
1812 }
1813}
1814
1815fn join_project_internal(
1816 response: impl JoinProjectInternalResponse,
1817 session: Session,
1818 project: &mut Project,
1819 replica_id: &ReplicaId,
1820) -> Result<()> {
1821 let collaborators = project
1822 .collaborators
1823 .iter()
1824 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1825 .map(|collaborator| collaborator.to_proto())
1826 .collect::<Vec<_>>();
1827 let project_id = project.id;
1828 let guest_user_id = session.user_id();
1829
1830 let worktrees = project
1831 .worktrees
1832 .iter()
1833 .map(|(id, worktree)| proto::WorktreeMetadata {
1834 id: *id,
1835 root_name: worktree.root_name.clone(),
1836 visible: worktree.visible,
1837 abs_path: worktree.abs_path.clone(),
1838 })
1839 .collect::<Vec<_>>();
1840
1841 let add_project_collaborator = proto::AddProjectCollaborator {
1842 project_id: project_id.to_proto(),
1843 collaborator: Some(proto::Collaborator {
1844 peer_id: Some(session.connection_id.into()),
1845 replica_id: replica_id.0 as u32,
1846 user_id: guest_user_id.to_proto(),
1847 is_host: false,
1848 }),
1849 };
1850
1851 for collaborator in &collaborators {
1852 session
1853 .peer
1854 .send(
1855 collaborator.peer_id.unwrap().into(),
1856 add_project_collaborator.clone(),
1857 )
1858 .trace_err();
1859 }
1860
1861 // First, we send the metadata associated with each worktree.
1862 response.send(proto::JoinProjectResponse {
1863 project_id: project.id.0 as u64,
1864 worktrees: worktrees.clone(),
1865 replica_id: replica_id.0 as u32,
1866 collaborators: collaborators.clone(),
1867 language_servers: project.language_servers.clone(),
1868 role: project.role.into(),
1869 })?;
1870
1871 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1872 // Stream this worktree's entries.
1873 let message = proto::UpdateWorktree {
1874 project_id: project_id.to_proto(),
1875 worktree_id,
1876 abs_path: worktree.abs_path.clone(),
1877 root_name: worktree.root_name,
1878 updated_entries: worktree.entries,
1879 removed_entries: Default::default(),
1880 scan_id: worktree.scan_id,
1881 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1882 updated_repositories: worktree.repository_entries.into_values().collect(),
1883 removed_repositories: Default::default(),
1884 };
1885 for update in proto::split_worktree_update(message) {
1886 session.peer.send(session.connection_id, update.clone())?;
1887 }
1888
1889 // Stream this worktree's diagnostics.
1890 for summary in worktree.diagnostic_summaries {
1891 session.peer.send(
1892 session.connection_id,
1893 proto::UpdateDiagnosticSummary {
1894 project_id: project_id.to_proto(),
1895 worktree_id: worktree.id,
1896 summary: Some(summary),
1897 },
1898 )?;
1899 }
1900
1901 for settings_file in worktree.settings_files {
1902 session.peer.send(
1903 session.connection_id,
1904 proto::UpdateWorktreeSettings {
1905 project_id: project_id.to_proto(),
1906 worktree_id: worktree.id,
1907 path: settings_file.path,
1908 content: Some(settings_file.content),
1909 kind: Some(settings_file.kind.to_proto() as i32),
1910 },
1911 )?;
1912 }
1913 }
1914
1915 for language_server in &project.language_servers {
1916 session.peer.send(
1917 session.connection_id,
1918 proto::UpdateLanguageServer {
1919 project_id: project_id.to_proto(),
1920 language_server_id: language_server.id,
1921 variant: Some(
1922 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1923 proto::LspDiskBasedDiagnosticsUpdated {},
1924 ),
1925 ),
1926 },
1927 )?;
1928 }
1929
1930 Ok(())
1931}
1932
1933/// Leave someone elses shared project.
1934async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1935 let sender_id = session.connection_id;
1936 let project_id = ProjectId::from_proto(request.project_id);
1937 let db = session.db().await;
1938
1939 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1940 tracing::info!(
1941 %project_id,
1942 "leave project"
1943 );
1944
1945 project_left(project, &session);
1946 if let Some(room) = room {
1947 room_updated(room, &session.peer);
1948 }
1949
1950 Ok(())
1951}
1952
1953/// Updates other participants with changes to the project
1954async fn update_project(
1955 request: proto::UpdateProject,
1956 response: Response<proto::UpdateProject>,
1957 session: Session,
1958) -> Result<()> {
1959 let project_id = ProjectId::from_proto(request.project_id);
1960 let (room, guest_connection_ids) = &*session
1961 .db()
1962 .await
1963 .update_project(project_id, session.connection_id, &request.worktrees)
1964 .await?;
1965 broadcast(
1966 Some(session.connection_id),
1967 guest_connection_ids.iter().copied(),
1968 |connection_id| {
1969 session
1970 .peer
1971 .forward_send(session.connection_id, connection_id, request.clone())
1972 },
1973 );
1974 if let Some(room) = room {
1975 room_updated(room, &session.peer);
1976 }
1977 response.send(proto::Ack {})?;
1978
1979 Ok(())
1980}
1981
1982/// Updates other participants with changes to the worktree
1983async fn update_worktree(
1984 request: proto::UpdateWorktree,
1985 response: Response<proto::UpdateWorktree>,
1986 session: Session,
1987) -> Result<()> {
1988 let guest_connection_ids = session
1989 .db()
1990 .await
1991 .update_worktree(&request, session.connection_id)
1992 .await?;
1993
1994 broadcast(
1995 Some(session.connection_id),
1996 guest_connection_ids.iter().copied(),
1997 |connection_id| {
1998 session
1999 .peer
2000 .forward_send(session.connection_id, connection_id, request.clone())
2001 },
2002 );
2003 response.send(proto::Ack {})?;
2004 Ok(())
2005}
2006
2007/// Updates other participants with changes to the diagnostics
2008async fn update_diagnostic_summary(
2009 message: proto::UpdateDiagnosticSummary,
2010 session: Session,
2011) -> Result<()> {
2012 let guest_connection_ids = session
2013 .db()
2014 .await
2015 .update_diagnostic_summary(&message, session.connection_id)
2016 .await?;
2017
2018 broadcast(
2019 Some(session.connection_id),
2020 guest_connection_ids.iter().copied(),
2021 |connection_id| {
2022 session
2023 .peer
2024 .forward_send(session.connection_id, connection_id, message.clone())
2025 },
2026 );
2027
2028 Ok(())
2029}
2030
2031/// Updates other participants with changes to the worktree settings
2032async fn update_worktree_settings(
2033 message: proto::UpdateWorktreeSettings,
2034 session: Session,
2035) -> Result<()> {
2036 let guest_connection_ids = session
2037 .db()
2038 .await
2039 .update_worktree_settings(&message, session.connection_id)
2040 .await?;
2041
2042 broadcast(
2043 Some(session.connection_id),
2044 guest_connection_ids.iter().copied(),
2045 |connection_id| {
2046 session
2047 .peer
2048 .forward_send(session.connection_id, connection_id, message.clone())
2049 },
2050 );
2051
2052 Ok(())
2053}
2054
2055/// Notify other participants that a language server has started.
2056async fn start_language_server(
2057 request: proto::StartLanguageServer,
2058 session: Session,
2059) -> Result<()> {
2060 let guest_connection_ids = session
2061 .db()
2062 .await
2063 .start_language_server(&request, session.connection_id)
2064 .await?;
2065
2066 broadcast(
2067 Some(session.connection_id),
2068 guest_connection_ids.iter().copied(),
2069 |connection_id| {
2070 session
2071 .peer
2072 .forward_send(session.connection_id, connection_id, request.clone())
2073 },
2074 );
2075 Ok(())
2076}
2077
2078/// Notify other participants that a language server has changed.
2079async fn update_language_server(
2080 request: proto::UpdateLanguageServer,
2081 session: Session,
2082) -> Result<()> {
2083 let project_id = ProjectId::from_proto(request.project_id);
2084 let project_connection_ids = session
2085 .db()
2086 .await
2087 .project_connection_ids(project_id, session.connection_id, true)
2088 .await?;
2089 broadcast(
2090 Some(session.connection_id),
2091 project_connection_ids.iter().copied(),
2092 |connection_id| {
2093 session
2094 .peer
2095 .forward_send(session.connection_id, connection_id, request.clone())
2096 },
2097 );
2098 Ok(())
2099}
2100
2101/// forward a project request to the host. These requests should be read only
2102/// as guests are allowed to send them.
2103async fn forward_read_only_project_request<T>(
2104 request: T,
2105 response: Response<T>,
2106 session: Session,
2107) -> Result<()>
2108where
2109 T: EntityMessage + RequestMessage,
2110{
2111 let project_id = ProjectId::from_proto(request.remote_entity_id());
2112 let host_connection_id = session
2113 .db()
2114 .await
2115 .host_for_read_only_project_request(project_id, session.connection_id)
2116 .await?;
2117 let payload = session
2118 .peer
2119 .forward_request(session.connection_id, host_connection_id, request)
2120 .await?;
2121 response.send(payload)?;
2122 Ok(())
2123}
2124
2125async fn forward_find_search_candidates_request(
2126 request: proto::FindSearchCandidates,
2127 response: Response<proto::FindSearchCandidates>,
2128 session: Session,
2129) -> Result<()> {
2130 let project_id = ProjectId::from_proto(request.remote_entity_id());
2131 let host_connection_id = session
2132 .db()
2133 .await
2134 .host_for_read_only_project_request(project_id, session.connection_id)
2135 .await?;
2136 let payload = session
2137 .peer
2138 .forward_request(session.connection_id, host_connection_id, request)
2139 .await?;
2140 response.send(payload)?;
2141 Ok(())
2142}
2143
2144/// forward a project request to the host. These requests are disallowed
2145/// for guests.
2146async fn forward_mutating_project_request<T>(
2147 request: T,
2148 response: Response<T>,
2149 session: Session,
2150) -> Result<()>
2151where
2152 T: EntityMessage + RequestMessage,
2153{
2154 let project_id = ProjectId::from_proto(request.remote_entity_id());
2155
2156 let host_connection_id = session
2157 .db()
2158 .await
2159 .host_for_mutating_project_request(project_id, session.connection_id)
2160 .await?;
2161 let payload = session
2162 .peer
2163 .forward_request(session.connection_id, host_connection_id, request)
2164 .await?;
2165 response.send(payload)?;
2166 Ok(())
2167}
2168
2169/// Notify other participants that a new buffer has been created
2170async fn create_buffer_for_peer(
2171 request: proto::CreateBufferForPeer,
2172 session: Session,
2173) -> Result<()> {
2174 session
2175 .db()
2176 .await
2177 .check_user_is_project_host(
2178 ProjectId::from_proto(request.project_id),
2179 session.connection_id,
2180 )
2181 .await?;
2182 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2183 session
2184 .peer
2185 .forward_send(session.connection_id, peer_id.into(), request)?;
2186 Ok(())
2187}
2188
2189/// Notify other participants that a buffer has been updated. This is
2190/// allowed for guests as long as the update is limited to selections.
2191async fn update_buffer(
2192 request: proto::UpdateBuffer,
2193 response: Response<proto::UpdateBuffer>,
2194 session: Session,
2195) -> Result<()> {
2196 let project_id = ProjectId::from_proto(request.project_id);
2197 let mut capability = Capability::ReadOnly;
2198
2199 for op in request.operations.iter() {
2200 match op.variant {
2201 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2202 Some(_) => capability = Capability::ReadWrite,
2203 }
2204 }
2205
2206 let host = {
2207 let guard = session
2208 .db()
2209 .await
2210 .connections_for_buffer_update(project_id, session.connection_id, capability)
2211 .await?;
2212
2213 let (host, guests) = &*guard;
2214
2215 broadcast(
2216 Some(session.connection_id),
2217 guests.clone(),
2218 |connection_id| {
2219 session
2220 .peer
2221 .forward_send(session.connection_id, connection_id, request.clone())
2222 },
2223 );
2224
2225 *host
2226 };
2227
2228 if host != session.connection_id {
2229 session
2230 .peer
2231 .forward_request(session.connection_id, host, request.clone())
2232 .await?;
2233 }
2234
2235 response.send(proto::Ack {})?;
2236 Ok(())
2237}
2238
2239async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2240 let project_id = ProjectId::from_proto(message.project_id);
2241
2242 let operation = message.operation.as_ref().context("invalid operation")?;
2243 let capability = match operation.variant.as_ref() {
2244 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2245 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2246 match buffer_op.variant {
2247 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2248 Capability::ReadOnly
2249 }
2250 _ => Capability::ReadWrite,
2251 }
2252 } else {
2253 Capability::ReadWrite
2254 }
2255 }
2256 Some(_) => Capability::ReadWrite,
2257 None => Capability::ReadOnly,
2258 };
2259
2260 let guard = session
2261 .db()
2262 .await
2263 .connections_for_buffer_update(project_id, session.connection_id, capability)
2264 .await?;
2265
2266 let (host, guests) = &*guard;
2267
2268 broadcast(
2269 Some(session.connection_id),
2270 guests.iter().chain([host]).copied(),
2271 |connection_id| {
2272 session
2273 .peer
2274 .forward_send(session.connection_id, connection_id, message.clone())
2275 },
2276 );
2277
2278 Ok(())
2279}
2280
2281/// Notify other participants that a project has been updated.
2282async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2283 request: T,
2284 session: Session,
2285) -> Result<()> {
2286 let project_id = ProjectId::from_proto(request.remote_entity_id());
2287 let project_connection_ids = session
2288 .db()
2289 .await
2290 .project_connection_ids(project_id, session.connection_id, false)
2291 .await?;
2292
2293 broadcast(
2294 Some(session.connection_id),
2295 project_connection_ids.iter().copied(),
2296 |connection_id| {
2297 session
2298 .peer
2299 .forward_send(session.connection_id, connection_id, request.clone())
2300 },
2301 );
2302 Ok(())
2303}
2304
2305/// Start following another user in a call.
2306async fn follow(
2307 request: proto::Follow,
2308 response: Response<proto::Follow>,
2309 session: Session,
2310) -> Result<()> {
2311 let room_id = RoomId::from_proto(request.room_id);
2312 let project_id = request.project_id.map(ProjectId::from_proto);
2313 let leader_id = request
2314 .leader_id
2315 .ok_or_else(|| anyhow!("invalid leader id"))?
2316 .into();
2317 let follower_id = session.connection_id;
2318
2319 session
2320 .db()
2321 .await
2322 .check_room_participants(room_id, leader_id, session.connection_id)
2323 .await?;
2324
2325 let response_payload = session
2326 .peer
2327 .forward_request(session.connection_id, leader_id, request)
2328 .await?;
2329 response.send(response_payload)?;
2330
2331 if let Some(project_id) = project_id {
2332 let room = session
2333 .db()
2334 .await
2335 .follow(room_id, project_id, leader_id, follower_id)
2336 .await?;
2337 room_updated(&room, &session.peer);
2338 }
2339
2340 Ok(())
2341}
2342
2343/// Stop following another user in a call.
2344async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2345 let room_id = RoomId::from_proto(request.room_id);
2346 let project_id = request.project_id.map(ProjectId::from_proto);
2347 let leader_id = request
2348 .leader_id
2349 .ok_or_else(|| anyhow!("invalid leader id"))?
2350 .into();
2351 let follower_id = session.connection_id;
2352
2353 session
2354 .db()
2355 .await
2356 .check_room_participants(room_id, leader_id, session.connection_id)
2357 .await?;
2358
2359 session
2360 .peer
2361 .forward_send(session.connection_id, leader_id, request)?;
2362
2363 if let Some(project_id) = project_id {
2364 let room = session
2365 .db()
2366 .await
2367 .unfollow(room_id, project_id, leader_id, follower_id)
2368 .await?;
2369 room_updated(&room, &session.peer);
2370 }
2371
2372 Ok(())
2373}
2374
2375/// Notify everyone following you of your current location.
2376async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2377 let room_id = RoomId::from_proto(request.room_id);
2378 let database = session.db.lock().await;
2379
2380 let connection_ids = if let Some(project_id) = request.project_id {
2381 let project_id = ProjectId::from_proto(project_id);
2382 database
2383 .project_connection_ids(project_id, session.connection_id, true)
2384 .await?
2385 } else {
2386 database
2387 .room_connection_ids(room_id, session.connection_id)
2388 .await?
2389 };
2390
2391 // For now, don't send view update messages back to that view's current leader.
2392 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2393 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2394 _ => None,
2395 });
2396
2397 for connection_id in connection_ids.iter().cloned() {
2398 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2399 session
2400 .peer
2401 .forward_send(session.connection_id, connection_id, request.clone())?;
2402 }
2403 }
2404 Ok(())
2405}
2406
2407/// Get public data about users.
2408async fn get_users(
2409 request: proto::GetUsers,
2410 response: Response<proto::GetUsers>,
2411 session: Session,
2412) -> Result<()> {
2413 let user_ids = request
2414 .user_ids
2415 .into_iter()
2416 .map(UserId::from_proto)
2417 .collect();
2418 let users = session
2419 .db()
2420 .await
2421 .get_users_by_ids(user_ids)
2422 .await?
2423 .into_iter()
2424 .map(|user| proto::User {
2425 id: user.id.to_proto(),
2426 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2427 github_login: user.github_login,
2428 email: user.email_address,
2429 name: user.name,
2430 })
2431 .collect();
2432 response.send(proto::UsersResponse { users })?;
2433 Ok(())
2434}
2435
2436/// Search for users (to invite) buy Github login
2437async fn fuzzy_search_users(
2438 request: proto::FuzzySearchUsers,
2439 response: Response<proto::FuzzySearchUsers>,
2440 session: Session,
2441) -> Result<()> {
2442 let query = request.query;
2443 let users = match query.len() {
2444 0 => vec![],
2445 1 | 2 => session
2446 .db()
2447 .await
2448 .get_user_by_github_login(&query)
2449 .await?
2450 .into_iter()
2451 .collect(),
2452 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2453 };
2454 let users = users
2455 .into_iter()
2456 .filter(|user| user.id != session.user_id())
2457 .map(|user| proto::User {
2458 id: user.id.to_proto(),
2459 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2460 github_login: user.github_login,
2461 name: user.name,
2462 email: user.email_address,
2463 })
2464 .collect();
2465 response.send(proto::UsersResponse { users })?;
2466 Ok(())
2467}
2468
2469/// Send a contact request to another user.
2470async fn request_contact(
2471 request: proto::RequestContact,
2472 response: Response<proto::RequestContact>,
2473 session: Session,
2474) -> Result<()> {
2475 let requester_id = session.user_id();
2476 let responder_id = UserId::from_proto(request.responder_id);
2477 if requester_id == responder_id {
2478 return Err(anyhow!("cannot add yourself as a contact"))?;
2479 }
2480
2481 let notifications = session
2482 .db()
2483 .await
2484 .send_contact_request(requester_id, responder_id)
2485 .await?;
2486
2487 // Update outgoing contact requests of requester
2488 let mut update = proto::UpdateContacts::default();
2489 update.outgoing_requests.push(responder_id.to_proto());
2490 for connection_id in session
2491 .connection_pool()
2492 .await
2493 .user_connection_ids(requester_id)
2494 {
2495 session.peer.send(connection_id, update.clone())?;
2496 }
2497
2498 // Update incoming contact requests of responder
2499 let mut update = proto::UpdateContacts::default();
2500 update
2501 .incoming_requests
2502 .push(proto::IncomingContactRequest {
2503 requester_id: requester_id.to_proto(),
2504 });
2505 let connection_pool = session.connection_pool().await;
2506 for connection_id in connection_pool.user_connection_ids(responder_id) {
2507 session.peer.send(connection_id, update.clone())?;
2508 }
2509
2510 send_notifications(&connection_pool, &session.peer, notifications);
2511
2512 response.send(proto::Ack {})?;
2513 Ok(())
2514}
2515
2516/// Accept or decline a contact request
2517async fn respond_to_contact_request(
2518 request: proto::RespondToContactRequest,
2519 response: Response<proto::RespondToContactRequest>,
2520 session: Session,
2521) -> Result<()> {
2522 let responder_id = session.user_id();
2523 let requester_id = UserId::from_proto(request.requester_id);
2524 let db = session.db().await;
2525 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2526 db.dismiss_contact_notification(responder_id, requester_id)
2527 .await?;
2528 } else {
2529 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2530
2531 let notifications = db
2532 .respond_to_contact_request(responder_id, requester_id, accept)
2533 .await?;
2534 let requester_busy = db.is_user_busy(requester_id).await?;
2535 let responder_busy = db.is_user_busy(responder_id).await?;
2536
2537 let pool = session.connection_pool().await;
2538 // Update responder with new contact
2539 let mut update = proto::UpdateContacts::default();
2540 if accept {
2541 update
2542 .contacts
2543 .push(contact_for_user(requester_id, requester_busy, &pool));
2544 }
2545 update
2546 .remove_incoming_requests
2547 .push(requester_id.to_proto());
2548 for connection_id in pool.user_connection_ids(responder_id) {
2549 session.peer.send(connection_id, update.clone())?;
2550 }
2551
2552 // Update requester with new contact
2553 let mut update = proto::UpdateContacts::default();
2554 if accept {
2555 update
2556 .contacts
2557 .push(contact_for_user(responder_id, responder_busy, &pool));
2558 }
2559 update
2560 .remove_outgoing_requests
2561 .push(responder_id.to_proto());
2562
2563 for connection_id in pool.user_connection_ids(requester_id) {
2564 session.peer.send(connection_id, update.clone())?;
2565 }
2566
2567 send_notifications(&pool, &session.peer, notifications);
2568 }
2569
2570 response.send(proto::Ack {})?;
2571 Ok(())
2572}
2573
2574/// Remove a contact.
2575async fn remove_contact(
2576 request: proto::RemoveContact,
2577 response: Response<proto::RemoveContact>,
2578 session: Session,
2579) -> Result<()> {
2580 let requester_id = session.user_id();
2581 let responder_id = UserId::from_proto(request.user_id);
2582 let db = session.db().await;
2583 let (contact_accepted, deleted_notification_id) =
2584 db.remove_contact(requester_id, responder_id).await?;
2585
2586 let pool = session.connection_pool().await;
2587 // Update outgoing contact requests of requester
2588 let mut update = proto::UpdateContacts::default();
2589 if contact_accepted {
2590 update.remove_contacts.push(responder_id.to_proto());
2591 } else {
2592 update
2593 .remove_outgoing_requests
2594 .push(responder_id.to_proto());
2595 }
2596 for connection_id in pool.user_connection_ids(requester_id) {
2597 session.peer.send(connection_id, update.clone())?;
2598 }
2599
2600 // Update incoming contact requests of responder
2601 let mut update = proto::UpdateContacts::default();
2602 if contact_accepted {
2603 update.remove_contacts.push(requester_id.to_proto());
2604 } else {
2605 update
2606 .remove_incoming_requests
2607 .push(requester_id.to_proto());
2608 }
2609 for connection_id in pool.user_connection_ids(responder_id) {
2610 session.peer.send(connection_id, update.clone())?;
2611 if let Some(notification_id) = deleted_notification_id {
2612 session.peer.send(
2613 connection_id,
2614 proto::DeleteNotification {
2615 notification_id: notification_id.to_proto(),
2616 },
2617 )?;
2618 }
2619 }
2620
2621 response.send(proto::Ack {})?;
2622 Ok(())
2623}
2624
2625fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2626 version.0.minor() < 139
2627}
2628
2629async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2630 let plan = session.current_plan(&session.db().await).await?;
2631
2632 session
2633 .peer
2634 .send(
2635 session.connection_id,
2636 proto::UpdateUserPlan { plan: plan.into() },
2637 )
2638 .trace_err();
2639
2640 Ok(())
2641}
2642
2643async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2644 subscribe_user_to_channels(session.user_id(), &session).await?;
2645 Ok(())
2646}
2647
2648async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2649 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2650 let mut pool = session.connection_pool().await;
2651 for membership in &channels_for_user.channel_memberships {
2652 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2653 }
2654 session.peer.send(
2655 session.connection_id,
2656 build_update_user_channels(&channels_for_user),
2657 )?;
2658 session.peer.send(
2659 session.connection_id,
2660 build_channels_update(channels_for_user),
2661 )?;
2662 Ok(())
2663}
2664
2665/// Creates a new channel.
2666async fn create_channel(
2667 request: proto::CreateChannel,
2668 response: Response<proto::CreateChannel>,
2669 session: Session,
2670) -> Result<()> {
2671 let db = session.db().await;
2672
2673 let parent_id = request.parent_id.map(ChannelId::from_proto);
2674 let (channel, membership) = db
2675 .create_channel(&request.name, parent_id, session.user_id())
2676 .await?;
2677
2678 let root_id = channel.root_id();
2679 let channel = Channel::from_model(channel);
2680
2681 response.send(proto::CreateChannelResponse {
2682 channel: Some(channel.to_proto()),
2683 parent_id: request.parent_id,
2684 })?;
2685
2686 let mut connection_pool = session.connection_pool().await;
2687 if let Some(membership) = membership {
2688 connection_pool.subscribe_to_channel(
2689 membership.user_id,
2690 membership.channel_id,
2691 membership.role,
2692 );
2693 let update = proto::UpdateUserChannels {
2694 channel_memberships: vec![proto::ChannelMembership {
2695 channel_id: membership.channel_id.to_proto(),
2696 role: membership.role.into(),
2697 }],
2698 ..Default::default()
2699 };
2700 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2701 session.peer.send(connection_id, update.clone())?;
2702 }
2703 }
2704
2705 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2706 if !role.can_see_channel(channel.visibility) {
2707 continue;
2708 }
2709
2710 let update = proto::UpdateChannels {
2711 channels: vec![channel.to_proto()],
2712 ..Default::default()
2713 };
2714 session.peer.send(connection_id, update.clone())?;
2715 }
2716
2717 Ok(())
2718}
2719
2720/// Delete a channel
2721async fn delete_channel(
2722 request: proto::DeleteChannel,
2723 response: Response<proto::DeleteChannel>,
2724 session: Session,
2725) -> Result<()> {
2726 let db = session.db().await;
2727
2728 let channel_id = request.channel_id;
2729 let (root_channel, removed_channels) = db
2730 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2731 .await?;
2732 response.send(proto::Ack {})?;
2733
2734 // Notify members of removed channels
2735 let mut update = proto::UpdateChannels::default();
2736 update
2737 .delete_channels
2738 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2739
2740 let connection_pool = session.connection_pool().await;
2741 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2742 session.peer.send(connection_id, update.clone())?;
2743 }
2744
2745 Ok(())
2746}
2747
2748/// Invite someone to join a channel.
2749async fn invite_channel_member(
2750 request: proto::InviteChannelMember,
2751 response: Response<proto::InviteChannelMember>,
2752 session: Session,
2753) -> Result<()> {
2754 let db = session.db().await;
2755 let channel_id = ChannelId::from_proto(request.channel_id);
2756 let invitee_id = UserId::from_proto(request.user_id);
2757 let InviteMemberResult {
2758 channel,
2759 notifications,
2760 } = db
2761 .invite_channel_member(
2762 channel_id,
2763 invitee_id,
2764 session.user_id(),
2765 request.role().into(),
2766 )
2767 .await?;
2768
2769 let update = proto::UpdateChannels {
2770 channel_invitations: vec![channel.to_proto()],
2771 ..Default::default()
2772 };
2773
2774 let connection_pool = session.connection_pool().await;
2775 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2776 session.peer.send(connection_id, update.clone())?;
2777 }
2778
2779 send_notifications(&connection_pool, &session.peer, notifications);
2780
2781 response.send(proto::Ack {})?;
2782 Ok(())
2783}
2784
2785/// remove someone from a channel
2786async fn remove_channel_member(
2787 request: proto::RemoveChannelMember,
2788 response: Response<proto::RemoveChannelMember>,
2789 session: Session,
2790) -> Result<()> {
2791 let db = session.db().await;
2792 let channel_id = ChannelId::from_proto(request.channel_id);
2793 let member_id = UserId::from_proto(request.user_id);
2794
2795 let RemoveChannelMemberResult {
2796 membership_update,
2797 notification_id,
2798 } = db
2799 .remove_channel_member(channel_id, member_id, session.user_id())
2800 .await?;
2801
2802 let mut connection_pool = session.connection_pool().await;
2803 notify_membership_updated(
2804 &mut connection_pool,
2805 membership_update,
2806 member_id,
2807 &session.peer,
2808 );
2809 for connection_id in connection_pool.user_connection_ids(member_id) {
2810 if let Some(notification_id) = notification_id {
2811 session
2812 .peer
2813 .send(
2814 connection_id,
2815 proto::DeleteNotification {
2816 notification_id: notification_id.to_proto(),
2817 },
2818 )
2819 .trace_err();
2820 }
2821 }
2822
2823 response.send(proto::Ack {})?;
2824 Ok(())
2825}
2826
2827/// Toggle the channel between public and private.
2828/// Care is taken to maintain the invariant that public channels only descend from public channels,
2829/// (though members-only channels can appear at any point in the hierarchy).
2830async fn set_channel_visibility(
2831 request: proto::SetChannelVisibility,
2832 response: Response<proto::SetChannelVisibility>,
2833 session: Session,
2834) -> Result<()> {
2835 let db = session.db().await;
2836 let channel_id = ChannelId::from_proto(request.channel_id);
2837 let visibility = request.visibility().into();
2838
2839 let channel_model = db
2840 .set_channel_visibility(channel_id, visibility, session.user_id())
2841 .await?;
2842 let root_id = channel_model.root_id();
2843 let channel = Channel::from_model(channel_model);
2844
2845 let mut connection_pool = session.connection_pool().await;
2846 for (user_id, role) in connection_pool
2847 .channel_user_ids(root_id)
2848 .collect::<Vec<_>>()
2849 .into_iter()
2850 {
2851 let update = if role.can_see_channel(channel.visibility) {
2852 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2853 proto::UpdateChannels {
2854 channels: vec![channel.to_proto()],
2855 ..Default::default()
2856 }
2857 } else {
2858 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2859 proto::UpdateChannels {
2860 delete_channels: vec![channel.id.to_proto()],
2861 ..Default::default()
2862 }
2863 };
2864
2865 for connection_id in connection_pool.user_connection_ids(user_id) {
2866 session.peer.send(connection_id, update.clone())?;
2867 }
2868 }
2869
2870 response.send(proto::Ack {})?;
2871 Ok(())
2872}
2873
2874/// Alter the role for a user in the channel.
2875async fn set_channel_member_role(
2876 request: proto::SetChannelMemberRole,
2877 response: Response<proto::SetChannelMemberRole>,
2878 session: Session,
2879) -> Result<()> {
2880 let db = session.db().await;
2881 let channel_id = ChannelId::from_proto(request.channel_id);
2882 let member_id = UserId::from_proto(request.user_id);
2883 let result = db
2884 .set_channel_member_role(
2885 channel_id,
2886 session.user_id(),
2887 member_id,
2888 request.role().into(),
2889 )
2890 .await?;
2891
2892 match result {
2893 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2894 let mut connection_pool = session.connection_pool().await;
2895 notify_membership_updated(
2896 &mut connection_pool,
2897 membership_update,
2898 member_id,
2899 &session.peer,
2900 )
2901 }
2902 db::SetMemberRoleResult::InviteUpdated(channel) => {
2903 let update = proto::UpdateChannels {
2904 channel_invitations: vec![channel.to_proto()],
2905 ..Default::default()
2906 };
2907
2908 for connection_id in session
2909 .connection_pool()
2910 .await
2911 .user_connection_ids(member_id)
2912 {
2913 session.peer.send(connection_id, update.clone())?;
2914 }
2915 }
2916 }
2917
2918 response.send(proto::Ack {})?;
2919 Ok(())
2920}
2921
2922/// Change the name of a channel
2923async fn rename_channel(
2924 request: proto::RenameChannel,
2925 response: Response<proto::RenameChannel>,
2926 session: Session,
2927) -> Result<()> {
2928 let db = session.db().await;
2929 let channel_id = ChannelId::from_proto(request.channel_id);
2930 let channel_model = db
2931 .rename_channel(channel_id, session.user_id(), &request.name)
2932 .await?;
2933 let root_id = channel_model.root_id();
2934 let channel = Channel::from_model(channel_model);
2935
2936 response.send(proto::RenameChannelResponse {
2937 channel: Some(channel.to_proto()),
2938 })?;
2939
2940 let connection_pool = session.connection_pool().await;
2941 let update = proto::UpdateChannels {
2942 channels: vec![channel.to_proto()],
2943 ..Default::default()
2944 };
2945 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2946 if role.can_see_channel(channel.visibility) {
2947 session.peer.send(connection_id, update.clone())?;
2948 }
2949 }
2950
2951 Ok(())
2952}
2953
2954/// Move a channel to a new parent.
2955async fn move_channel(
2956 request: proto::MoveChannel,
2957 response: Response<proto::MoveChannel>,
2958 session: Session,
2959) -> Result<()> {
2960 let channel_id = ChannelId::from_proto(request.channel_id);
2961 let to = ChannelId::from_proto(request.to);
2962
2963 let (root_id, channels) = session
2964 .db()
2965 .await
2966 .move_channel(channel_id, to, session.user_id())
2967 .await?;
2968
2969 let connection_pool = session.connection_pool().await;
2970 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2971 let channels = channels
2972 .iter()
2973 .filter_map(|channel| {
2974 if role.can_see_channel(channel.visibility) {
2975 Some(channel.to_proto())
2976 } else {
2977 None
2978 }
2979 })
2980 .collect::<Vec<_>>();
2981 if channels.is_empty() {
2982 continue;
2983 }
2984
2985 let update = proto::UpdateChannels {
2986 channels,
2987 ..Default::default()
2988 };
2989
2990 session.peer.send(connection_id, update.clone())?;
2991 }
2992
2993 response.send(Ack {})?;
2994 Ok(())
2995}
2996
2997/// Get the list of channel members
2998async fn get_channel_members(
2999 request: proto::GetChannelMembers,
3000 response: Response<proto::GetChannelMembers>,
3001 session: Session,
3002) -> Result<()> {
3003 let db = session.db().await;
3004 let channel_id = ChannelId::from_proto(request.channel_id);
3005 let limit = if request.limit == 0 {
3006 u16::MAX as u64
3007 } else {
3008 request.limit
3009 };
3010 let (members, users) = db
3011 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3012 .await?;
3013 response.send(proto::GetChannelMembersResponse { members, users })?;
3014 Ok(())
3015}
3016
3017/// Accept or decline a channel invitation.
3018async fn respond_to_channel_invite(
3019 request: proto::RespondToChannelInvite,
3020 response: Response<proto::RespondToChannelInvite>,
3021 session: Session,
3022) -> Result<()> {
3023 let db = session.db().await;
3024 let channel_id = ChannelId::from_proto(request.channel_id);
3025 let RespondToChannelInvite {
3026 membership_update,
3027 notifications,
3028 } = db
3029 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3030 .await?;
3031
3032 let mut connection_pool = session.connection_pool().await;
3033 if let Some(membership_update) = membership_update {
3034 notify_membership_updated(
3035 &mut connection_pool,
3036 membership_update,
3037 session.user_id(),
3038 &session.peer,
3039 );
3040 } else {
3041 let update = proto::UpdateChannels {
3042 remove_channel_invitations: vec![channel_id.to_proto()],
3043 ..Default::default()
3044 };
3045
3046 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3047 session.peer.send(connection_id, update.clone())?;
3048 }
3049 };
3050
3051 send_notifications(&connection_pool, &session.peer, notifications);
3052
3053 response.send(proto::Ack {})?;
3054
3055 Ok(())
3056}
3057
3058/// Join the channels' room
3059async fn join_channel(
3060 request: proto::JoinChannel,
3061 response: Response<proto::JoinChannel>,
3062 session: Session,
3063) -> Result<()> {
3064 let channel_id = ChannelId::from_proto(request.channel_id);
3065 join_channel_internal(channel_id, Box::new(response), session).await
3066}
3067
3068trait JoinChannelInternalResponse {
3069 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3070}
3071impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3072 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3073 Response::<proto::JoinChannel>::send(self, result)
3074 }
3075}
3076impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3077 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3078 Response::<proto::JoinRoom>::send(self, result)
3079 }
3080}
3081
3082async fn join_channel_internal(
3083 channel_id: ChannelId,
3084 response: Box<impl JoinChannelInternalResponse>,
3085 session: Session,
3086) -> Result<()> {
3087 let joined_room = {
3088 let mut db = session.db().await;
3089 // If zed quits without leaving the room, and the user re-opens zed before the
3090 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3091 // room they were in.
3092 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3093 tracing::info!(
3094 stale_connection_id = %connection,
3095 "cleaning up stale connection",
3096 );
3097 drop(db);
3098 leave_room_for_session(&session, connection).await?;
3099 db = session.db().await;
3100 }
3101
3102 let (joined_room, membership_updated, role) = db
3103 .join_channel(channel_id, session.user_id(), session.connection_id)
3104 .await?;
3105
3106 let live_kit_connection_info =
3107 session
3108 .app_state
3109 .livekit_client
3110 .as_ref()
3111 .and_then(|live_kit| {
3112 let (can_publish, token) = if role == ChannelRole::Guest {
3113 (
3114 false,
3115 live_kit
3116 .guest_token(
3117 &joined_room.room.livekit_room,
3118 &session.user_id().to_string(),
3119 )
3120 .trace_err()?,
3121 )
3122 } else {
3123 (
3124 true,
3125 live_kit
3126 .room_token(
3127 &joined_room.room.livekit_room,
3128 &session.user_id().to_string(),
3129 )
3130 .trace_err()?,
3131 )
3132 };
3133
3134 Some(LiveKitConnectionInfo {
3135 server_url: live_kit.url().into(),
3136 token,
3137 can_publish,
3138 })
3139 });
3140
3141 response.send(proto::JoinRoomResponse {
3142 room: Some(joined_room.room.clone()),
3143 channel_id: joined_room
3144 .channel
3145 .as_ref()
3146 .map(|channel| channel.id.to_proto()),
3147 live_kit_connection_info,
3148 })?;
3149
3150 let mut connection_pool = session.connection_pool().await;
3151 if let Some(membership_updated) = membership_updated {
3152 notify_membership_updated(
3153 &mut connection_pool,
3154 membership_updated,
3155 session.user_id(),
3156 &session.peer,
3157 );
3158 }
3159
3160 room_updated(&joined_room.room, &session.peer);
3161
3162 joined_room
3163 };
3164
3165 channel_updated(
3166 &joined_room
3167 .channel
3168 .ok_or_else(|| anyhow!("channel not returned"))?,
3169 &joined_room.room,
3170 &session.peer,
3171 &*session.connection_pool().await,
3172 );
3173
3174 update_user_contacts(session.user_id(), &session).await?;
3175 Ok(())
3176}
3177
3178/// Start editing the channel notes
3179async fn join_channel_buffer(
3180 request: proto::JoinChannelBuffer,
3181 response: Response<proto::JoinChannelBuffer>,
3182 session: Session,
3183) -> Result<()> {
3184 let db = session.db().await;
3185 let channel_id = ChannelId::from_proto(request.channel_id);
3186
3187 let open_response = db
3188 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3189 .await?;
3190
3191 let collaborators = open_response.collaborators.clone();
3192 response.send(open_response)?;
3193
3194 let update = UpdateChannelBufferCollaborators {
3195 channel_id: channel_id.to_proto(),
3196 collaborators: collaborators.clone(),
3197 };
3198 channel_buffer_updated(
3199 session.connection_id,
3200 collaborators
3201 .iter()
3202 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3203 &update,
3204 &session.peer,
3205 );
3206
3207 Ok(())
3208}
3209
3210/// Edit the channel notes
3211async fn update_channel_buffer(
3212 request: proto::UpdateChannelBuffer,
3213 session: Session,
3214) -> Result<()> {
3215 let db = session.db().await;
3216 let channel_id = ChannelId::from_proto(request.channel_id);
3217
3218 let (collaborators, epoch, version) = db
3219 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3220 .await?;
3221
3222 channel_buffer_updated(
3223 session.connection_id,
3224 collaborators.clone(),
3225 &proto::UpdateChannelBuffer {
3226 channel_id: channel_id.to_proto(),
3227 operations: request.operations,
3228 },
3229 &session.peer,
3230 );
3231
3232 let pool = &*session.connection_pool().await;
3233
3234 let non_collaborators =
3235 pool.channel_connection_ids(channel_id)
3236 .filter_map(|(connection_id, _)| {
3237 if collaborators.contains(&connection_id) {
3238 None
3239 } else {
3240 Some(connection_id)
3241 }
3242 });
3243
3244 broadcast(None, non_collaborators, |peer_id| {
3245 session.peer.send(
3246 peer_id,
3247 proto::UpdateChannels {
3248 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3249 channel_id: channel_id.to_proto(),
3250 epoch: epoch as u64,
3251 version: version.clone(),
3252 }],
3253 ..Default::default()
3254 },
3255 )
3256 });
3257
3258 Ok(())
3259}
3260
3261/// Rejoin the channel notes after a connection blip
3262async fn rejoin_channel_buffers(
3263 request: proto::RejoinChannelBuffers,
3264 response: Response<proto::RejoinChannelBuffers>,
3265 session: Session,
3266) -> Result<()> {
3267 let db = session.db().await;
3268 let buffers = db
3269 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3270 .await?;
3271
3272 for rejoined_buffer in &buffers {
3273 let collaborators_to_notify = rejoined_buffer
3274 .buffer
3275 .collaborators
3276 .iter()
3277 .filter_map(|c| Some(c.peer_id?.into()));
3278 channel_buffer_updated(
3279 session.connection_id,
3280 collaborators_to_notify,
3281 &proto::UpdateChannelBufferCollaborators {
3282 channel_id: rejoined_buffer.buffer.channel_id,
3283 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3284 },
3285 &session.peer,
3286 );
3287 }
3288
3289 response.send(proto::RejoinChannelBuffersResponse {
3290 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3291 })?;
3292
3293 Ok(())
3294}
3295
3296/// Stop editing the channel notes
3297async fn leave_channel_buffer(
3298 request: proto::LeaveChannelBuffer,
3299 response: Response<proto::LeaveChannelBuffer>,
3300 session: Session,
3301) -> Result<()> {
3302 let db = session.db().await;
3303 let channel_id = ChannelId::from_proto(request.channel_id);
3304
3305 let left_buffer = db
3306 .leave_channel_buffer(channel_id, session.connection_id)
3307 .await?;
3308
3309 response.send(Ack {})?;
3310
3311 channel_buffer_updated(
3312 session.connection_id,
3313 left_buffer.connections,
3314 &proto::UpdateChannelBufferCollaborators {
3315 channel_id: channel_id.to_proto(),
3316 collaborators: left_buffer.collaborators,
3317 },
3318 &session.peer,
3319 );
3320
3321 Ok(())
3322}
3323
3324fn channel_buffer_updated<T: EnvelopedMessage>(
3325 sender_id: ConnectionId,
3326 collaborators: impl IntoIterator<Item = ConnectionId>,
3327 message: &T,
3328 peer: &Peer,
3329) {
3330 broadcast(Some(sender_id), collaborators, |peer_id| {
3331 peer.send(peer_id, message.clone())
3332 });
3333}
3334
3335fn send_notifications(
3336 connection_pool: &ConnectionPool,
3337 peer: &Peer,
3338 notifications: db::NotificationBatch,
3339) {
3340 for (user_id, notification) in notifications {
3341 for connection_id in connection_pool.user_connection_ids(user_id) {
3342 if let Err(error) = peer.send(
3343 connection_id,
3344 proto::AddNotification {
3345 notification: Some(notification.clone()),
3346 },
3347 ) {
3348 tracing::error!(
3349 "failed to send notification to {:?} {}",
3350 connection_id,
3351 error
3352 );
3353 }
3354 }
3355 }
3356}
3357
3358/// Send a message to the channel
3359async fn send_channel_message(
3360 request: proto::SendChannelMessage,
3361 response: Response<proto::SendChannelMessage>,
3362 session: Session,
3363) -> Result<()> {
3364 // Validate the message body.
3365 let body = request.body.trim().to_string();
3366 if body.len() > MAX_MESSAGE_LEN {
3367 return Err(anyhow!("message is too long"))?;
3368 }
3369 if body.is_empty() {
3370 return Err(anyhow!("message can't be blank"))?;
3371 }
3372
3373 // TODO: adjust mentions if body is trimmed
3374
3375 let timestamp = OffsetDateTime::now_utc();
3376 let nonce = request
3377 .nonce
3378 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3379
3380 let channel_id = ChannelId::from_proto(request.channel_id);
3381 let CreatedChannelMessage {
3382 message_id,
3383 participant_connection_ids,
3384 notifications,
3385 } = session
3386 .db()
3387 .await
3388 .create_channel_message(
3389 channel_id,
3390 session.user_id(),
3391 &body,
3392 &request.mentions,
3393 timestamp,
3394 nonce.clone().into(),
3395 request.reply_to_message_id.map(MessageId::from_proto),
3396 )
3397 .await?;
3398
3399 let message = proto::ChannelMessage {
3400 sender_id: session.user_id().to_proto(),
3401 id: message_id.to_proto(),
3402 body,
3403 mentions: request.mentions,
3404 timestamp: timestamp.unix_timestamp() as u64,
3405 nonce: Some(nonce),
3406 reply_to_message_id: request.reply_to_message_id,
3407 edited_at: None,
3408 };
3409 broadcast(
3410 Some(session.connection_id),
3411 participant_connection_ids.clone(),
3412 |connection| {
3413 session.peer.send(
3414 connection,
3415 proto::ChannelMessageSent {
3416 channel_id: channel_id.to_proto(),
3417 message: Some(message.clone()),
3418 },
3419 )
3420 },
3421 );
3422 response.send(proto::SendChannelMessageResponse {
3423 message: Some(message),
3424 })?;
3425
3426 let pool = &*session.connection_pool().await;
3427 let non_participants =
3428 pool.channel_connection_ids(channel_id)
3429 .filter_map(|(connection_id, _)| {
3430 if participant_connection_ids.contains(&connection_id) {
3431 None
3432 } else {
3433 Some(connection_id)
3434 }
3435 });
3436 broadcast(None, non_participants, |peer_id| {
3437 session.peer.send(
3438 peer_id,
3439 proto::UpdateChannels {
3440 latest_channel_message_ids: vec![proto::ChannelMessageId {
3441 channel_id: channel_id.to_proto(),
3442 message_id: message_id.to_proto(),
3443 }],
3444 ..Default::default()
3445 },
3446 )
3447 });
3448 send_notifications(pool, &session.peer, notifications);
3449
3450 Ok(())
3451}
3452
3453/// Delete a channel message
3454async fn remove_channel_message(
3455 request: proto::RemoveChannelMessage,
3456 response: Response<proto::RemoveChannelMessage>,
3457 session: Session,
3458) -> Result<()> {
3459 let channel_id = ChannelId::from_proto(request.channel_id);
3460 let message_id = MessageId::from_proto(request.message_id);
3461 let (connection_ids, existing_notification_ids) = session
3462 .db()
3463 .await
3464 .remove_channel_message(channel_id, message_id, session.user_id())
3465 .await?;
3466
3467 broadcast(
3468 Some(session.connection_id),
3469 connection_ids,
3470 move |connection| {
3471 session.peer.send(connection, request.clone())?;
3472
3473 for notification_id in &existing_notification_ids {
3474 session.peer.send(
3475 connection,
3476 proto::DeleteNotification {
3477 notification_id: (*notification_id).to_proto(),
3478 },
3479 )?;
3480 }
3481
3482 Ok(())
3483 },
3484 );
3485 response.send(proto::Ack {})?;
3486 Ok(())
3487}
3488
3489async fn update_channel_message(
3490 request: proto::UpdateChannelMessage,
3491 response: Response<proto::UpdateChannelMessage>,
3492 session: Session,
3493) -> Result<()> {
3494 let channel_id = ChannelId::from_proto(request.channel_id);
3495 let message_id = MessageId::from_proto(request.message_id);
3496 let updated_at = OffsetDateTime::now_utc();
3497 let UpdatedChannelMessage {
3498 message_id,
3499 participant_connection_ids,
3500 notifications,
3501 reply_to_message_id,
3502 timestamp,
3503 deleted_mention_notification_ids,
3504 updated_mention_notifications,
3505 } = session
3506 .db()
3507 .await
3508 .update_channel_message(
3509 channel_id,
3510 message_id,
3511 session.user_id(),
3512 request.body.as_str(),
3513 &request.mentions,
3514 updated_at,
3515 )
3516 .await?;
3517
3518 let nonce = request
3519 .nonce
3520 .clone()
3521 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3522
3523 let message = proto::ChannelMessage {
3524 sender_id: session.user_id().to_proto(),
3525 id: message_id.to_proto(),
3526 body: request.body.clone(),
3527 mentions: request.mentions.clone(),
3528 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3529 nonce: Some(nonce),
3530 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3531 edited_at: Some(updated_at.unix_timestamp() as u64),
3532 };
3533
3534 response.send(proto::Ack {})?;
3535
3536 let pool = &*session.connection_pool().await;
3537 broadcast(
3538 Some(session.connection_id),
3539 participant_connection_ids,
3540 |connection| {
3541 session.peer.send(
3542 connection,
3543 proto::ChannelMessageUpdate {
3544 channel_id: channel_id.to_proto(),
3545 message: Some(message.clone()),
3546 },
3547 )?;
3548
3549 for notification_id in &deleted_mention_notification_ids {
3550 session.peer.send(
3551 connection,
3552 proto::DeleteNotification {
3553 notification_id: (*notification_id).to_proto(),
3554 },
3555 )?;
3556 }
3557
3558 for notification in &updated_mention_notifications {
3559 session.peer.send(
3560 connection,
3561 proto::UpdateNotification {
3562 notification: Some(notification.clone()),
3563 },
3564 )?;
3565 }
3566
3567 Ok(())
3568 },
3569 );
3570
3571 send_notifications(pool, &session.peer, notifications);
3572
3573 Ok(())
3574}
3575
3576/// Mark a channel message as read
3577async fn acknowledge_channel_message(
3578 request: proto::AckChannelMessage,
3579 session: Session,
3580) -> Result<()> {
3581 let channel_id = ChannelId::from_proto(request.channel_id);
3582 let message_id = MessageId::from_proto(request.message_id);
3583 let notifications = session
3584 .db()
3585 .await
3586 .observe_channel_message(channel_id, session.user_id(), message_id)
3587 .await?;
3588 send_notifications(
3589 &*session.connection_pool().await,
3590 &session.peer,
3591 notifications,
3592 );
3593 Ok(())
3594}
3595
3596/// Mark a buffer version as synced
3597async fn acknowledge_buffer_version(
3598 request: proto::AckBufferOperation,
3599 session: Session,
3600) -> Result<()> {
3601 let buffer_id = BufferId::from_proto(request.buffer_id);
3602 session
3603 .db()
3604 .await
3605 .observe_buffer_version(
3606 buffer_id,
3607 session.user_id(),
3608 request.epoch as i32,
3609 &request.version,
3610 )
3611 .await?;
3612 Ok(())
3613}
3614
3615async fn count_language_model_tokens(
3616 request: proto::CountLanguageModelTokens,
3617 response: Response<proto::CountLanguageModelTokens>,
3618 session: Session,
3619 config: &Config,
3620) -> Result<()> {
3621 authorize_access_to_legacy_llm_endpoints(&session).await?;
3622
3623 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3624 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3625 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3626 };
3627
3628 session
3629 .app_state
3630 .rate_limiter
3631 .check(&*rate_limit, session.user_id())
3632 .await?;
3633
3634 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3635 Some(proto::LanguageModelProvider::Google) => {
3636 let api_key = config
3637 .google_ai_api_key
3638 .as_ref()
3639 .context("no Google AI API key configured on the server")?;
3640 google_ai::count_tokens(
3641 session.http_client.as_ref(),
3642 google_ai::API_URL,
3643 api_key,
3644 serde_json::from_str(&request.request)?,
3645 )
3646 .await?
3647 }
3648 _ => return Err(anyhow!("unsupported provider"))?,
3649 };
3650
3651 response.send(proto::CountLanguageModelTokensResponse {
3652 token_count: result.total_tokens as u32,
3653 })?;
3654
3655 Ok(())
3656}
3657
3658struct ZedProCountLanguageModelTokensRateLimit;
3659
3660impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3661 fn capacity(&self) -> usize {
3662 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3663 .ok()
3664 .and_then(|v| v.parse().ok())
3665 .unwrap_or(600) // Picked arbitrarily
3666 }
3667
3668 fn refill_duration(&self) -> chrono::Duration {
3669 chrono::Duration::hours(1)
3670 }
3671
3672 fn db_name(&self) -> &'static str {
3673 "zed-pro:count-language-model-tokens"
3674 }
3675}
3676
3677struct FreeCountLanguageModelTokensRateLimit;
3678
3679impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3680 fn capacity(&self) -> usize {
3681 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3682 .ok()
3683 .and_then(|v| v.parse().ok())
3684 .unwrap_or(600 / 10) // Picked arbitrarily
3685 }
3686
3687 fn refill_duration(&self) -> chrono::Duration {
3688 chrono::Duration::hours(1)
3689 }
3690
3691 fn db_name(&self) -> &'static str {
3692 "free:count-language-model-tokens"
3693 }
3694}
3695
3696struct ZedProComputeEmbeddingsRateLimit;
3697
3698impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3699 fn capacity(&self) -> usize {
3700 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3701 .ok()
3702 .and_then(|v| v.parse().ok())
3703 .unwrap_or(5000) // Picked arbitrarily
3704 }
3705
3706 fn refill_duration(&self) -> chrono::Duration {
3707 chrono::Duration::hours(1)
3708 }
3709
3710 fn db_name(&self) -> &'static str {
3711 "zed-pro:compute-embeddings"
3712 }
3713}
3714
3715struct FreeComputeEmbeddingsRateLimit;
3716
3717impl RateLimit for FreeComputeEmbeddingsRateLimit {
3718 fn capacity(&self) -> usize {
3719 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3720 .ok()
3721 .and_then(|v| v.parse().ok())
3722 .unwrap_or(5000 / 10) // Picked arbitrarily
3723 }
3724
3725 fn refill_duration(&self) -> chrono::Duration {
3726 chrono::Duration::hours(1)
3727 }
3728
3729 fn db_name(&self) -> &'static str {
3730 "free:compute-embeddings"
3731 }
3732}
3733
3734async fn compute_embeddings(
3735 request: proto::ComputeEmbeddings,
3736 response: Response<proto::ComputeEmbeddings>,
3737 session: Session,
3738 api_key: Option<Arc<str>>,
3739) -> Result<()> {
3740 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3741 authorize_access_to_legacy_llm_endpoints(&session).await?;
3742
3743 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3744 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3745 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3746 };
3747
3748 session
3749 .app_state
3750 .rate_limiter
3751 .check(&*rate_limit, session.user_id())
3752 .await?;
3753
3754 let embeddings = match request.model.as_str() {
3755 "openai/text-embedding-3-small" => {
3756 open_ai::embed(
3757 session.http_client.as_ref(),
3758 OPEN_AI_API_URL,
3759 &api_key,
3760 OpenAiEmbeddingModel::TextEmbedding3Small,
3761 request.texts.iter().map(|text| text.as_str()),
3762 )
3763 .await?
3764 }
3765 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3766 };
3767
3768 let embeddings = request
3769 .texts
3770 .iter()
3771 .map(|text| {
3772 let mut hasher = sha2::Sha256::new();
3773 hasher.update(text.as_bytes());
3774 let result = hasher.finalize();
3775 result.to_vec()
3776 })
3777 .zip(
3778 embeddings
3779 .data
3780 .into_iter()
3781 .map(|embedding| embedding.embedding),
3782 )
3783 .collect::<HashMap<_, _>>();
3784
3785 let db = session.db().await;
3786 db.save_embeddings(&request.model, &embeddings)
3787 .await
3788 .context("failed to save embeddings")
3789 .trace_err();
3790
3791 response.send(proto::ComputeEmbeddingsResponse {
3792 embeddings: embeddings
3793 .into_iter()
3794 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3795 .collect(),
3796 })?;
3797 Ok(())
3798}
3799
3800async fn get_cached_embeddings(
3801 request: proto::GetCachedEmbeddings,
3802 response: Response<proto::GetCachedEmbeddings>,
3803 session: Session,
3804) -> Result<()> {
3805 authorize_access_to_legacy_llm_endpoints(&session).await?;
3806
3807 let db = session.db().await;
3808 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3809
3810 response.send(proto::GetCachedEmbeddingsResponse {
3811 embeddings: embeddings
3812 .into_iter()
3813 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3814 .collect(),
3815 })?;
3816 Ok(())
3817}
3818
3819/// This is leftover from before the LLM service.
3820///
3821/// The endpoints protected by this check will be moved there eventually.
3822async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3823 if session.is_staff() {
3824 Ok(())
3825 } else {
3826 Err(anyhow!("permission denied"))?
3827 }
3828}
3829
3830/// Get a Supermaven API key for the user
3831async fn get_supermaven_api_key(
3832 _request: proto::GetSupermavenApiKey,
3833 response: Response<proto::GetSupermavenApiKey>,
3834 session: Session,
3835) -> Result<()> {
3836 let user_id: String = session.user_id().to_string();
3837 if !session.is_staff() {
3838 return Err(anyhow!("supermaven not enabled for this account"))?;
3839 }
3840
3841 let email = session
3842 .email()
3843 .ok_or_else(|| anyhow!("user must have an email"))?;
3844
3845 let supermaven_admin_api = session
3846 .supermaven_client
3847 .as_ref()
3848 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3849
3850 let result = supermaven_admin_api
3851 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3852 .await?;
3853
3854 response.send(proto::GetSupermavenApiKeyResponse {
3855 api_key: result.api_key,
3856 })?;
3857
3858 Ok(())
3859}
3860
3861/// Start receiving chat updates for a channel
3862async fn join_channel_chat(
3863 request: proto::JoinChannelChat,
3864 response: Response<proto::JoinChannelChat>,
3865 session: Session,
3866) -> Result<()> {
3867 let channel_id = ChannelId::from_proto(request.channel_id);
3868
3869 let db = session.db().await;
3870 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3871 .await?;
3872 let messages = db
3873 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3874 .await?;
3875 response.send(proto::JoinChannelChatResponse {
3876 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3877 messages,
3878 })?;
3879 Ok(())
3880}
3881
3882/// Stop receiving chat updates for a channel
3883async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3884 let channel_id = ChannelId::from_proto(request.channel_id);
3885 session
3886 .db()
3887 .await
3888 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3889 .await?;
3890 Ok(())
3891}
3892
3893/// Retrieve the chat history for a channel
3894async fn get_channel_messages(
3895 request: proto::GetChannelMessages,
3896 response: Response<proto::GetChannelMessages>,
3897 session: Session,
3898) -> Result<()> {
3899 let channel_id = ChannelId::from_proto(request.channel_id);
3900 let messages = session
3901 .db()
3902 .await
3903 .get_channel_messages(
3904 channel_id,
3905 session.user_id(),
3906 MESSAGE_COUNT_PER_PAGE,
3907 Some(MessageId::from_proto(request.before_message_id)),
3908 )
3909 .await?;
3910 response.send(proto::GetChannelMessagesResponse {
3911 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3912 messages,
3913 })?;
3914 Ok(())
3915}
3916
3917/// Retrieve specific chat messages
3918async fn get_channel_messages_by_id(
3919 request: proto::GetChannelMessagesById,
3920 response: Response<proto::GetChannelMessagesById>,
3921 session: Session,
3922) -> Result<()> {
3923 let message_ids = request
3924 .message_ids
3925 .iter()
3926 .map(|id| MessageId::from_proto(*id))
3927 .collect::<Vec<_>>();
3928 let messages = session
3929 .db()
3930 .await
3931 .get_channel_messages_by_id(session.user_id(), &message_ids)
3932 .await?;
3933 response.send(proto::GetChannelMessagesResponse {
3934 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3935 messages,
3936 })?;
3937 Ok(())
3938}
3939
3940/// Retrieve the current users notifications
3941async fn get_notifications(
3942 request: proto::GetNotifications,
3943 response: Response<proto::GetNotifications>,
3944 session: Session,
3945) -> Result<()> {
3946 let notifications = session
3947 .db()
3948 .await
3949 .get_notifications(
3950 session.user_id(),
3951 NOTIFICATION_COUNT_PER_PAGE,
3952 request.before_id.map(db::NotificationId::from_proto),
3953 )
3954 .await?;
3955 response.send(proto::GetNotificationsResponse {
3956 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3957 notifications,
3958 })?;
3959 Ok(())
3960}
3961
3962/// Mark notifications as read
3963async fn mark_notification_as_read(
3964 request: proto::MarkNotificationRead,
3965 response: Response<proto::MarkNotificationRead>,
3966 session: Session,
3967) -> Result<()> {
3968 let database = &session.db().await;
3969 let notifications = database
3970 .mark_notification_as_read_by_id(
3971 session.user_id(),
3972 NotificationId::from_proto(request.notification_id),
3973 )
3974 .await?;
3975 send_notifications(
3976 &*session.connection_pool().await,
3977 &session.peer,
3978 notifications,
3979 );
3980 response.send(proto::Ack {})?;
3981 Ok(())
3982}
3983
3984/// Get the current users information
3985async fn get_private_user_info(
3986 _request: proto::GetPrivateUserInfo,
3987 response: Response<proto::GetPrivateUserInfo>,
3988 session: Session,
3989) -> Result<()> {
3990 let db = session.db().await;
3991
3992 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3993 let user = db
3994 .get_user_by_id(session.user_id())
3995 .await?
3996 .ok_or_else(|| anyhow!("user not found"))?;
3997 let flags = db.get_user_flags(session.user_id()).await?;
3998
3999 response.send(proto::GetPrivateUserInfoResponse {
4000 metrics_id,
4001 staff: user.admin,
4002 flags,
4003 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4004 })?;
4005 Ok(())
4006}
4007
4008/// Accept the terms of service (tos) on behalf of the current user
4009async fn accept_terms_of_service(
4010 _request: proto::AcceptTermsOfService,
4011 response: Response<proto::AcceptTermsOfService>,
4012 session: Session,
4013) -> Result<()> {
4014 let db = session.db().await;
4015
4016 let accepted_tos_at = Utc::now();
4017 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4018 .await?;
4019
4020 response.send(proto::AcceptTermsOfServiceResponse {
4021 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4022 })?;
4023 Ok(())
4024}
4025
4026/// The minimum account age an account must have in order to use the LLM service.
4027const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4028
4029async fn get_llm_api_token(
4030 _request: proto::GetLlmToken,
4031 response: Response<proto::GetLlmToken>,
4032 session: Session,
4033) -> Result<()> {
4034 let db = session.db().await;
4035
4036 let flags = db.get_user_flags(session.user_id()).await?;
4037 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4038 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4039 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4040
4041 if !session.is_staff() && !has_language_models_feature_flag {
4042 Err(anyhow!("permission denied"))?
4043 }
4044
4045 let user_id = session.user_id();
4046 let user = db
4047 .get_user_by_id(user_id)
4048 .await?
4049 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4050
4051 if user.accepted_tos_at.is_none() {
4052 Err(anyhow!("terms of service not accepted"))?
4053 }
4054
4055 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4056
4057 let bypass_account_age_check =
4058 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4059 if !bypass_account_age_check {
4060 let mut account_created_at = user.created_at;
4061 if let Some(github_created_at) = user.github_user_created_at {
4062 account_created_at = account_created_at.min(github_created_at);
4063 }
4064 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4065 Err(anyhow!("account too young"))?
4066 }
4067 }
4068
4069 let billing_preferences = db.get_billing_preferences(user.id).await?;
4070
4071 let token = LlmTokenClaims::create(
4072 &user,
4073 session.is_staff(),
4074 billing_preferences,
4075 has_llm_closed_beta_feature_flag,
4076 has_predict_edits_feature_flag,
4077 has_llm_subscription,
4078 session.current_plan(&db).await?,
4079 session.system_id.clone(),
4080 &session.app_state.config,
4081 )?;
4082 response.send(proto::GetLlmTokenResponse { token })?;
4083 Ok(())
4084}
4085
4086fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4087 let message = match message {
4088 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4089 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4090 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4091 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4092 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4093 code: frame.code.into(),
4094 reason: frame.reason,
4095 })),
4096 // We should never receive a frame while reading the message, according
4097 // to the `tungstenite` maintainers:
4098 //
4099 // > It cannot occur when you read messages from the WebSocket, but it
4100 // > can be used when you want to send the raw frames (e.g. you want to
4101 // > send the frames to the WebSocket without composing the full message first).
4102 // >
4103 // > — https://github.com/snapview/tungstenite-rs/issues/268
4104 TungsteniteMessage::Frame(_) => {
4105 bail!("received an unexpected frame while reading the message")
4106 }
4107 };
4108
4109 Ok(message)
4110}
4111
4112fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4113 match message {
4114 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4115 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4116 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4117 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4118 AxumMessage::Close(frame) => {
4119 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4120 code: frame.code.into(),
4121 reason: frame.reason,
4122 }))
4123 }
4124 }
4125}
4126
4127fn notify_membership_updated(
4128 connection_pool: &mut ConnectionPool,
4129 result: MembershipUpdated,
4130 user_id: UserId,
4131 peer: &Peer,
4132) {
4133 for membership in &result.new_channels.channel_memberships {
4134 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4135 }
4136 for channel_id in &result.removed_channels {
4137 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4138 }
4139
4140 let user_channels_update = proto::UpdateUserChannels {
4141 channel_memberships: result
4142 .new_channels
4143 .channel_memberships
4144 .iter()
4145 .map(|cm| proto::ChannelMembership {
4146 channel_id: cm.channel_id.to_proto(),
4147 role: cm.role.into(),
4148 })
4149 .collect(),
4150 ..Default::default()
4151 };
4152
4153 let mut update = build_channels_update(result.new_channels);
4154 update.delete_channels = result
4155 .removed_channels
4156 .into_iter()
4157 .map(|id| id.to_proto())
4158 .collect();
4159 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4160
4161 for connection_id in connection_pool.user_connection_ids(user_id) {
4162 peer.send(connection_id, user_channels_update.clone())
4163 .trace_err();
4164 peer.send(connection_id, update.clone()).trace_err();
4165 }
4166}
4167
4168fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4169 proto::UpdateUserChannels {
4170 channel_memberships: channels
4171 .channel_memberships
4172 .iter()
4173 .map(|m| proto::ChannelMembership {
4174 channel_id: m.channel_id.to_proto(),
4175 role: m.role.into(),
4176 })
4177 .collect(),
4178 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4179 observed_channel_message_id: channels.observed_channel_messages.clone(),
4180 }
4181}
4182
4183fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4184 let mut update = proto::UpdateChannels::default();
4185
4186 for channel in channels.channels {
4187 update.channels.push(channel.to_proto());
4188 }
4189
4190 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4191 update.latest_channel_message_ids = channels.latest_channel_messages;
4192
4193 for (channel_id, participants) in channels.channel_participants {
4194 update
4195 .channel_participants
4196 .push(proto::ChannelParticipants {
4197 channel_id: channel_id.to_proto(),
4198 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4199 });
4200 }
4201
4202 for channel in channels.invited_channels {
4203 update.channel_invitations.push(channel.to_proto());
4204 }
4205
4206 update
4207}
4208
4209fn build_initial_contacts_update(
4210 contacts: Vec<db::Contact>,
4211 pool: &ConnectionPool,
4212) -> proto::UpdateContacts {
4213 let mut update = proto::UpdateContacts::default();
4214
4215 for contact in contacts {
4216 match contact {
4217 db::Contact::Accepted { user_id, busy } => {
4218 update.contacts.push(contact_for_user(user_id, busy, pool));
4219 }
4220 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4221 db::Contact::Incoming { user_id } => {
4222 update
4223 .incoming_requests
4224 .push(proto::IncomingContactRequest {
4225 requester_id: user_id.to_proto(),
4226 })
4227 }
4228 }
4229 }
4230
4231 update
4232}
4233
4234fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4235 proto::Contact {
4236 user_id: user_id.to_proto(),
4237 online: pool.is_user_online(user_id),
4238 busy,
4239 }
4240}
4241
4242fn room_updated(room: &proto::Room, peer: &Peer) {
4243 broadcast(
4244 None,
4245 room.participants
4246 .iter()
4247 .filter_map(|participant| Some(participant.peer_id?.into())),
4248 |peer_id| {
4249 peer.send(
4250 peer_id,
4251 proto::RoomUpdated {
4252 room: Some(room.clone()),
4253 },
4254 )
4255 },
4256 );
4257}
4258
4259fn channel_updated(
4260 channel: &db::channel::Model,
4261 room: &proto::Room,
4262 peer: &Peer,
4263 pool: &ConnectionPool,
4264) {
4265 let participants = room
4266 .participants
4267 .iter()
4268 .map(|p| p.user_id)
4269 .collect::<Vec<_>>();
4270
4271 broadcast(
4272 None,
4273 pool.channel_connection_ids(channel.root_id())
4274 .filter_map(|(channel_id, role)| {
4275 role.can_see_channel(channel.visibility)
4276 .then_some(channel_id)
4277 }),
4278 |peer_id| {
4279 peer.send(
4280 peer_id,
4281 proto::UpdateChannels {
4282 channel_participants: vec![proto::ChannelParticipants {
4283 channel_id: channel.id.to_proto(),
4284 participant_user_ids: participants.clone(),
4285 }],
4286 ..Default::default()
4287 },
4288 )
4289 },
4290 );
4291}
4292
4293async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4294 let db = session.db().await;
4295
4296 let contacts = db.get_contacts(user_id).await?;
4297 let busy = db.is_user_busy(user_id).await?;
4298
4299 let pool = session.connection_pool().await;
4300 let updated_contact = contact_for_user(user_id, busy, &pool);
4301 for contact in contacts {
4302 if let db::Contact::Accepted {
4303 user_id: contact_user_id,
4304 ..
4305 } = contact
4306 {
4307 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4308 session
4309 .peer
4310 .send(
4311 contact_conn_id,
4312 proto::UpdateContacts {
4313 contacts: vec![updated_contact.clone()],
4314 remove_contacts: Default::default(),
4315 incoming_requests: Default::default(),
4316 remove_incoming_requests: Default::default(),
4317 outgoing_requests: Default::default(),
4318 remove_outgoing_requests: Default::default(),
4319 },
4320 )
4321 .trace_err();
4322 }
4323 }
4324 }
4325 Ok(())
4326}
4327
4328async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4329 let mut contacts_to_update = HashSet::default();
4330
4331 let room_id;
4332 let canceled_calls_to_user_ids;
4333 let livekit_room;
4334 let delete_livekit_room;
4335 let room;
4336 let channel;
4337
4338 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4339 contacts_to_update.insert(session.user_id());
4340
4341 for project in left_room.left_projects.values() {
4342 project_left(project, session);
4343 }
4344
4345 room_id = RoomId::from_proto(left_room.room.id);
4346 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4347 livekit_room = mem::take(&mut left_room.room.livekit_room);
4348 delete_livekit_room = left_room.deleted;
4349 room = mem::take(&mut left_room.room);
4350 channel = mem::take(&mut left_room.channel);
4351
4352 room_updated(&room, &session.peer);
4353 } else {
4354 return Ok(());
4355 }
4356
4357 if let Some(channel) = channel {
4358 channel_updated(
4359 &channel,
4360 &room,
4361 &session.peer,
4362 &*session.connection_pool().await,
4363 );
4364 }
4365
4366 {
4367 let pool = session.connection_pool().await;
4368 for canceled_user_id in canceled_calls_to_user_ids {
4369 for connection_id in pool.user_connection_ids(canceled_user_id) {
4370 session
4371 .peer
4372 .send(
4373 connection_id,
4374 proto::CallCanceled {
4375 room_id: room_id.to_proto(),
4376 },
4377 )
4378 .trace_err();
4379 }
4380 contacts_to_update.insert(canceled_user_id);
4381 }
4382 }
4383
4384 for contact_user_id in contacts_to_update {
4385 update_user_contacts(contact_user_id, session).await?;
4386 }
4387
4388 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4389 live_kit
4390 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4391 .await
4392 .trace_err();
4393
4394 if delete_livekit_room {
4395 live_kit.delete_room(livekit_room).await.trace_err();
4396 }
4397 }
4398
4399 Ok(())
4400}
4401
4402async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4403 let left_channel_buffers = session
4404 .db()
4405 .await
4406 .leave_channel_buffers(session.connection_id)
4407 .await?;
4408
4409 for left_buffer in left_channel_buffers {
4410 channel_buffer_updated(
4411 session.connection_id,
4412 left_buffer.connections,
4413 &proto::UpdateChannelBufferCollaborators {
4414 channel_id: left_buffer.channel_id.to_proto(),
4415 collaborators: left_buffer.collaborators,
4416 },
4417 &session.peer,
4418 );
4419 }
4420
4421 Ok(())
4422}
4423
4424fn project_left(project: &db::LeftProject, session: &Session) {
4425 for connection_id in &project.connection_ids {
4426 if project.should_unshare {
4427 session
4428 .peer
4429 .send(
4430 *connection_id,
4431 proto::UnshareProject {
4432 project_id: project.id.to_proto(),
4433 },
4434 )
4435 .trace_err();
4436 } else {
4437 session
4438 .peer
4439 .send(
4440 *connection_id,
4441 proto::RemoveProjectCollaborator {
4442 project_id: project.id.to_proto(),
4443 peer_id: Some(session.connection_id.into()),
4444 },
4445 )
4446 .trace_err();
4447 }
4448 }
4449}
4450
4451pub trait ResultExt {
4452 type Ok;
4453
4454 fn trace_err(self) -> Option<Self::Ok>;
4455}
4456
4457impl<T, E> ResultExt for Result<T, E>
4458where
4459 E: std::fmt::Debug,
4460{
4461 type Ok = T;
4462
4463 #[track_caller]
4464 fn trace_err(self) -> Option<T> {
4465 match self {
4466 Ok(value) => Some(value),
4467 Err(error) => {
4468 tracing::error!("{:?}", error);
4469 None
4470 }
4471 }
4472 }
4473}