1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
314 .add_request_handler(
315 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
316 )
317 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
318 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
323 .add_request_handler(
324 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
325 )
326 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
327 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
328 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
332 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
333 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
341 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
342 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
343 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
344 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
345 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
346 .add_message_handler(create_buffer_for_peer)
347 .add_request_handler(update_buffer)
348 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
349 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
350 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
353 .add_request_handler(get_users)
354 .add_request_handler(fuzzy_search_users)
355 .add_request_handler(request_contact)
356 .add_request_handler(remove_contact)
357 .add_request_handler(respond_to_contact_request)
358 .add_message_handler(subscribe_to_channels)
359 .add_request_handler(create_channel)
360 .add_request_handler(delete_channel)
361 .add_request_handler(invite_channel_member)
362 .add_request_handler(remove_channel_member)
363 .add_request_handler(set_channel_member_role)
364 .add_request_handler(set_channel_visibility)
365 .add_request_handler(rename_channel)
366 .add_request_handler(join_channel_buffer)
367 .add_request_handler(leave_channel_buffer)
368 .add_message_handler(update_channel_buffer)
369 .add_request_handler(rejoin_channel_buffers)
370 .add_request_handler(get_channel_members)
371 .add_request_handler(respond_to_channel_invite)
372 .add_request_handler(join_channel)
373 .add_request_handler(join_channel_chat)
374 .add_message_handler(leave_channel_chat)
375 .add_request_handler(send_channel_message)
376 .add_request_handler(remove_channel_message)
377 .add_request_handler(update_channel_message)
378 .add_request_handler(get_channel_messages)
379 .add_request_handler(get_channel_messages_by_id)
380 .add_request_handler(get_notifications)
381 .add_request_handler(mark_notification_as_read)
382 .add_request_handler(move_channel)
383 .add_request_handler(follow)
384 .add_message_handler(unfollow)
385 .add_message_handler(update_followers)
386 .add_request_handler(get_private_user_info)
387 .add_request_handler(get_llm_api_token)
388 .add_request_handler(accept_terms_of_service)
389 .add_message_handler(acknowledge_channel_message)
390 .add_message_handler(acknowledge_buffer_version)
391 .add_request_handler(get_supermaven_api_key)
392 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
393 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
394 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
395 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
396 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
397 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
398 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
399 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
400 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
401 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
402 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
403 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
404 .add_message_handler(update_context)
405 .add_request_handler({
406 let app_state = app_state.clone();
407 move |request, response, session| {
408 let app_state = app_state.clone();
409 async move {
410 count_language_model_tokens(request, response, session, &app_state.config)
411 .await
412 }
413 }
414 })
415 .add_request_handler(get_cached_embeddings)
416 .add_request_handler({
417 let app_state = app_state.clone();
418 move |request, response, session| {
419 compute_embeddings(
420 request,
421 response,
422 session,
423 app_state.config.openai_api_key.clone(),
424 )
425 }
426 });
427
428 Arc::new(server)
429 }
430
431 pub async fn start(&self) -> Result<()> {
432 let server_id = *self.id.lock();
433 let app_state = self.app_state.clone();
434 let peer = self.peer.clone();
435 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
436 let pool = self.connection_pool.clone();
437 let livekit_client = self.app_state.livekit_client.clone();
438
439 let span = info_span!("start server");
440 self.app_state.executor.spawn_detached(
441 async move {
442 tracing::info!("waiting for cleanup timeout");
443 timeout.await;
444 tracing::info!("cleanup timeout expired, retrieving stale rooms");
445 if let Some((room_ids, channel_ids)) = app_state
446 .db
447 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
448 .await
449 .trace_err()
450 {
451 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
452 tracing::info!(
453 stale_channel_buffer_count = channel_ids.len(),
454 "retrieved stale channel buffers"
455 );
456
457 for channel_id in channel_ids {
458 if let Some(refreshed_channel_buffer) = app_state
459 .db
460 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
461 .await
462 .trace_err()
463 {
464 for connection_id in refreshed_channel_buffer.connection_ids {
465 peer.send(
466 connection_id,
467 proto::UpdateChannelBufferCollaborators {
468 channel_id: channel_id.to_proto(),
469 collaborators: refreshed_channel_buffer
470 .collaborators
471 .clone(),
472 },
473 )
474 .trace_err();
475 }
476 }
477 }
478
479 for room_id in room_ids {
480 let mut contacts_to_update = HashSet::default();
481 let mut canceled_calls_to_user_ids = Vec::new();
482 let mut livekit_room = String::new();
483 let mut delete_livekit_room = false;
484
485 if let Some(mut refreshed_room) = app_state
486 .db
487 .clear_stale_room_participants(room_id, server_id)
488 .await
489 .trace_err()
490 {
491 tracing::info!(
492 room_id = room_id.0,
493 new_participant_count = refreshed_room.room.participants.len(),
494 "refreshed room"
495 );
496 room_updated(&refreshed_room.room, &peer);
497 if let Some(channel) = refreshed_room.channel.as_ref() {
498 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
499 }
500 contacts_to_update
501 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
502 contacts_to_update
503 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
504 canceled_calls_to_user_ids =
505 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
506 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
507 delete_livekit_room = refreshed_room.room.participants.is_empty();
508 }
509
510 {
511 let pool = pool.lock();
512 for canceled_user_id in canceled_calls_to_user_ids {
513 for connection_id in pool.user_connection_ids(canceled_user_id) {
514 peer.send(
515 connection_id,
516 proto::CallCanceled {
517 room_id: room_id.to_proto(),
518 },
519 )
520 .trace_err();
521 }
522 }
523 }
524
525 for user_id in contacts_to_update {
526 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
527 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
528 if let Some((busy, contacts)) = busy.zip(contacts) {
529 let pool = pool.lock();
530 let updated_contact = contact_for_user(user_id, busy, &pool);
531 for contact in contacts {
532 if let db::Contact::Accepted {
533 user_id: contact_user_id,
534 ..
535 } = contact
536 {
537 for contact_conn_id in
538 pool.user_connection_ids(contact_user_id)
539 {
540 peer.send(
541 contact_conn_id,
542 proto::UpdateContacts {
543 contacts: vec![updated_contact.clone()],
544 remove_contacts: Default::default(),
545 incoming_requests: Default::default(),
546 remove_incoming_requests: Default::default(),
547 outgoing_requests: Default::default(),
548 remove_outgoing_requests: Default::default(),
549 },
550 )
551 .trace_err();
552 }
553 }
554 }
555 }
556 }
557
558 if let Some(live_kit) = livekit_client.as_ref() {
559 if delete_livekit_room {
560 live_kit.delete_room(livekit_room).await.trace_err();
561 }
562 }
563 }
564 }
565
566 app_state
567 .db
568 .delete_stale_servers(&app_state.config.zed_environment, server_id)
569 .await
570 .trace_err();
571 }
572 .instrument(span),
573 );
574 Ok(())
575 }
576
577 pub fn teardown(&self) {
578 self.peer.teardown();
579 self.connection_pool.lock().reset();
580 let _ = self.teardown.send(true);
581 }
582
583 #[cfg(test)]
584 pub fn reset(&self, id: ServerId) {
585 self.teardown();
586 *self.id.lock() = id;
587 self.peer.reset(id.0 as u32);
588 let _ = self.teardown.send(false);
589 }
590
591 #[cfg(test)]
592 pub fn id(&self) -> ServerId {
593 *self.id.lock()
594 }
595
596 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
597 where
598 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
599 Fut: 'static + Send + Future<Output = Result<()>>,
600 M: EnvelopedMessage,
601 {
602 let prev_handler = self.handlers.insert(
603 TypeId::of::<M>(),
604 Box::new(move |envelope, session| {
605 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
606 let received_at = envelope.received_at;
607 tracing::info!("message received");
608 let start_time = Instant::now();
609 let future = (handler)(*envelope, session);
610 async move {
611 let result = future.await;
612 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
613 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
614 let queue_duration_ms = total_duration_ms - processing_duration_ms;
615 let payload_type = M::NAME;
616
617 match result {
618 Err(error) => {
619 tracing::error!(
620 ?error,
621 total_duration_ms,
622 processing_duration_ms,
623 queue_duration_ms,
624 payload_type,
625 "error handling message"
626 )
627 }
628 Ok(()) => tracing::info!(
629 total_duration_ms,
630 processing_duration_ms,
631 queue_duration_ms,
632 "finished handling message"
633 ),
634 }
635 }
636 .boxed()
637 }),
638 );
639 if prev_handler.is_some() {
640 panic!("registered a handler for the same message twice");
641 }
642 self
643 }
644
645 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
646 where
647 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
648 Fut: 'static + Send + Future<Output = Result<()>>,
649 M: EnvelopedMessage,
650 {
651 self.add_handler(move |envelope, session| handler(envelope.payload, session));
652 self
653 }
654
655 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
656 where
657 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
658 Fut: Send + Future<Output = Result<()>>,
659 M: RequestMessage,
660 {
661 let handler = Arc::new(handler);
662 self.add_handler(move |envelope, session| {
663 let receipt = envelope.receipt();
664 let handler = handler.clone();
665 async move {
666 let peer = session.peer.clone();
667 let responded = Arc::new(AtomicBool::default());
668 let response = Response {
669 peer: peer.clone(),
670 responded: responded.clone(),
671 receipt,
672 };
673 match (handler)(envelope.payload, response, session).await {
674 Ok(()) => {
675 if responded.load(std::sync::atomic::Ordering::SeqCst) {
676 Ok(())
677 } else {
678 Err(anyhow!("handler did not send a response"))?
679 }
680 }
681 Err(error) => {
682 let proto_err = match &error {
683 Error::Internal(err) => err.to_proto(),
684 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
685 };
686 peer.respond_with_error(receipt, proto_err)?;
687 Err(error)
688 }
689 }
690 }
691 })
692 }
693
694 pub fn handle_connection(
695 self: &Arc<Self>,
696 connection: Connection,
697 address: String,
698 principal: Principal,
699 zed_version: ZedVersion,
700 geoip_country_code: Option<String>,
701 system_id: Option<String>,
702 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
703 executor: Executor,
704 ) -> impl Future<Output = ()> {
705 let this = self.clone();
706 let span = info_span!("handle connection", %address,
707 connection_id=field::Empty,
708 user_id=field::Empty,
709 login=field::Empty,
710 impersonator=field::Empty,
711 geoip_country_code=field::Empty
712 );
713 principal.update_span(&span);
714 if let Some(country_code) = geoip_country_code.as_ref() {
715 span.record("geoip_country_code", country_code);
716 }
717
718 let mut teardown = self.teardown.subscribe();
719 async move {
720 if *teardown.borrow() {
721 tracing::error!("server is tearing down");
722 return
723 }
724 let (connection_id, handle_io, mut incoming_rx) = this
725 .peer
726 .add_connection(connection, {
727 let executor = executor.clone();
728 move |duration| executor.sleep(duration)
729 });
730 tracing::Span::current().record("connection_id", format!("{}", connection_id));
731
732 tracing::info!("connection opened");
733
734 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
735 let http_client = match ReqwestClient::user_agent(&user_agent) {
736 Ok(http_client) => Arc::new(http_client),
737 Err(error) => {
738 tracing::error!(?error, "failed to create HTTP client");
739 return;
740 }
741 };
742
743 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
744 supermaven_admin_api_key.to_string(),
745 http_client.clone(),
746 )));
747
748 let session = Session {
749 principal: principal.clone(),
750 connection_id,
751 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
752 peer: this.peer.clone(),
753 connection_pool: this.connection_pool.clone(),
754 app_state: this.app_state.clone(),
755 http_client,
756 geoip_country_code,
757 system_id,
758 _executor: executor.clone(),
759 supermaven_client,
760 };
761
762 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
763 tracing::error!(?error, "failed to send initial client update");
764 return;
765 }
766
767 let handle_io = handle_io.fuse();
768 futures::pin_mut!(handle_io);
769
770 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
771 // This prevents deadlocks when e.g., client A performs a request to client B and
772 // client B performs a request to client A. If both clients stop processing further
773 // messages until their respective request completes, they won't have a chance to
774 // respond to the other client's request and cause a deadlock.
775 //
776 // This arrangement ensures we will attempt to process earlier messages first, but fall
777 // back to processing messages arrived later in the spirit of making progress.
778 let mut foreground_message_handlers = FuturesUnordered::new();
779 let concurrent_handlers = Arc::new(Semaphore::new(256));
780 loop {
781 let next_message = async {
782 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
783 let message = incoming_rx.next().await;
784 (permit, message)
785 }.fuse();
786 futures::pin_mut!(next_message);
787 futures::select_biased! {
788 _ = teardown.changed().fuse() => return,
789 result = handle_io => {
790 if let Err(error) = result {
791 tracing::error!(?error, "error handling I/O");
792 }
793 break;
794 }
795 _ = foreground_message_handlers.next() => {}
796 next_message = next_message => {
797 let (permit, message) = next_message;
798 if let Some(message) = message {
799 let type_name = message.payload_type_name();
800 // note: we copy all the fields from the parent span so we can query them in the logs.
801 // (https://github.com/tokio-rs/tracing/issues/2670).
802 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
803 user_id=field::Empty,
804 login=field::Empty,
805 impersonator=field::Empty,
806 );
807 principal.update_span(&span);
808 let span_enter = span.enter();
809 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
810 let is_background = message.is_background();
811 let handle_message = (handler)(message, session.clone());
812 drop(span_enter);
813
814 let handle_message = async move {
815 handle_message.await;
816 drop(permit);
817 }.instrument(span);
818 if is_background {
819 executor.spawn_detached(handle_message);
820 } else {
821 foreground_message_handlers.push(handle_message);
822 }
823 } else {
824 tracing::error!("no message handler");
825 }
826 } else {
827 tracing::info!("connection closed");
828 break;
829 }
830 }
831 }
832 }
833
834 drop(foreground_message_handlers);
835 tracing::info!("signing out");
836 if let Err(error) = connection_lost(session, teardown, executor).await {
837 tracing::error!(?error, "error signing out");
838 }
839
840 }.instrument(span)
841 }
842
843 async fn send_initial_client_update(
844 &self,
845 connection_id: ConnectionId,
846 principal: &Principal,
847 zed_version: ZedVersion,
848 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
849 session: &Session,
850 ) -> Result<()> {
851 self.peer.send(
852 connection_id,
853 proto::Hello {
854 peer_id: Some(connection_id.into()),
855 },
856 )?;
857 tracing::info!("sent hello message");
858 if let Some(send_connection_id) = send_connection_id.take() {
859 let _ = send_connection_id.send(connection_id);
860 }
861
862 match principal {
863 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
864 if !user.connected_once {
865 self.peer.send(connection_id, proto::ShowContacts {})?;
866 self.app_state
867 .db
868 .set_user_connected_once(user.id, true)
869 .await?;
870 }
871
872 update_user_plan(user.id, session).await?;
873
874 let contacts = self.app_state.db.get_contacts(user.id).await?;
875
876 {
877 let mut pool = self.connection_pool.lock();
878 pool.add_connection(connection_id, user.id, user.admin, zed_version);
879 self.peer.send(
880 connection_id,
881 build_initial_contacts_update(contacts, &pool),
882 )?;
883 }
884
885 if should_auto_subscribe_to_channels(zed_version) {
886 subscribe_user_to_channels(user.id, session).await?;
887 }
888
889 if let Some(incoming_call) =
890 self.app_state.db.incoming_call_for_user(user.id).await?
891 {
892 self.peer.send(connection_id, incoming_call)?;
893 }
894
895 update_user_contacts(user.id, session).await?;
896 }
897 }
898
899 Ok(())
900 }
901
902 pub async fn invite_code_redeemed(
903 self: &Arc<Self>,
904 inviter_id: UserId,
905 invitee_id: UserId,
906 ) -> Result<()> {
907 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
908 if let Some(code) = &user.invite_code {
909 let pool = self.connection_pool.lock();
910 let invitee_contact = contact_for_user(invitee_id, false, &pool);
911 for connection_id in pool.user_connection_ids(inviter_id) {
912 self.peer.send(
913 connection_id,
914 proto::UpdateContacts {
915 contacts: vec![invitee_contact.clone()],
916 ..Default::default()
917 },
918 )?;
919 self.peer.send(
920 connection_id,
921 proto::UpdateInviteInfo {
922 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
923 count: user.invite_count as u32,
924 },
925 )?;
926 }
927 }
928 }
929 Ok(())
930 }
931
932 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
933 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
934 if let Some(invite_code) = &user.invite_code {
935 let pool = self.connection_pool.lock();
936 for connection_id in pool.user_connection_ids(user_id) {
937 self.peer.send(
938 connection_id,
939 proto::UpdateInviteInfo {
940 url: format!(
941 "{}{}",
942 self.app_state.config.invite_link_prefix, invite_code
943 ),
944 count: user.invite_count as u32,
945 },
946 )?;
947 }
948 }
949 }
950 Ok(())
951 }
952
953 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
954 let pool = self.connection_pool.lock();
955 for connection_id in pool.user_connection_ids(user_id) {
956 self.peer
957 .send(connection_id, proto::RefreshLlmToken {})
958 .trace_err();
959 }
960 }
961
962 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
963 ServerSnapshot {
964 connection_pool: ConnectionPoolGuard {
965 guard: self.connection_pool.lock(),
966 _not_send: PhantomData,
967 },
968 peer: &self.peer,
969 }
970 }
971}
972
973impl<'a> Deref for ConnectionPoolGuard<'a> {
974 type Target = ConnectionPool;
975
976 fn deref(&self) -> &Self::Target {
977 &self.guard
978 }
979}
980
981impl<'a> DerefMut for ConnectionPoolGuard<'a> {
982 fn deref_mut(&mut self) -> &mut Self::Target {
983 &mut self.guard
984 }
985}
986
987impl<'a> Drop for ConnectionPoolGuard<'a> {
988 fn drop(&mut self) {
989 #[cfg(test)]
990 self.check_invariants();
991 }
992}
993
994fn broadcast<F>(
995 sender_id: Option<ConnectionId>,
996 receiver_ids: impl IntoIterator<Item = ConnectionId>,
997 mut f: F,
998) where
999 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1000{
1001 for receiver_id in receiver_ids {
1002 if Some(receiver_id) != sender_id {
1003 if let Err(error) = f(receiver_id) {
1004 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1005 }
1006 }
1007 }
1008}
1009
1010pub struct ProtocolVersion(u32);
1011
1012impl Header for ProtocolVersion {
1013 fn name() -> &'static HeaderName {
1014 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1015 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1016 }
1017
1018 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1019 where
1020 Self: Sized,
1021 I: Iterator<Item = &'i axum::http::HeaderValue>,
1022 {
1023 let version = values
1024 .next()
1025 .ok_or_else(axum::headers::Error::invalid)?
1026 .to_str()
1027 .map_err(|_| axum::headers::Error::invalid())?
1028 .parse()
1029 .map_err(|_| axum::headers::Error::invalid())?;
1030 Ok(Self(version))
1031 }
1032
1033 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1034 values.extend([self.0.to_string().parse().unwrap()]);
1035 }
1036}
1037
1038pub struct AppVersionHeader(SemanticVersion);
1039impl Header for AppVersionHeader {
1040 fn name() -> &'static HeaderName {
1041 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1042 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1043 }
1044
1045 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1046 where
1047 Self: Sized,
1048 I: Iterator<Item = &'i axum::http::HeaderValue>,
1049 {
1050 let version = values
1051 .next()
1052 .ok_or_else(axum::headers::Error::invalid)?
1053 .to_str()
1054 .map_err(|_| axum::headers::Error::invalid())?
1055 .parse()
1056 .map_err(|_| axum::headers::Error::invalid())?;
1057 Ok(Self(version))
1058 }
1059
1060 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1061 values.extend([self.0.to_string().parse().unwrap()]);
1062 }
1063}
1064
1065pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1066 Router::new()
1067 .route("/rpc", get(handle_websocket_request))
1068 .layer(
1069 ServiceBuilder::new()
1070 .layer(Extension(server.app_state.clone()))
1071 .layer(middleware::from_fn(auth::validate_header)),
1072 )
1073 .route("/metrics", get(handle_metrics))
1074 .layer(Extension(server))
1075}
1076
1077pub async fn handle_websocket_request(
1078 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1079 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1080 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1081 Extension(server): Extension<Arc<Server>>,
1082 Extension(principal): Extension<Principal>,
1083 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1084 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1085 ws: WebSocketUpgrade,
1086) -> axum::response::Response {
1087 if protocol_version != rpc::PROTOCOL_VERSION {
1088 return (
1089 StatusCode::UPGRADE_REQUIRED,
1090 "client must be upgraded".to_string(),
1091 )
1092 .into_response();
1093 }
1094
1095 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1096 return (
1097 StatusCode::UPGRADE_REQUIRED,
1098 "no version header found".to_string(),
1099 )
1100 .into_response();
1101 };
1102
1103 if !version.can_collaborate() {
1104 return (
1105 StatusCode::UPGRADE_REQUIRED,
1106 "client must be upgraded".to_string(),
1107 )
1108 .into_response();
1109 }
1110
1111 let socket_address = socket_address.to_string();
1112 ws.on_upgrade(move |socket| {
1113 let socket = socket
1114 .map_ok(to_tungstenite_message)
1115 .err_into()
1116 .with(|message| async move { to_axum_message(message) });
1117 let connection = Connection::new(Box::pin(socket));
1118 async move {
1119 server
1120 .handle_connection(
1121 connection,
1122 socket_address,
1123 principal,
1124 version,
1125 country_code_header.map(|header| header.to_string()),
1126 system_id_header.map(|header| header.to_string()),
1127 None,
1128 Executor::Production,
1129 )
1130 .await;
1131 }
1132 })
1133}
1134
1135pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1136 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1137 let connections_metric = CONNECTIONS_METRIC
1138 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1139
1140 let connections = server
1141 .connection_pool
1142 .lock()
1143 .connections()
1144 .filter(|connection| !connection.admin)
1145 .count();
1146 connections_metric.set(connections as _);
1147
1148 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1149 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1150 register_int_gauge!(
1151 "shared_projects",
1152 "number of open projects with one or more guests"
1153 )
1154 .unwrap()
1155 });
1156
1157 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1158 shared_projects_metric.set(shared_projects as _);
1159
1160 let encoder = prometheus::TextEncoder::new();
1161 let metric_families = prometheus::gather();
1162 let encoded_metrics = encoder
1163 .encode_to_string(&metric_families)
1164 .map_err(|err| anyhow!("{}", err))?;
1165 Ok(encoded_metrics)
1166}
1167
1168#[instrument(err, skip(executor))]
1169async fn connection_lost(
1170 session: Session,
1171 mut teardown: watch::Receiver<bool>,
1172 executor: Executor,
1173) -> Result<()> {
1174 session.peer.disconnect(session.connection_id);
1175 session
1176 .connection_pool()
1177 .await
1178 .remove_connection(session.connection_id)?;
1179
1180 session
1181 .db()
1182 .await
1183 .connection_lost(session.connection_id)
1184 .await
1185 .trace_err();
1186
1187 futures::select_biased! {
1188 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1189
1190 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1191 leave_room_for_session(&session, session.connection_id).await.trace_err();
1192 leave_channel_buffers_for_session(&session)
1193 .await
1194 .trace_err();
1195
1196 if !session
1197 .connection_pool()
1198 .await
1199 .is_user_online(session.user_id())
1200 {
1201 let db = session.db().await;
1202 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1203 room_updated(&room, &session.peer);
1204 }
1205 }
1206
1207 update_user_contacts(session.user_id(), &session).await?;
1208 },
1209 _ = teardown.changed().fuse() => {}
1210 }
1211
1212 Ok(())
1213}
1214
1215/// Acknowledges a ping from a client, used to keep the connection alive.
1216async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1217 response.send(proto::Ack {})?;
1218 Ok(())
1219}
1220
1221/// Creates a new room for calling (outside of channels)
1222async fn create_room(
1223 _request: proto::CreateRoom,
1224 response: Response<proto::CreateRoom>,
1225 session: Session,
1226) -> Result<()> {
1227 let livekit_room = nanoid::nanoid!(30);
1228
1229 let live_kit_connection_info = util::maybe!(async {
1230 let live_kit = session.app_state.livekit_client.as_ref();
1231 let live_kit = live_kit?;
1232 let user_id = session.user_id().to_string();
1233
1234 let token = live_kit
1235 .room_token(&livekit_room, &user_id.to_string())
1236 .trace_err()?;
1237
1238 Some(proto::LiveKitConnectionInfo {
1239 server_url: live_kit.url().into(),
1240 token,
1241 can_publish: true,
1242 })
1243 })
1244 .await;
1245
1246 let room = session
1247 .db()
1248 .await
1249 .create_room(session.user_id(), session.connection_id, &livekit_room)
1250 .await?;
1251
1252 response.send(proto::CreateRoomResponse {
1253 room: Some(room.clone()),
1254 live_kit_connection_info,
1255 })?;
1256
1257 update_user_contacts(session.user_id(), &session).await?;
1258 Ok(())
1259}
1260
1261/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1262async fn join_room(
1263 request: proto::JoinRoom,
1264 response: Response<proto::JoinRoom>,
1265 session: Session,
1266) -> Result<()> {
1267 let room_id = RoomId::from_proto(request.id);
1268
1269 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1270
1271 if let Some(channel_id) = channel_id {
1272 return join_channel_internal(channel_id, Box::new(response), session).await;
1273 }
1274
1275 let joined_room = {
1276 let room = session
1277 .db()
1278 .await
1279 .join_room(room_id, session.user_id(), session.connection_id)
1280 .await?;
1281 room_updated(&room.room, &session.peer);
1282 room.into_inner()
1283 };
1284
1285 for connection_id in session
1286 .connection_pool()
1287 .await
1288 .user_connection_ids(session.user_id())
1289 {
1290 session
1291 .peer
1292 .send(
1293 connection_id,
1294 proto::CallCanceled {
1295 room_id: room_id.to_proto(),
1296 },
1297 )
1298 .trace_err();
1299 }
1300
1301 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1302 {
1303 live_kit
1304 .room_token(
1305 &joined_room.room.livekit_room,
1306 &session.user_id().to_string(),
1307 )
1308 .trace_err()
1309 .map(|token| proto::LiveKitConnectionInfo {
1310 server_url: live_kit.url().into(),
1311 token,
1312 can_publish: true,
1313 })
1314 } else {
1315 None
1316 };
1317
1318 response.send(proto::JoinRoomResponse {
1319 room: Some(joined_room.room),
1320 channel_id: None,
1321 live_kit_connection_info,
1322 })?;
1323
1324 update_user_contacts(session.user_id(), &session).await?;
1325 Ok(())
1326}
1327
1328/// Rejoin room is used to reconnect to a room after connection errors.
1329async fn rejoin_room(
1330 request: proto::RejoinRoom,
1331 response: Response<proto::RejoinRoom>,
1332 session: Session,
1333) -> Result<()> {
1334 let room;
1335 let channel;
1336 {
1337 let mut rejoined_room = session
1338 .db()
1339 .await
1340 .rejoin_room(request, session.user_id(), session.connection_id)
1341 .await?;
1342
1343 response.send(proto::RejoinRoomResponse {
1344 room: Some(rejoined_room.room.clone()),
1345 reshared_projects: rejoined_room
1346 .reshared_projects
1347 .iter()
1348 .map(|project| proto::ResharedProject {
1349 id: project.id.to_proto(),
1350 collaborators: project
1351 .collaborators
1352 .iter()
1353 .map(|collaborator| collaborator.to_proto())
1354 .collect(),
1355 })
1356 .collect(),
1357 rejoined_projects: rejoined_room
1358 .rejoined_projects
1359 .iter()
1360 .map(|rejoined_project| rejoined_project.to_proto())
1361 .collect(),
1362 })?;
1363 room_updated(&rejoined_room.room, &session.peer);
1364
1365 for project in &rejoined_room.reshared_projects {
1366 for collaborator in &project.collaborators {
1367 session
1368 .peer
1369 .send(
1370 collaborator.connection_id,
1371 proto::UpdateProjectCollaborator {
1372 project_id: project.id.to_proto(),
1373 old_peer_id: Some(project.old_connection_id.into()),
1374 new_peer_id: Some(session.connection_id.into()),
1375 },
1376 )
1377 .trace_err();
1378 }
1379
1380 broadcast(
1381 Some(session.connection_id),
1382 project
1383 .collaborators
1384 .iter()
1385 .map(|collaborator| collaborator.connection_id),
1386 |connection_id| {
1387 session.peer.forward_send(
1388 session.connection_id,
1389 connection_id,
1390 proto::UpdateProject {
1391 project_id: project.id.to_proto(),
1392 worktrees: project.worktrees.clone(),
1393 },
1394 )
1395 },
1396 );
1397 }
1398
1399 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1400
1401 let rejoined_room = rejoined_room.into_inner();
1402
1403 room = rejoined_room.room;
1404 channel = rejoined_room.channel;
1405 }
1406
1407 if let Some(channel) = channel {
1408 channel_updated(
1409 &channel,
1410 &room,
1411 &session.peer,
1412 &*session.connection_pool().await,
1413 );
1414 }
1415
1416 update_user_contacts(session.user_id(), &session).await?;
1417 Ok(())
1418}
1419
1420fn notify_rejoined_projects(
1421 rejoined_projects: &mut Vec<RejoinedProject>,
1422 session: &Session,
1423) -> Result<()> {
1424 for project in rejoined_projects.iter() {
1425 for collaborator in &project.collaborators {
1426 session
1427 .peer
1428 .send(
1429 collaborator.connection_id,
1430 proto::UpdateProjectCollaborator {
1431 project_id: project.id.to_proto(),
1432 old_peer_id: Some(project.old_connection_id.into()),
1433 new_peer_id: Some(session.connection_id.into()),
1434 },
1435 )
1436 .trace_err();
1437 }
1438 }
1439
1440 for project in rejoined_projects {
1441 for worktree in mem::take(&mut project.worktrees) {
1442 // Stream this worktree's entries.
1443 let message = proto::UpdateWorktree {
1444 project_id: project.id.to_proto(),
1445 worktree_id: worktree.id,
1446 abs_path: worktree.abs_path.clone(),
1447 root_name: worktree.root_name,
1448 updated_entries: worktree.updated_entries,
1449 removed_entries: worktree.removed_entries,
1450 scan_id: worktree.scan_id,
1451 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1452 updated_repositories: worktree.updated_repositories,
1453 removed_repositories: worktree.removed_repositories,
1454 };
1455 for update in proto::split_worktree_update(message) {
1456 session.peer.send(session.connection_id, update.clone())?;
1457 }
1458
1459 // Stream this worktree's diagnostics.
1460 for summary in worktree.diagnostic_summaries {
1461 session.peer.send(
1462 session.connection_id,
1463 proto::UpdateDiagnosticSummary {
1464 project_id: project.id.to_proto(),
1465 worktree_id: worktree.id,
1466 summary: Some(summary),
1467 },
1468 )?;
1469 }
1470
1471 for settings_file in worktree.settings_files {
1472 session.peer.send(
1473 session.connection_id,
1474 proto::UpdateWorktreeSettings {
1475 project_id: project.id.to_proto(),
1476 worktree_id: worktree.id,
1477 path: settings_file.path,
1478 content: Some(settings_file.content),
1479 kind: Some(settings_file.kind.to_proto().into()),
1480 },
1481 )?;
1482 }
1483 }
1484
1485 for language_server in &project.language_servers {
1486 session.peer.send(
1487 session.connection_id,
1488 proto::UpdateLanguageServer {
1489 project_id: project.id.to_proto(),
1490 language_server_id: language_server.id,
1491 variant: Some(
1492 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1493 proto::LspDiskBasedDiagnosticsUpdated {},
1494 ),
1495 ),
1496 },
1497 )?;
1498 }
1499 }
1500 Ok(())
1501}
1502
1503/// leave room disconnects from the room.
1504async fn leave_room(
1505 _: proto::LeaveRoom,
1506 response: Response<proto::LeaveRoom>,
1507 session: Session,
1508) -> Result<()> {
1509 leave_room_for_session(&session, session.connection_id).await?;
1510 response.send(proto::Ack {})?;
1511 Ok(())
1512}
1513
1514/// Updates the permissions of someone else in the room.
1515async fn set_room_participant_role(
1516 request: proto::SetRoomParticipantRole,
1517 response: Response<proto::SetRoomParticipantRole>,
1518 session: Session,
1519) -> Result<()> {
1520 let user_id = UserId::from_proto(request.user_id);
1521 let role = ChannelRole::from(request.role());
1522
1523 let (livekit_room, can_publish) = {
1524 let room = session
1525 .db()
1526 .await
1527 .set_room_participant_role(
1528 session.user_id(),
1529 RoomId::from_proto(request.room_id),
1530 user_id,
1531 role,
1532 )
1533 .await?;
1534
1535 let livekit_room = room.livekit_room.clone();
1536 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1537 room_updated(&room, &session.peer);
1538 (livekit_room, can_publish)
1539 };
1540
1541 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1542 live_kit
1543 .update_participant(
1544 livekit_room.clone(),
1545 request.user_id.to_string(),
1546 livekit_api::proto::ParticipantPermission {
1547 can_subscribe: true,
1548 can_publish,
1549 can_publish_data: can_publish,
1550 hidden: false,
1551 recorder: false,
1552 },
1553 )
1554 .await
1555 .trace_err();
1556 }
1557
1558 response.send(proto::Ack {})?;
1559 Ok(())
1560}
1561
1562/// Call someone else into the current room
1563async fn call(
1564 request: proto::Call,
1565 response: Response<proto::Call>,
1566 session: Session,
1567) -> Result<()> {
1568 let room_id = RoomId::from_proto(request.room_id);
1569 let calling_user_id = session.user_id();
1570 let calling_connection_id = session.connection_id;
1571 let called_user_id = UserId::from_proto(request.called_user_id);
1572 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1573 if !session
1574 .db()
1575 .await
1576 .has_contact(calling_user_id, called_user_id)
1577 .await?
1578 {
1579 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1580 }
1581
1582 let incoming_call = {
1583 let (room, incoming_call) = &mut *session
1584 .db()
1585 .await
1586 .call(
1587 room_id,
1588 calling_user_id,
1589 calling_connection_id,
1590 called_user_id,
1591 initial_project_id,
1592 )
1593 .await?;
1594 room_updated(room, &session.peer);
1595 mem::take(incoming_call)
1596 };
1597 update_user_contacts(called_user_id, &session).await?;
1598
1599 let mut calls = session
1600 .connection_pool()
1601 .await
1602 .user_connection_ids(called_user_id)
1603 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1604 .collect::<FuturesUnordered<_>>();
1605
1606 while let Some(call_response) = calls.next().await {
1607 match call_response.as_ref() {
1608 Ok(_) => {
1609 response.send(proto::Ack {})?;
1610 return Ok(());
1611 }
1612 Err(_) => {
1613 call_response.trace_err();
1614 }
1615 }
1616 }
1617
1618 {
1619 let room = session
1620 .db()
1621 .await
1622 .call_failed(room_id, called_user_id)
1623 .await?;
1624 room_updated(&room, &session.peer);
1625 }
1626 update_user_contacts(called_user_id, &session).await?;
1627
1628 Err(anyhow!("failed to ring user"))?
1629}
1630
1631/// Cancel an outgoing call.
1632async fn cancel_call(
1633 request: proto::CancelCall,
1634 response: Response<proto::CancelCall>,
1635 session: Session,
1636) -> Result<()> {
1637 let called_user_id = UserId::from_proto(request.called_user_id);
1638 let room_id = RoomId::from_proto(request.room_id);
1639 {
1640 let room = session
1641 .db()
1642 .await
1643 .cancel_call(room_id, session.connection_id, called_user_id)
1644 .await?;
1645 room_updated(&room, &session.peer);
1646 }
1647
1648 for connection_id in session
1649 .connection_pool()
1650 .await
1651 .user_connection_ids(called_user_id)
1652 {
1653 session
1654 .peer
1655 .send(
1656 connection_id,
1657 proto::CallCanceled {
1658 room_id: room_id.to_proto(),
1659 },
1660 )
1661 .trace_err();
1662 }
1663 response.send(proto::Ack {})?;
1664
1665 update_user_contacts(called_user_id, &session).await?;
1666 Ok(())
1667}
1668
1669/// Decline an incoming call.
1670async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1671 let room_id = RoomId::from_proto(message.room_id);
1672 {
1673 let room = session
1674 .db()
1675 .await
1676 .decline_call(Some(room_id), session.user_id())
1677 .await?
1678 .ok_or_else(|| anyhow!("failed to decline call"))?;
1679 room_updated(&room, &session.peer);
1680 }
1681
1682 for connection_id in session
1683 .connection_pool()
1684 .await
1685 .user_connection_ids(session.user_id())
1686 {
1687 session
1688 .peer
1689 .send(
1690 connection_id,
1691 proto::CallCanceled {
1692 room_id: room_id.to_proto(),
1693 },
1694 )
1695 .trace_err();
1696 }
1697 update_user_contacts(session.user_id(), &session).await?;
1698 Ok(())
1699}
1700
1701/// Updates other participants in the room with your current location.
1702async fn update_participant_location(
1703 request: proto::UpdateParticipantLocation,
1704 response: Response<proto::UpdateParticipantLocation>,
1705 session: Session,
1706) -> Result<()> {
1707 let room_id = RoomId::from_proto(request.room_id);
1708 let location = request
1709 .location
1710 .ok_or_else(|| anyhow!("invalid location"))?;
1711
1712 let db = session.db().await;
1713 let room = db
1714 .update_room_participant_location(room_id, session.connection_id, location)
1715 .await?;
1716
1717 room_updated(&room, &session.peer);
1718 response.send(proto::Ack {})?;
1719 Ok(())
1720}
1721
1722/// Share a project into the room.
1723async fn share_project(
1724 request: proto::ShareProject,
1725 response: Response<proto::ShareProject>,
1726 session: Session,
1727) -> Result<()> {
1728 let (project_id, room) = &*session
1729 .db()
1730 .await
1731 .share_project(
1732 RoomId::from_proto(request.room_id),
1733 session.connection_id,
1734 &request.worktrees,
1735 request.is_ssh_project,
1736 )
1737 .await?;
1738 response.send(proto::ShareProjectResponse {
1739 project_id: project_id.to_proto(),
1740 })?;
1741 room_updated(room, &session.peer);
1742
1743 Ok(())
1744}
1745
1746/// Unshare a project from the room.
1747async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1748 let project_id = ProjectId::from_proto(message.project_id);
1749 unshare_project_internal(project_id, session.connection_id, &session).await
1750}
1751
1752async fn unshare_project_internal(
1753 project_id: ProjectId,
1754 connection_id: ConnectionId,
1755 session: &Session,
1756) -> Result<()> {
1757 let delete = {
1758 let room_guard = session
1759 .db()
1760 .await
1761 .unshare_project(project_id, connection_id)
1762 .await?;
1763
1764 let (delete, room, guest_connection_ids) = &*room_guard;
1765
1766 let message = proto::UnshareProject {
1767 project_id: project_id.to_proto(),
1768 };
1769
1770 broadcast(
1771 Some(connection_id),
1772 guest_connection_ids.iter().copied(),
1773 |conn_id| session.peer.send(conn_id, message.clone()),
1774 );
1775 if let Some(room) = room {
1776 room_updated(room, &session.peer);
1777 }
1778
1779 *delete
1780 };
1781
1782 if delete {
1783 let db = session.db().await;
1784 db.delete_project(project_id).await?;
1785 }
1786
1787 Ok(())
1788}
1789
1790/// Join someone elses shared project.
1791async fn join_project(
1792 request: proto::JoinProject,
1793 response: Response<proto::JoinProject>,
1794 session: Session,
1795) -> Result<()> {
1796 let project_id = ProjectId::from_proto(request.project_id);
1797
1798 tracing::info!(%project_id, "join project");
1799
1800 let db = session.db().await;
1801 let (project, replica_id) = &mut *db
1802 .join_project(project_id, session.connection_id, session.user_id())
1803 .await?;
1804 drop(db);
1805 tracing::info!(%project_id, "join remote project");
1806 join_project_internal(response, session, project, replica_id)
1807}
1808
1809trait JoinProjectInternalResponse {
1810 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1811}
1812impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1813 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1814 Response::<proto::JoinProject>::send(self, result)
1815 }
1816}
1817
1818fn join_project_internal(
1819 response: impl JoinProjectInternalResponse,
1820 session: Session,
1821 project: &mut Project,
1822 replica_id: &ReplicaId,
1823) -> Result<()> {
1824 let collaborators = project
1825 .collaborators
1826 .iter()
1827 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1828 .map(|collaborator| collaborator.to_proto())
1829 .collect::<Vec<_>>();
1830 let project_id = project.id;
1831 let guest_user_id = session.user_id();
1832
1833 let worktrees = project
1834 .worktrees
1835 .iter()
1836 .map(|(id, worktree)| proto::WorktreeMetadata {
1837 id: *id,
1838 root_name: worktree.root_name.clone(),
1839 visible: worktree.visible,
1840 abs_path: worktree.abs_path.clone(),
1841 })
1842 .collect::<Vec<_>>();
1843
1844 let add_project_collaborator = proto::AddProjectCollaborator {
1845 project_id: project_id.to_proto(),
1846 collaborator: Some(proto::Collaborator {
1847 peer_id: Some(session.connection_id.into()),
1848 replica_id: replica_id.0 as u32,
1849 user_id: guest_user_id.to_proto(),
1850 is_host: false,
1851 }),
1852 };
1853
1854 for collaborator in &collaborators {
1855 session
1856 .peer
1857 .send(
1858 collaborator.peer_id.unwrap().into(),
1859 add_project_collaborator.clone(),
1860 )
1861 .trace_err();
1862 }
1863
1864 // First, we send the metadata associated with each worktree.
1865 response.send(proto::JoinProjectResponse {
1866 project_id: project.id.0 as u64,
1867 worktrees: worktrees.clone(),
1868 replica_id: replica_id.0 as u32,
1869 collaborators: collaborators.clone(),
1870 language_servers: project.language_servers.clone(),
1871 role: project.role.into(),
1872 })?;
1873
1874 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1875 // Stream this worktree's entries.
1876 let message = proto::UpdateWorktree {
1877 project_id: project_id.to_proto(),
1878 worktree_id,
1879 abs_path: worktree.abs_path.clone(),
1880 root_name: worktree.root_name,
1881 updated_entries: worktree.entries,
1882 removed_entries: Default::default(),
1883 scan_id: worktree.scan_id,
1884 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1885 updated_repositories: worktree.repository_entries.into_values().collect(),
1886 removed_repositories: Default::default(),
1887 };
1888 for update in proto::split_worktree_update(message) {
1889 session.peer.send(session.connection_id, update.clone())?;
1890 }
1891
1892 // Stream this worktree's diagnostics.
1893 for summary in worktree.diagnostic_summaries {
1894 session.peer.send(
1895 session.connection_id,
1896 proto::UpdateDiagnosticSummary {
1897 project_id: project_id.to_proto(),
1898 worktree_id: worktree.id,
1899 summary: Some(summary),
1900 },
1901 )?;
1902 }
1903
1904 for settings_file in worktree.settings_files {
1905 session.peer.send(
1906 session.connection_id,
1907 proto::UpdateWorktreeSettings {
1908 project_id: project_id.to_proto(),
1909 worktree_id: worktree.id,
1910 path: settings_file.path,
1911 content: Some(settings_file.content),
1912 kind: Some(settings_file.kind.to_proto() as i32),
1913 },
1914 )?;
1915 }
1916 }
1917
1918 for language_server in &project.language_servers {
1919 session.peer.send(
1920 session.connection_id,
1921 proto::UpdateLanguageServer {
1922 project_id: project_id.to_proto(),
1923 language_server_id: language_server.id,
1924 variant: Some(
1925 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1926 proto::LspDiskBasedDiagnosticsUpdated {},
1927 ),
1928 ),
1929 },
1930 )?;
1931 }
1932
1933 Ok(())
1934}
1935
1936/// Leave someone elses shared project.
1937async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1938 let sender_id = session.connection_id;
1939 let project_id = ProjectId::from_proto(request.project_id);
1940 let db = session.db().await;
1941
1942 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1943 tracing::info!(
1944 %project_id,
1945 "leave project"
1946 );
1947
1948 project_left(project, &session);
1949 if let Some(room) = room {
1950 room_updated(room, &session.peer);
1951 }
1952
1953 Ok(())
1954}
1955
1956/// Updates other participants with changes to the project
1957async fn update_project(
1958 request: proto::UpdateProject,
1959 response: Response<proto::UpdateProject>,
1960 session: Session,
1961) -> Result<()> {
1962 let project_id = ProjectId::from_proto(request.project_id);
1963 let (room, guest_connection_ids) = &*session
1964 .db()
1965 .await
1966 .update_project(project_id, session.connection_id, &request.worktrees)
1967 .await?;
1968 broadcast(
1969 Some(session.connection_id),
1970 guest_connection_ids.iter().copied(),
1971 |connection_id| {
1972 session
1973 .peer
1974 .forward_send(session.connection_id, connection_id, request.clone())
1975 },
1976 );
1977 if let Some(room) = room {
1978 room_updated(room, &session.peer);
1979 }
1980 response.send(proto::Ack {})?;
1981
1982 Ok(())
1983}
1984
1985/// Updates other participants with changes to the worktree
1986async fn update_worktree(
1987 request: proto::UpdateWorktree,
1988 response: Response<proto::UpdateWorktree>,
1989 session: Session,
1990) -> Result<()> {
1991 let guest_connection_ids = session
1992 .db()
1993 .await
1994 .update_worktree(&request, session.connection_id)
1995 .await?;
1996
1997 broadcast(
1998 Some(session.connection_id),
1999 guest_connection_ids.iter().copied(),
2000 |connection_id| {
2001 session
2002 .peer
2003 .forward_send(session.connection_id, connection_id, request.clone())
2004 },
2005 );
2006 response.send(proto::Ack {})?;
2007 Ok(())
2008}
2009
2010/// Updates other participants with changes to the diagnostics
2011async fn update_diagnostic_summary(
2012 message: proto::UpdateDiagnosticSummary,
2013 session: Session,
2014) -> Result<()> {
2015 let guest_connection_ids = session
2016 .db()
2017 .await
2018 .update_diagnostic_summary(&message, session.connection_id)
2019 .await?;
2020
2021 broadcast(
2022 Some(session.connection_id),
2023 guest_connection_ids.iter().copied(),
2024 |connection_id| {
2025 session
2026 .peer
2027 .forward_send(session.connection_id, connection_id, message.clone())
2028 },
2029 );
2030
2031 Ok(())
2032}
2033
2034/// Updates other participants with changes to the worktree settings
2035async fn update_worktree_settings(
2036 message: proto::UpdateWorktreeSettings,
2037 session: Session,
2038) -> Result<()> {
2039 let guest_connection_ids = session
2040 .db()
2041 .await
2042 .update_worktree_settings(&message, session.connection_id)
2043 .await?;
2044
2045 broadcast(
2046 Some(session.connection_id),
2047 guest_connection_ids.iter().copied(),
2048 |connection_id| {
2049 session
2050 .peer
2051 .forward_send(session.connection_id, connection_id, message.clone())
2052 },
2053 );
2054
2055 Ok(())
2056}
2057
2058/// Notify other participants that a language server has started.
2059async fn start_language_server(
2060 request: proto::StartLanguageServer,
2061 session: Session,
2062) -> Result<()> {
2063 let guest_connection_ids = session
2064 .db()
2065 .await
2066 .start_language_server(&request, session.connection_id)
2067 .await?;
2068
2069 broadcast(
2070 Some(session.connection_id),
2071 guest_connection_ids.iter().copied(),
2072 |connection_id| {
2073 session
2074 .peer
2075 .forward_send(session.connection_id, connection_id, request.clone())
2076 },
2077 );
2078 Ok(())
2079}
2080
2081/// Notify other participants that a language server has changed.
2082async fn update_language_server(
2083 request: proto::UpdateLanguageServer,
2084 session: Session,
2085) -> Result<()> {
2086 let project_id = ProjectId::from_proto(request.project_id);
2087 let project_connection_ids = session
2088 .db()
2089 .await
2090 .project_connection_ids(project_id, session.connection_id, true)
2091 .await?;
2092 broadcast(
2093 Some(session.connection_id),
2094 project_connection_ids.iter().copied(),
2095 |connection_id| {
2096 session
2097 .peer
2098 .forward_send(session.connection_id, connection_id, request.clone())
2099 },
2100 );
2101 Ok(())
2102}
2103
2104/// forward a project request to the host. These requests should be read only
2105/// as guests are allowed to send them.
2106async fn forward_read_only_project_request<T>(
2107 request: T,
2108 response: Response<T>,
2109 session: Session,
2110) -> Result<()>
2111where
2112 T: EntityMessage + RequestMessage,
2113{
2114 let project_id = ProjectId::from_proto(request.remote_entity_id());
2115 let host_connection_id = session
2116 .db()
2117 .await
2118 .host_for_read_only_project_request(project_id, session.connection_id)
2119 .await?;
2120 let payload = session
2121 .peer
2122 .forward_request(session.connection_id, host_connection_id, request)
2123 .await?;
2124 response.send(payload)?;
2125 Ok(())
2126}
2127
2128async fn forward_find_search_candidates_request(
2129 request: proto::FindSearchCandidates,
2130 response: Response<proto::FindSearchCandidates>,
2131 session: Session,
2132) -> Result<()> {
2133 let project_id = ProjectId::from_proto(request.remote_entity_id());
2134 let host_connection_id = session
2135 .db()
2136 .await
2137 .host_for_read_only_project_request(project_id, session.connection_id)
2138 .await?;
2139 let payload = session
2140 .peer
2141 .forward_request(session.connection_id, host_connection_id, request)
2142 .await?;
2143 response.send(payload)?;
2144 Ok(())
2145}
2146
2147/// forward a project request to the host. These requests are disallowed
2148/// for guests.
2149async fn forward_mutating_project_request<T>(
2150 request: T,
2151 response: Response<T>,
2152 session: Session,
2153) -> Result<()>
2154where
2155 T: EntityMessage + RequestMessage,
2156{
2157 let project_id = ProjectId::from_proto(request.remote_entity_id());
2158
2159 let host_connection_id = session
2160 .db()
2161 .await
2162 .host_for_mutating_project_request(project_id, session.connection_id)
2163 .await?;
2164 let payload = session
2165 .peer
2166 .forward_request(session.connection_id, host_connection_id, request)
2167 .await?;
2168 response.send(payload)?;
2169 Ok(())
2170}
2171
2172/// Notify other participants that a new buffer has been created
2173async fn create_buffer_for_peer(
2174 request: proto::CreateBufferForPeer,
2175 session: Session,
2176) -> Result<()> {
2177 session
2178 .db()
2179 .await
2180 .check_user_is_project_host(
2181 ProjectId::from_proto(request.project_id),
2182 session.connection_id,
2183 )
2184 .await?;
2185 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2186 session
2187 .peer
2188 .forward_send(session.connection_id, peer_id.into(), request)?;
2189 Ok(())
2190}
2191
2192/// Notify other participants that a buffer has been updated. This is
2193/// allowed for guests as long as the update is limited to selections.
2194async fn update_buffer(
2195 request: proto::UpdateBuffer,
2196 response: Response<proto::UpdateBuffer>,
2197 session: Session,
2198) -> Result<()> {
2199 let project_id = ProjectId::from_proto(request.project_id);
2200 let mut capability = Capability::ReadOnly;
2201
2202 for op in request.operations.iter() {
2203 match op.variant {
2204 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2205 Some(_) => capability = Capability::ReadWrite,
2206 }
2207 }
2208
2209 let host = {
2210 let guard = session
2211 .db()
2212 .await
2213 .connections_for_buffer_update(project_id, session.connection_id, capability)
2214 .await?;
2215
2216 let (host, guests) = &*guard;
2217
2218 broadcast(
2219 Some(session.connection_id),
2220 guests.clone(),
2221 |connection_id| {
2222 session
2223 .peer
2224 .forward_send(session.connection_id, connection_id, request.clone())
2225 },
2226 );
2227
2228 *host
2229 };
2230
2231 if host != session.connection_id {
2232 session
2233 .peer
2234 .forward_request(session.connection_id, host, request.clone())
2235 .await?;
2236 }
2237
2238 response.send(proto::Ack {})?;
2239 Ok(())
2240}
2241
2242async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2243 let project_id = ProjectId::from_proto(message.project_id);
2244
2245 let operation = message.operation.as_ref().context("invalid operation")?;
2246 let capability = match operation.variant.as_ref() {
2247 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2248 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2249 match buffer_op.variant {
2250 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2251 Capability::ReadOnly
2252 }
2253 _ => Capability::ReadWrite,
2254 }
2255 } else {
2256 Capability::ReadWrite
2257 }
2258 }
2259 Some(_) => Capability::ReadWrite,
2260 None => Capability::ReadOnly,
2261 };
2262
2263 let guard = session
2264 .db()
2265 .await
2266 .connections_for_buffer_update(project_id, session.connection_id, capability)
2267 .await?;
2268
2269 let (host, guests) = &*guard;
2270
2271 broadcast(
2272 Some(session.connection_id),
2273 guests.iter().chain([host]).copied(),
2274 |connection_id| {
2275 session
2276 .peer
2277 .forward_send(session.connection_id, connection_id, message.clone())
2278 },
2279 );
2280
2281 Ok(())
2282}
2283
2284/// Notify other participants that a project has been updated.
2285async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2286 request: T,
2287 session: Session,
2288) -> Result<()> {
2289 let project_id = ProjectId::from_proto(request.remote_entity_id());
2290 let project_connection_ids = session
2291 .db()
2292 .await
2293 .project_connection_ids(project_id, session.connection_id, false)
2294 .await?;
2295
2296 broadcast(
2297 Some(session.connection_id),
2298 project_connection_ids.iter().copied(),
2299 |connection_id| {
2300 session
2301 .peer
2302 .forward_send(session.connection_id, connection_id, request.clone())
2303 },
2304 );
2305 Ok(())
2306}
2307
2308/// Start following another user in a call.
2309async fn follow(
2310 request: proto::Follow,
2311 response: Response<proto::Follow>,
2312 session: Session,
2313) -> Result<()> {
2314 let room_id = RoomId::from_proto(request.room_id);
2315 let project_id = request.project_id.map(ProjectId::from_proto);
2316 let leader_id = request
2317 .leader_id
2318 .ok_or_else(|| anyhow!("invalid leader id"))?
2319 .into();
2320 let follower_id = session.connection_id;
2321
2322 session
2323 .db()
2324 .await
2325 .check_room_participants(room_id, leader_id, session.connection_id)
2326 .await?;
2327
2328 let response_payload = session
2329 .peer
2330 .forward_request(session.connection_id, leader_id, request)
2331 .await?;
2332 response.send(response_payload)?;
2333
2334 if let Some(project_id) = project_id {
2335 let room = session
2336 .db()
2337 .await
2338 .follow(room_id, project_id, leader_id, follower_id)
2339 .await?;
2340 room_updated(&room, &session.peer);
2341 }
2342
2343 Ok(())
2344}
2345
2346/// Stop following another user in a call.
2347async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2348 let room_id = RoomId::from_proto(request.room_id);
2349 let project_id = request.project_id.map(ProjectId::from_proto);
2350 let leader_id = request
2351 .leader_id
2352 .ok_or_else(|| anyhow!("invalid leader id"))?
2353 .into();
2354 let follower_id = session.connection_id;
2355
2356 session
2357 .db()
2358 .await
2359 .check_room_participants(room_id, leader_id, session.connection_id)
2360 .await?;
2361
2362 session
2363 .peer
2364 .forward_send(session.connection_id, leader_id, request)?;
2365
2366 if let Some(project_id) = project_id {
2367 let room = session
2368 .db()
2369 .await
2370 .unfollow(room_id, project_id, leader_id, follower_id)
2371 .await?;
2372 room_updated(&room, &session.peer);
2373 }
2374
2375 Ok(())
2376}
2377
2378/// Notify everyone following you of your current location.
2379async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2380 let room_id = RoomId::from_proto(request.room_id);
2381 let database = session.db.lock().await;
2382
2383 let connection_ids = if let Some(project_id) = request.project_id {
2384 let project_id = ProjectId::from_proto(project_id);
2385 database
2386 .project_connection_ids(project_id, session.connection_id, true)
2387 .await?
2388 } else {
2389 database
2390 .room_connection_ids(room_id, session.connection_id)
2391 .await?
2392 };
2393
2394 // For now, don't send view update messages back to that view's current leader.
2395 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2396 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2397 _ => None,
2398 });
2399
2400 for connection_id in connection_ids.iter().cloned() {
2401 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2402 session
2403 .peer
2404 .forward_send(session.connection_id, connection_id, request.clone())?;
2405 }
2406 }
2407 Ok(())
2408}
2409
2410/// Get public data about users.
2411async fn get_users(
2412 request: proto::GetUsers,
2413 response: Response<proto::GetUsers>,
2414 session: Session,
2415) -> Result<()> {
2416 let user_ids = request
2417 .user_ids
2418 .into_iter()
2419 .map(UserId::from_proto)
2420 .collect();
2421 let users = session
2422 .db()
2423 .await
2424 .get_users_by_ids(user_ids)
2425 .await?
2426 .into_iter()
2427 .map(|user| proto::User {
2428 id: user.id.to_proto(),
2429 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2430 github_login: user.github_login,
2431 email: user.email_address,
2432 name: user.name,
2433 })
2434 .collect();
2435 response.send(proto::UsersResponse { users })?;
2436 Ok(())
2437}
2438
2439/// Search for users (to invite) buy Github login
2440async fn fuzzy_search_users(
2441 request: proto::FuzzySearchUsers,
2442 response: Response<proto::FuzzySearchUsers>,
2443 session: Session,
2444) -> Result<()> {
2445 let query = request.query;
2446 let users = match query.len() {
2447 0 => vec![],
2448 1 | 2 => session
2449 .db()
2450 .await
2451 .get_user_by_github_login(&query)
2452 .await?
2453 .into_iter()
2454 .collect(),
2455 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2456 };
2457 let users = users
2458 .into_iter()
2459 .filter(|user| user.id != session.user_id())
2460 .map(|user| proto::User {
2461 id: user.id.to_proto(),
2462 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2463 github_login: user.github_login,
2464 name: user.name,
2465 email: user.email_address,
2466 })
2467 .collect();
2468 response.send(proto::UsersResponse { users })?;
2469 Ok(())
2470}
2471
2472/// Send a contact request to another user.
2473async fn request_contact(
2474 request: proto::RequestContact,
2475 response: Response<proto::RequestContact>,
2476 session: Session,
2477) -> Result<()> {
2478 let requester_id = session.user_id();
2479 let responder_id = UserId::from_proto(request.responder_id);
2480 if requester_id == responder_id {
2481 return Err(anyhow!("cannot add yourself as a contact"))?;
2482 }
2483
2484 let notifications = session
2485 .db()
2486 .await
2487 .send_contact_request(requester_id, responder_id)
2488 .await?;
2489
2490 // Update outgoing contact requests of requester
2491 let mut update = proto::UpdateContacts::default();
2492 update.outgoing_requests.push(responder_id.to_proto());
2493 for connection_id in session
2494 .connection_pool()
2495 .await
2496 .user_connection_ids(requester_id)
2497 {
2498 session.peer.send(connection_id, update.clone())?;
2499 }
2500
2501 // Update incoming contact requests of responder
2502 let mut update = proto::UpdateContacts::default();
2503 update
2504 .incoming_requests
2505 .push(proto::IncomingContactRequest {
2506 requester_id: requester_id.to_proto(),
2507 });
2508 let connection_pool = session.connection_pool().await;
2509 for connection_id in connection_pool.user_connection_ids(responder_id) {
2510 session.peer.send(connection_id, update.clone())?;
2511 }
2512
2513 send_notifications(&connection_pool, &session.peer, notifications);
2514
2515 response.send(proto::Ack {})?;
2516 Ok(())
2517}
2518
2519/// Accept or decline a contact request
2520async fn respond_to_contact_request(
2521 request: proto::RespondToContactRequest,
2522 response: Response<proto::RespondToContactRequest>,
2523 session: Session,
2524) -> Result<()> {
2525 let responder_id = session.user_id();
2526 let requester_id = UserId::from_proto(request.requester_id);
2527 let db = session.db().await;
2528 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2529 db.dismiss_contact_notification(responder_id, requester_id)
2530 .await?;
2531 } else {
2532 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2533
2534 let notifications = db
2535 .respond_to_contact_request(responder_id, requester_id, accept)
2536 .await?;
2537 let requester_busy = db.is_user_busy(requester_id).await?;
2538 let responder_busy = db.is_user_busy(responder_id).await?;
2539
2540 let pool = session.connection_pool().await;
2541 // Update responder with new contact
2542 let mut update = proto::UpdateContacts::default();
2543 if accept {
2544 update
2545 .contacts
2546 .push(contact_for_user(requester_id, requester_busy, &pool));
2547 }
2548 update
2549 .remove_incoming_requests
2550 .push(requester_id.to_proto());
2551 for connection_id in pool.user_connection_ids(responder_id) {
2552 session.peer.send(connection_id, update.clone())?;
2553 }
2554
2555 // Update requester with new contact
2556 let mut update = proto::UpdateContacts::default();
2557 if accept {
2558 update
2559 .contacts
2560 .push(contact_for_user(responder_id, responder_busy, &pool));
2561 }
2562 update
2563 .remove_outgoing_requests
2564 .push(responder_id.to_proto());
2565
2566 for connection_id in pool.user_connection_ids(requester_id) {
2567 session.peer.send(connection_id, update.clone())?;
2568 }
2569
2570 send_notifications(&pool, &session.peer, notifications);
2571 }
2572
2573 response.send(proto::Ack {})?;
2574 Ok(())
2575}
2576
2577/// Remove a contact.
2578async fn remove_contact(
2579 request: proto::RemoveContact,
2580 response: Response<proto::RemoveContact>,
2581 session: Session,
2582) -> Result<()> {
2583 let requester_id = session.user_id();
2584 let responder_id = UserId::from_proto(request.user_id);
2585 let db = session.db().await;
2586 let (contact_accepted, deleted_notification_id) =
2587 db.remove_contact(requester_id, responder_id).await?;
2588
2589 let pool = session.connection_pool().await;
2590 // Update outgoing contact requests of requester
2591 let mut update = proto::UpdateContacts::default();
2592 if contact_accepted {
2593 update.remove_contacts.push(responder_id.to_proto());
2594 } else {
2595 update
2596 .remove_outgoing_requests
2597 .push(responder_id.to_proto());
2598 }
2599 for connection_id in pool.user_connection_ids(requester_id) {
2600 session.peer.send(connection_id, update.clone())?;
2601 }
2602
2603 // Update incoming contact requests of responder
2604 let mut update = proto::UpdateContacts::default();
2605 if contact_accepted {
2606 update.remove_contacts.push(requester_id.to_proto());
2607 } else {
2608 update
2609 .remove_incoming_requests
2610 .push(requester_id.to_proto());
2611 }
2612 for connection_id in pool.user_connection_ids(responder_id) {
2613 session.peer.send(connection_id, update.clone())?;
2614 if let Some(notification_id) = deleted_notification_id {
2615 session.peer.send(
2616 connection_id,
2617 proto::DeleteNotification {
2618 notification_id: notification_id.to_proto(),
2619 },
2620 )?;
2621 }
2622 }
2623
2624 response.send(proto::Ack {})?;
2625 Ok(())
2626}
2627
2628fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2629 version.0.minor() < 139
2630}
2631
2632async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2633 let plan = session.current_plan(&session.db().await).await?;
2634
2635 session
2636 .peer
2637 .send(
2638 session.connection_id,
2639 proto::UpdateUserPlan { plan: plan.into() },
2640 )
2641 .trace_err();
2642
2643 Ok(())
2644}
2645
2646async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2647 subscribe_user_to_channels(session.user_id(), &session).await?;
2648 Ok(())
2649}
2650
2651async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2652 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2653 let mut pool = session.connection_pool().await;
2654 for membership in &channels_for_user.channel_memberships {
2655 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2656 }
2657 session.peer.send(
2658 session.connection_id,
2659 build_update_user_channels(&channels_for_user),
2660 )?;
2661 session.peer.send(
2662 session.connection_id,
2663 build_channels_update(channels_for_user),
2664 )?;
2665 Ok(())
2666}
2667
2668/// Creates a new channel.
2669async fn create_channel(
2670 request: proto::CreateChannel,
2671 response: Response<proto::CreateChannel>,
2672 session: Session,
2673) -> Result<()> {
2674 let db = session.db().await;
2675
2676 let parent_id = request.parent_id.map(ChannelId::from_proto);
2677 let (channel, membership) = db
2678 .create_channel(&request.name, parent_id, session.user_id())
2679 .await?;
2680
2681 let root_id = channel.root_id();
2682 let channel = Channel::from_model(channel);
2683
2684 response.send(proto::CreateChannelResponse {
2685 channel: Some(channel.to_proto()),
2686 parent_id: request.parent_id,
2687 })?;
2688
2689 let mut connection_pool = session.connection_pool().await;
2690 if let Some(membership) = membership {
2691 connection_pool.subscribe_to_channel(
2692 membership.user_id,
2693 membership.channel_id,
2694 membership.role,
2695 );
2696 let update = proto::UpdateUserChannels {
2697 channel_memberships: vec![proto::ChannelMembership {
2698 channel_id: membership.channel_id.to_proto(),
2699 role: membership.role.into(),
2700 }],
2701 ..Default::default()
2702 };
2703 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2704 session.peer.send(connection_id, update.clone())?;
2705 }
2706 }
2707
2708 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2709 if !role.can_see_channel(channel.visibility) {
2710 continue;
2711 }
2712
2713 let update = proto::UpdateChannels {
2714 channels: vec![channel.to_proto()],
2715 ..Default::default()
2716 };
2717 session.peer.send(connection_id, update.clone())?;
2718 }
2719
2720 Ok(())
2721}
2722
2723/// Delete a channel
2724async fn delete_channel(
2725 request: proto::DeleteChannel,
2726 response: Response<proto::DeleteChannel>,
2727 session: Session,
2728) -> Result<()> {
2729 let db = session.db().await;
2730
2731 let channel_id = request.channel_id;
2732 let (root_channel, removed_channels) = db
2733 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2734 .await?;
2735 response.send(proto::Ack {})?;
2736
2737 // Notify members of removed channels
2738 let mut update = proto::UpdateChannels::default();
2739 update
2740 .delete_channels
2741 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2742
2743 let connection_pool = session.connection_pool().await;
2744 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2745 session.peer.send(connection_id, update.clone())?;
2746 }
2747
2748 Ok(())
2749}
2750
2751/// Invite someone to join a channel.
2752async fn invite_channel_member(
2753 request: proto::InviteChannelMember,
2754 response: Response<proto::InviteChannelMember>,
2755 session: Session,
2756) -> Result<()> {
2757 let db = session.db().await;
2758 let channel_id = ChannelId::from_proto(request.channel_id);
2759 let invitee_id = UserId::from_proto(request.user_id);
2760 let InviteMemberResult {
2761 channel,
2762 notifications,
2763 } = db
2764 .invite_channel_member(
2765 channel_id,
2766 invitee_id,
2767 session.user_id(),
2768 request.role().into(),
2769 )
2770 .await?;
2771
2772 let update = proto::UpdateChannels {
2773 channel_invitations: vec![channel.to_proto()],
2774 ..Default::default()
2775 };
2776
2777 let connection_pool = session.connection_pool().await;
2778 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2779 session.peer.send(connection_id, update.clone())?;
2780 }
2781
2782 send_notifications(&connection_pool, &session.peer, notifications);
2783
2784 response.send(proto::Ack {})?;
2785 Ok(())
2786}
2787
2788/// remove someone from a channel
2789async fn remove_channel_member(
2790 request: proto::RemoveChannelMember,
2791 response: Response<proto::RemoveChannelMember>,
2792 session: Session,
2793) -> Result<()> {
2794 let db = session.db().await;
2795 let channel_id = ChannelId::from_proto(request.channel_id);
2796 let member_id = UserId::from_proto(request.user_id);
2797
2798 let RemoveChannelMemberResult {
2799 membership_update,
2800 notification_id,
2801 } = db
2802 .remove_channel_member(channel_id, member_id, session.user_id())
2803 .await?;
2804
2805 let mut connection_pool = session.connection_pool().await;
2806 notify_membership_updated(
2807 &mut connection_pool,
2808 membership_update,
2809 member_id,
2810 &session.peer,
2811 );
2812 for connection_id in connection_pool.user_connection_ids(member_id) {
2813 if let Some(notification_id) = notification_id {
2814 session
2815 .peer
2816 .send(
2817 connection_id,
2818 proto::DeleteNotification {
2819 notification_id: notification_id.to_proto(),
2820 },
2821 )
2822 .trace_err();
2823 }
2824 }
2825
2826 response.send(proto::Ack {})?;
2827 Ok(())
2828}
2829
2830/// Toggle the channel between public and private.
2831/// Care is taken to maintain the invariant that public channels only descend from public channels,
2832/// (though members-only channels can appear at any point in the hierarchy).
2833async fn set_channel_visibility(
2834 request: proto::SetChannelVisibility,
2835 response: Response<proto::SetChannelVisibility>,
2836 session: Session,
2837) -> Result<()> {
2838 let db = session.db().await;
2839 let channel_id = ChannelId::from_proto(request.channel_id);
2840 let visibility = request.visibility().into();
2841
2842 let channel_model = db
2843 .set_channel_visibility(channel_id, visibility, session.user_id())
2844 .await?;
2845 let root_id = channel_model.root_id();
2846 let channel = Channel::from_model(channel_model);
2847
2848 let mut connection_pool = session.connection_pool().await;
2849 for (user_id, role) in connection_pool
2850 .channel_user_ids(root_id)
2851 .collect::<Vec<_>>()
2852 .into_iter()
2853 {
2854 let update = if role.can_see_channel(channel.visibility) {
2855 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2856 proto::UpdateChannels {
2857 channels: vec![channel.to_proto()],
2858 ..Default::default()
2859 }
2860 } else {
2861 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2862 proto::UpdateChannels {
2863 delete_channels: vec![channel.id.to_proto()],
2864 ..Default::default()
2865 }
2866 };
2867
2868 for connection_id in connection_pool.user_connection_ids(user_id) {
2869 session.peer.send(connection_id, update.clone())?;
2870 }
2871 }
2872
2873 response.send(proto::Ack {})?;
2874 Ok(())
2875}
2876
2877/// Alter the role for a user in the channel.
2878async fn set_channel_member_role(
2879 request: proto::SetChannelMemberRole,
2880 response: Response<proto::SetChannelMemberRole>,
2881 session: Session,
2882) -> Result<()> {
2883 let db = session.db().await;
2884 let channel_id = ChannelId::from_proto(request.channel_id);
2885 let member_id = UserId::from_proto(request.user_id);
2886 let result = db
2887 .set_channel_member_role(
2888 channel_id,
2889 session.user_id(),
2890 member_id,
2891 request.role().into(),
2892 )
2893 .await?;
2894
2895 match result {
2896 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2897 let mut connection_pool = session.connection_pool().await;
2898 notify_membership_updated(
2899 &mut connection_pool,
2900 membership_update,
2901 member_id,
2902 &session.peer,
2903 )
2904 }
2905 db::SetMemberRoleResult::InviteUpdated(channel) => {
2906 let update = proto::UpdateChannels {
2907 channel_invitations: vec![channel.to_proto()],
2908 ..Default::default()
2909 };
2910
2911 for connection_id in session
2912 .connection_pool()
2913 .await
2914 .user_connection_ids(member_id)
2915 {
2916 session.peer.send(connection_id, update.clone())?;
2917 }
2918 }
2919 }
2920
2921 response.send(proto::Ack {})?;
2922 Ok(())
2923}
2924
2925/// Change the name of a channel
2926async fn rename_channel(
2927 request: proto::RenameChannel,
2928 response: Response<proto::RenameChannel>,
2929 session: Session,
2930) -> Result<()> {
2931 let db = session.db().await;
2932 let channel_id = ChannelId::from_proto(request.channel_id);
2933 let channel_model = db
2934 .rename_channel(channel_id, session.user_id(), &request.name)
2935 .await?;
2936 let root_id = channel_model.root_id();
2937 let channel = Channel::from_model(channel_model);
2938
2939 response.send(proto::RenameChannelResponse {
2940 channel: Some(channel.to_proto()),
2941 })?;
2942
2943 let connection_pool = session.connection_pool().await;
2944 let update = proto::UpdateChannels {
2945 channels: vec![channel.to_proto()],
2946 ..Default::default()
2947 };
2948 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2949 if role.can_see_channel(channel.visibility) {
2950 session.peer.send(connection_id, update.clone())?;
2951 }
2952 }
2953
2954 Ok(())
2955}
2956
2957/// Move a channel to a new parent.
2958async fn move_channel(
2959 request: proto::MoveChannel,
2960 response: Response<proto::MoveChannel>,
2961 session: Session,
2962) -> Result<()> {
2963 let channel_id = ChannelId::from_proto(request.channel_id);
2964 let to = ChannelId::from_proto(request.to);
2965
2966 let (root_id, channels) = session
2967 .db()
2968 .await
2969 .move_channel(channel_id, to, session.user_id())
2970 .await?;
2971
2972 let connection_pool = session.connection_pool().await;
2973 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2974 let channels = channels
2975 .iter()
2976 .filter_map(|channel| {
2977 if role.can_see_channel(channel.visibility) {
2978 Some(channel.to_proto())
2979 } else {
2980 None
2981 }
2982 })
2983 .collect::<Vec<_>>();
2984 if channels.is_empty() {
2985 continue;
2986 }
2987
2988 let update = proto::UpdateChannels {
2989 channels,
2990 ..Default::default()
2991 };
2992
2993 session.peer.send(connection_id, update.clone())?;
2994 }
2995
2996 response.send(Ack {})?;
2997 Ok(())
2998}
2999
3000/// Get the list of channel members
3001async fn get_channel_members(
3002 request: proto::GetChannelMembers,
3003 response: Response<proto::GetChannelMembers>,
3004 session: Session,
3005) -> Result<()> {
3006 let db = session.db().await;
3007 let channel_id = ChannelId::from_proto(request.channel_id);
3008 let limit = if request.limit == 0 {
3009 u16::MAX as u64
3010 } else {
3011 request.limit
3012 };
3013 let (members, users) = db
3014 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3015 .await?;
3016 response.send(proto::GetChannelMembersResponse { members, users })?;
3017 Ok(())
3018}
3019
3020/// Accept or decline a channel invitation.
3021async fn respond_to_channel_invite(
3022 request: proto::RespondToChannelInvite,
3023 response: Response<proto::RespondToChannelInvite>,
3024 session: Session,
3025) -> Result<()> {
3026 let db = session.db().await;
3027 let channel_id = ChannelId::from_proto(request.channel_id);
3028 let RespondToChannelInvite {
3029 membership_update,
3030 notifications,
3031 } = db
3032 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3033 .await?;
3034
3035 let mut connection_pool = session.connection_pool().await;
3036 if let Some(membership_update) = membership_update {
3037 notify_membership_updated(
3038 &mut connection_pool,
3039 membership_update,
3040 session.user_id(),
3041 &session.peer,
3042 );
3043 } else {
3044 let update = proto::UpdateChannels {
3045 remove_channel_invitations: vec![channel_id.to_proto()],
3046 ..Default::default()
3047 };
3048
3049 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3050 session.peer.send(connection_id, update.clone())?;
3051 }
3052 };
3053
3054 send_notifications(&connection_pool, &session.peer, notifications);
3055
3056 response.send(proto::Ack {})?;
3057
3058 Ok(())
3059}
3060
3061/// Join the channels' room
3062async fn join_channel(
3063 request: proto::JoinChannel,
3064 response: Response<proto::JoinChannel>,
3065 session: Session,
3066) -> Result<()> {
3067 let channel_id = ChannelId::from_proto(request.channel_id);
3068 join_channel_internal(channel_id, Box::new(response), session).await
3069}
3070
3071trait JoinChannelInternalResponse {
3072 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3073}
3074impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3075 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3076 Response::<proto::JoinChannel>::send(self, result)
3077 }
3078}
3079impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3080 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3081 Response::<proto::JoinRoom>::send(self, result)
3082 }
3083}
3084
3085async fn join_channel_internal(
3086 channel_id: ChannelId,
3087 response: Box<impl JoinChannelInternalResponse>,
3088 session: Session,
3089) -> Result<()> {
3090 let joined_room = {
3091 let mut db = session.db().await;
3092 // If zed quits without leaving the room, and the user re-opens zed before the
3093 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3094 // room they were in.
3095 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3096 tracing::info!(
3097 stale_connection_id = %connection,
3098 "cleaning up stale connection",
3099 );
3100 drop(db);
3101 leave_room_for_session(&session, connection).await?;
3102 db = session.db().await;
3103 }
3104
3105 let (joined_room, membership_updated, role) = db
3106 .join_channel(channel_id, session.user_id(), session.connection_id)
3107 .await?;
3108
3109 let live_kit_connection_info =
3110 session
3111 .app_state
3112 .livekit_client
3113 .as_ref()
3114 .and_then(|live_kit| {
3115 let (can_publish, token) = if role == ChannelRole::Guest {
3116 (
3117 false,
3118 live_kit
3119 .guest_token(
3120 &joined_room.room.livekit_room,
3121 &session.user_id().to_string(),
3122 )
3123 .trace_err()?,
3124 )
3125 } else {
3126 (
3127 true,
3128 live_kit
3129 .room_token(
3130 &joined_room.room.livekit_room,
3131 &session.user_id().to_string(),
3132 )
3133 .trace_err()?,
3134 )
3135 };
3136
3137 Some(LiveKitConnectionInfo {
3138 server_url: live_kit.url().into(),
3139 token,
3140 can_publish,
3141 })
3142 });
3143
3144 response.send(proto::JoinRoomResponse {
3145 room: Some(joined_room.room.clone()),
3146 channel_id: joined_room
3147 .channel
3148 .as_ref()
3149 .map(|channel| channel.id.to_proto()),
3150 live_kit_connection_info,
3151 })?;
3152
3153 let mut connection_pool = session.connection_pool().await;
3154 if let Some(membership_updated) = membership_updated {
3155 notify_membership_updated(
3156 &mut connection_pool,
3157 membership_updated,
3158 session.user_id(),
3159 &session.peer,
3160 );
3161 }
3162
3163 room_updated(&joined_room.room, &session.peer);
3164
3165 joined_room
3166 };
3167
3168 channel_updated(
3169 &joined_room
3170 .channel
3171 .ok_or_else(|| anyhow!("channel not returned"))?,
3172 &joined_room.room,
3173 &session.peer,
3174 &*session.connection_pool().await,
3175 );
3176
3177 update_user_contacts(session.user_id(), &session).await?;
3178 Ok(())
3179}
3180
3181/// Start editing the channel notes
3182async fn join_channel_buffer(
3183 request: proto::JoinChannelBuffer,
3184 response: Response<proto::JoinChannelBuffer>,
3185 session: Session,
3186) -> Result<()> {
3187 let db = session.db().await;
3188 let channel_id = ChannelId::from_proto(request.channel_id);
3189
3190 let open_response = db
3191 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3192 .await?;
3193
3194 let collaborators = open_response.collaborators.clone();
3195 response.send(open_response)?;
3196
3197 let update = UpdateChannelBufferCollaborators {
3198 channel_id: channel_id.to_proto(),
3199 collaborators: collaborators.clone(),
3200 };
3201 channel_buffer_updated(
3202 session.connection_id,
3203 collaborators
3204 .iter()
3205 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3206 &update,
3207 &session.peer,
3208 );
3209
3210 Ok(())
3211}
3212
3213/// Edit the channel notes
3214async fn update_channel_buffer(
3215 request: proto::UpdateChannelBuffer,
3216 session: Session,
3217) -> Result<()> {
3218 let db = session.db().await;
3219 let channel_id = ChannelId::from_proto(request.channel_id);
3220
3221 let (collaborators, epoch, version) = db
3222 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3223 .await?;
3224
3225 channel_buffer_updated(
3226 session.connection_id,
3227 collaborators.clone(),
3228 &proto::UpdateChannelBuffer {
3229 channel_id: channel_id.to_proto(),
3230 operations: request.operations,
3231 },
3232 &session.peer,
3233 );
3234
3235 let pool = &*session.connection_pool().await;
3236
3237 let non_collaborators =
3238 pool.channel_connection_ids(channel_id)
3239 .filter_map(|(connection_id, _)| {
3240 if collaborators.contains(&connection_id) {
3241 None
3242 } else {
3243 Some(connection_id)
3244 }
3245 });
3246
3247 broadcast(None, non_collaborators, |peer_id| {
3248 session.peer.send(
3249 peer_id,
3250 proto::UpdateChannels {
3251 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3252 channel_id: channel_id.to_proto(),
3253 epoch: epoch as u64,
3254 version: version.clone(),
3255 }],
3256 ..Default::default()
3257 },
3258 )
3259 });
3260
3261 Ok(())
3262}
3263
3264/// Rejoin the channel notes after a connection blip
3265async fn rejoin_channel_buffers(
3266 request: proto::RejoinChannelBuffers,
3267 response: Response<proto::RejoinChannelBuffers>,
3268 session: Session,
3269) -> Result<()> {
3270 let db = session.db().await;
3271 let buffers = db
3272 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3273 .await?;
3274
3275 for rejoined_buffer in &buffers {
3276 let collaborators_to_notify = rejoined_buffer
3277 .buffer
3278 .collaborators
3279 .iter()
3280 .filter_map(|c| Some(c.peer_id?.into()));
3281 channel_buffer_updated(
3282 session.connection_id,
3283 collaborators_to_notify,
3284 &proto::UpdateChannelBufferCollaborators {
3285 channel_id: rejoined_buffer.buffer.channel_id,
3286 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3287 },
3288 &session.peer,
3289 );
3290 }
3291
3292 response.send(proto::RejoinChannelBuffersResponse {
3293 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3294 })?;
3295
3296 Ok(())
3297}
3298
3299/// Stop editing the channel notes
3300async fn leave_channel_buffer(
3301 request: proto::LeaveChannelBuffer,
3302 response: Response<proto::LeaveChannelBuffer>,
3303 session: Session,
3304) -> Result<()> {
3305 let db = session.db().await;
3306 let channel_id = ChannelId::from_proto(request.channel_id);
3307
3308 let left_buffer = db
3309 .leave_channel_buffer(channel_id, session.connection_id)
3310 .await?;
3311
3312 response.send(Ack {})?;
3313
3314 channel_buffer_updated(
3315 session.connection_id,
3316 left_buffer.connections,
3317 &proto::UpdateChannelBufferCollaborators {
3318 channel_id: channel_id.to_proto(),
3319 collaborators: left_buffer.collaborators,
3320 },
3321 &session.peer,
3322 );
3323
3324 Ok(())
3325}
3326
3327fn channel_buffer_updated<T: EnvelopedMessage>(
3328 sender_id: ConnectionId,
3329 collaborators: impl IntoIterator<Item = ConnectionId>,
3330 message: &T,
3331 peer: &Peer,
3332) {
3333 broadcast(Some(sender_id), collaborators, |peer_id| {
3334 peer.send(peer_id, message.clone())
3335 });
3336}
3337
3338fn send_notifications(
3339 connection_pool: &ConnectionPool,
3340 peer: &Peer,
3341 notifications: db::NotificationBatch,
3342) {
3343 for (user_id, notification) in notifications {
3344 for connection_id in connection_pool.user_connection_ids(user_id) {
3345 if let Err(error) = peer.send(
3346 connection_id,
3347 proto::AddNotification {
3348 notification: Some(notification.clone()),
3349 },
3350 ) {
3351 tracing::error!(
3352 "failed to send notification to {:?} {}",
3353 connection_id,
3354 error
3355 );
3356 }
3357 }
3358 }
3359}
3360
3361/// Send a message to the channel
3362async fn send_channel_message(
3363 request: proto::SendChannelMessage,
3364 response: Response<proto::SendChannelMessage>,
3365 session: Session,
3366) -> Result<()> {
3367 // Validate the message body.
3368 let body = request.body.trim().to_string();
3369 if body.len() > MAX_MESSAGE_LEN {
3370 return Err(anyhow!("message is too long"))?;
3371 }
3372 if body.is_empty() {
3373 return Err(anyhow!("message can't be blank"))?;
3374 }
3375
3376 // TODO: adjust mentions if body is trimmed
3377
3378 let timestamp = OffsetDateTime::now_utc();
3379 let nonce = request
3380 .nonce
3381 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3382
3383 let channel_id = ChannelId::from_proto(request.channel_id);
3384 let CreatedChannelMessage {
3385 message_id,
3386 participant_connection_ids,
3387 notifications,
3388 } = session
3389 .db()
3390 .await
3391 .create_channel_message(
3392 channel_id,
3393 session.user_id(),
3394 &body,
3395 &request.mentions,
3396 timestamp,
3397 nonce.clone().into(),
3398 request.reply_to_message_id.map(MessageId::from_proto),
3399 )
3400 .await?;
3401
3402 let message = proto::ChannelMessage {
3403 sender_id: session.user_id().to_proto(),
3404 id: message_id.to_proto(),
3405 body,
3406 mentions: request.mentions,
3407 timestamp: timestamp.unix_timestamp() as u64,
3408 nonce: Some(nonce),
3409 reply_to_message_id: request.reply_to_message_id,
3410 edited_at: None,
3411 };
3412 broadcast(
3413 Some(session.connection_id),
3414 participant_connection_ids.clone(),
3415 |connection| {
3416 session.peer.send(
3417 connection,
3418 proto::ChannelMessageSent {
3419 channel_id: channel_id.to_proto(),
3420 message: Some(message.clone()),
3421 },
3422 )
3423 },
3424 );
3425 response.send(proto::SendChannelMessageResponse {
3426 message: Some(message),
3427 })?;
3428
3429 let pool = &*session.connection_pool().await;
3430 let non_participants =
3431 pool.channel_connection_ids(channel_id)
3432 .filter_map(|(connection_id, _)| {
3433 if participant_connection_ids.contains(&connection_id) {
3434 None
3435 } else {
3436 Some(connection_id)
3437 }
3438 });
3439 broadcast(None, non_participants, |peer_id| {
3440 session.peer.send(
3441 peer_id,
3442 proto::UpdateChannels {
3443 latest_channel_message_ids: vec![proto::ChannelMessageId {
3444 channel_id: channel_id.to_proto(),
3445 message_id: message_id.to_proto(),
3446 }],
3447 ..Default::default()
3448 },
3449 )
3450 });
3451 send_notifications(pool, &session.peer, notifications);
3452
3453 Ok(())
3454}
3455
3456/// Delete a channel message
3457async fn remove_channel_message(
3458 request: proto::RemoveChannelMessage,
3459 response: Response<proto::RemoveChannelMessage>,
3460 session: Session,
3461) -> Result<()> {
3462 let channel_id = ChannelId::from_proto(request.channel_id);
3463 let message_id = MessageId::from_proto(request.message_id);
3464 let (connection_ids, existing_notification_ids) = session
3465 .db()
3466 .await
3467 .remove_channel_message(channel_id, message_id, session.user_id())
3468 .await?;
3469
3470 broadcast(
3471 Some(session.connection_id),
3472 connection_ids,
3473 move |connection| {
3474 session.peer.send(connection, request.clone())?;
3475
3476 for notification_id in &existing_notification_ids {
3477 session.peer.send(
3478 connection,
3479 proto::DeleteNotification {
3480 notification_id: (*notification_id).to_proto(),
3481 },
3482 )?;
3483 }
3484
3485 Ok(())
3486 },
3487 );
3488 response.send(proto::Ack {})?;
3489 Ok(())
3490}
3491
3492async fn update_channel_message(
3493 request: proto::UpdateChannelMessage,
3494 response: Response<proto::UpdateChannelMessage>,
3495 session: Session,
3496) -> Result<()> {
3497 let channel_id = ChannelId::from_proto(request.channel_id);
3498 let message_id = MessageId::from_proto(request.message_id);
3499 let updated_at = OffsetDateTime::now_utc();
3500 let UpdatedChannelMessage {
3501 message_id,
3502 participant_connection_ids,
3503 notifications,
3504 reply_to_message_id,
3505 timestamp,
3506 deleted_mention_notification_ids,
3507 updated_mention_notifications,
3508 } = session
3509 .db()
3510 .await
3511 .update_channel_message(
3512 channel_id,
3513 message_id,
3514 session.user_id(),
3515 request.body.as_str(),
3516 &request.mentions,
3517 updated_at,
3518 )
3519 .await?;
3520
3521 let nonce = request
3522 .nonce
3523 .clone()
3524 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3525
3526 let message = proto::ChannelMessage {
3527 sender_id: session.user_id().to_proto(),
3528 id: message_id.to_proto(),
3529 body: request.body.clone(),
3530 mentions: request.mentions.clone(),
3531 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3532 nonce: Some(nonce),
3533 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3534 edited_at: Some(updated_at.unix_timestamp() as u64),
3535 };
3536
3537 response.send(proto::Ack {})?;
3538
3539 let pool = &*session.connection_pool().await;
3540 broadcast(
3541 Some(session.connection_id),
3542 participant_connection_ids,
3543 |connection| {
3544 session.peer.send(
3545 connection,
3546 proto::ChannelMessageUpdate {
3547 channel_id: channel_id.to_proto(),
3548 message: Some(message.clone()),
3549 },
3550 )?;
3551
3552 for notification_id in &deleted_mention_notification_ids {
3553 session.peer.send(
3554 connection,
3555 proto::DeleteNotification {
3556 notification_id: (*notification_id).to_proto(),
3557 },
3558 )?;
3559 }
3560
3561 for notification in &updated_mention_notifications {
3562 session.peer.send(
3563 connection,
3564 proto::UpdateNotification {
3565 notification: Some(notification.clone()),
3566 },
3567 )?;
3568 }
3569
3570 Ok(())
3571 },
3572 );
3573
3574 send_notifications(pool, &session.peer, notifications);
3575
3576 Ok(())
3577}
3578
3579/// Mark a channel message as read
3580async fn acknowledge_channel_message(
3581 request: proto::AckChannelMessage,
3582 session: Session,
3583) -> Result<()> {
3584 let channel_id = ChannelId::from_proto(request.channel_id);
3585 let message_id = MessageId::from_proto(request.message_id);
3586 let notifications = session
3587 .db()
3588 .await
3589 .observe_channel_message(channel_id, session.user_id(), message_id)
3590 .await?;
3591 send_notifications(
3592 &*session.connection_pool().await,
3593 &session.peer,
3594 notifications,
3595 );
3596 Ok(())
3597}
3598
3599/// Mark a buffer version as synced
3600async fn acknowledge_buffer_version(
3601 request: proto::AckBufferOperation,
3602 session: Session,
3603) -> Result<()> {
3604 let buffer_id = BufferId::from_proto(request.buffer_id);
3605 session
3606 .db()
3607 .await
3608 .observe_buffer_version(
3609 buffer_id,
3610 session.user_id(),
3611 request.epoch as i32,
3612 &request.version,
3613 )
3614 .await?;
3615 Ok(())
3616}
3617
3618async fn count_language_model_tokens(
3619 request: proto::CountLanguageModelTokens,
3620 response: Response<proto::CountLanguageModelTokens>,
3621 session: Session,
3622 config: &Config,
3623) -> Result<()> {
3624 authorize_access_to_legacy_llm_endpoints(&session).await?;
3625
3626 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3627 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3628 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3629 };
3630
3631 session
3632 .app_state
3633 .rate_limiter
3634 .check(&*rate_limit, session.user_id())
3635 .await?;
3636
3637 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3638 Some(proto::LanguageModelProvider::Google) => {
3639 let api_key = config
3640 .google_ai_api_key
3641 .as_ref()
3642 .context("no Google AI API key configured on the server")?;
3643 google_ai::count_tokens(
3644 session.http_client.as_ref(),
3645 google_ai::API_URL,
3646 api_key,
3647 serde_json::from_str(&request.request)?,
3648 )
3649 .await?
3650 }
3651 _ => return Err(anyhow!("unsupported provider"))?,
3652 };
3653
3654 response.send(proto::CountLanguageModelTokensResponse {
3655 token_count: result.total_tokens as u32,
3656 })?;
3657
3658 Ok(())
3659}
3660
3661struct ZedProCountLanguageModelTokensRateLimit;
3662
3663impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3664 fn capacity(&self) -> usize {
3665 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3666 .ok()
3667 .and_then(|v| v.parse().ok())
3668 .unwrap_or(600) // Picked arbitrarily
3669 }
3670
3671 fn refill_duration(&self) -> chrono::Duration {
3672 chrono::Duration::hours(1)
3673 }
3674
3675 fn db_name(&self) -> &'static str {
3676 "zed-pro:count-language-model-tokens"
3677 }
3678}
3679
3680struct FreeCountLanguageModelTokensRateLimit;
3681
3682impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3683 fn capacity(&self) -> usize {
3684 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3685 .ok()
3686 .and_then(|v| v.parse().ok())
3687 .unwrap_or(600 / 10) // Picked arbitrarily
3688 }
3689
3690 fn refill_duration(&self) -> chrono::Duration {
3691 chrono::Duration::hours(1)
3692 }
3693
3694 fn db_name(&self) -> &'static str {
3695 "free:count-language-model-tokens"
3696 }
3697}
3698
3699struct ZedProComputeEmbeddingsRateLimit;
3700
3701impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3702 fn capacity(&self) -> usize {
3703 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3704 .ok()
3705 .and_then(|v| v.parse().ok())
3706 .unwrap_or(5000) // Picked arbitrarily
3707 }
3708
3709 fn refill_duration(&self) -> chrono::Duration {
3710 chrono::Duration::hours(1)
3711 }
3712
3713 fn db_name(&self) -> &'static str {
3714 "zed-pro:compute-embeddings"
3715 }
3716}
3717
3718struct FreeComputeEmbeddingsRateLimit;
3719
3720impl RateLimit for FreeComputeEmbeddingsRateLimit {
3721 fn capacity(&self) -> usize {
3722 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3723 .ok()
3724 .and_then(|v| v.parse().ok())
3725 .unwrap_or(5000 / 10) // Picked arbitrarily
3726 }
3727
3728 fn refill_duration(&self) -> chrono::Duration {
3729 chrono::Duration::hours(1)
3730 }
3731
3732 fn db_name(&self) -> &'static str {
3733 "free:compute-embeddings"
3734 }
3735}
3736
3737async fn compute_embeddings(
3738 request: proto::ComputeEmbeddings,
3739 response: Response<proto::ComputeEmbeddings>,
3740 session: Session,
3741 api_key: Option<Arc<str>>,
3742) -> Result<()> {
3743 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3744 authorize_access_to_legacy_llm_endpoints(&session).await?;
3745
3746 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3747 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3748 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3749 };
3750
3751 session
3752 .app_state
3753 .rate_limiter
3754 .check(&*rate_limit, session.user_id())
3755 .await?;
3756
3757 let embeddings = match request.model.as_str() {
3758 "openai/text-embedding-3-small" => {
3759 open_ai::embed(
3760 session.http_client.as_ref(),
3761 OPEN_AI_API_URL,
3762 &api_key,
3763 OpenAiEmbeddingModel::TextEmbedding3Small,
3764 request.texts.iter().map(|text| text.as_str()),
3765 )
3766 .await?
3767 }
3768 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3769 };
3770
3771 let embeddings = request
3772 .texts
3773 .iter()
3774 .map(|text| {
3775 let mut hasher = sha2::Sha256::new();
3776 hasher.update(text.as_bytes());
3777 let result = hasher.finalize();
3778 result.to_vec()
3779 })
3780 .zip(
3781 embeddings
3782 .data
3783 .into_iter()
3784 .map(|embedding| embedding.embedding),
3785 )
3786 .collect::<HashMap<_, _>>();
3787
3788 let db = session.db().await;
3789 db.save_embeddings(&request.model, &embeddings)
3790 .await
3791 .context("failed to save embeddings")
3792 .trace_err();
3793
3794 response.send(proto::ComputeEmbeddingsResponse {
3795 embeddings: embeddings
3796 .into_iter()
3797 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3798 .collect(),
3799 })?;
3800 Ok(())
3801}
3802
3803async fn get_cached_embeddings(
3804 request: proto::GetCachedEmbeddings,
3805 response: Response<proto::GetCachedEmbeddings>,
3806 session: Session,
3807) -> Result<()> {
3808 authorize_access_to_legacy_llm_endpoints(&session).await?;
3809
3810 let db = session.db().await;
3811 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3812
3813 response.send(proto::GetCachedEmbeddingsResponse {
3814 embeddings: embeddings
3815 .into_iter()
3816 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3817 .collect(),
3818 })?;
3819 Ok(())
3820}
3821
3822/// This is leftover from before the LLM service.
3823///
3824/// The endpoints protected by this check will be moved there eventually.
3825async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3826 if session.is_staff() {
3827 Ok(())
3828 } else {
3829 Err(anyhow!("permission denied"))?
3830 }
3831}
3832
3833/// Get a Supermaven API key for the user
3834async fn get_supermaven_api_key(
3835 _request: proto::GetSupermavenApiKey,
3836 response: Response<proto::GetSupermavenApiKey>,
3837 session: Session,
3838) -> Result<()> {
3839 let user_id: String = session.user_id().to_string();
3840 if !session.is_staff() {
3841 return Err(anyhow!("supermaven not enabled for this account"))?;
3842 }
3843
3844 let email = session
3845 .email()
3846 .ok_or_else(|| anyhow!("user must have an email"))?;
3847
3848 let supermaven_admin_api = session
3849 .supermaven_client
3850 .as_ref()
3851 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3852
3853 let result = supermaven_admin_api
3854 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3855 .await?;
3856
3857 response.send(proto::GetSupermavenApiKeyResponse {
3858 api_key: result.api_key,
3859 })?;
3860
3861 Ok(())
3862}
3863
3864/// Start receiving chat updates for a channel
3865async fn join_channel_chat(
3866 request: proto::JoinChannelChat,
3867 response: Response<proto::JoinChannelChat>,
3868 session: Session,
3869) -> Result<()> {
3870 let channel_id = ChannelId::from_proto(request.channel_id);
3871
3872 let db = session.db().await;
3873 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3874 .await?;
3875 let messages = db
3876 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3877 .await?;
3878 response.send(proto::JoinChannelChatResponse {
3879 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3880 messages,
3881 })?;
3882 Ok(())
3883}
3884
3885/// Stop receiving chat updates for a channel
3886async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3887 let channel_id = ChannelId::from_proto(request.channel_id);
3888 session
3889 .db()
3890 .await
3891 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3892 .await?;
3893 Ok(())
3894}
3895
3896/// Retrieve the chat history for a channel
3897async fn get_channel_messages(
3898 request: proto::GetChannelMessages,
3899 response: Response<proto::GetChannelMessages>,
3900 session: Session,
3901) -> Result<()> {
3902 let channel_id = ChannelId::from_proto(request.channel_id);
3903 let messages = session
3904 .db()
3905 .await
3906 .get_channel_messages(
3907 channel_id,
3908 session.user_id(),
3909 MESSAGE_COUNT_PER_PAGE,
3910 Some(MessageId::from_proto(request.before_message_id)),
3911 )
3912 .await?;
3913 response.send(proto::GetChannelMessagesResponse {
3914 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3915 messages,
3916 })?;
3917 Ok(())
3918}
3919
3920/// Retrieve specific chat messages
3921async fn get_channel_messages_by_id(
3922 request: proto::GetChannelMessagesById,
3923 response: Response<proto::GetChannelMessagesById>,
3924 session: Session,
3925) -> Result<()> {
3926 let message_ids = request
3927 .message_ids
3928 .iter()
3929 .map(|id| MessageId::from_proto(*id))
3930 .collect::<Vec<_>>();
3931 let messages = session
3932 .db()
3933 .await
3934 .get_channel_messages_by_id(session.user_id(), &message_ids)
3935 .await?;
3936 response.send(proto::GetChannelMessagesResponse {
3937 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3938 messages,
3939 })?;
3940 Ok(())
3941}
3942
3943/// Retrieve the current users notifications
3944async fn get_notifications(
3945 request: proto::GetNotifications,
3946 response: Response<proto::GetNotifications>,
3947 session: Session,
3948) -> Result<()> {
3949 let notifications = session
3950 .db()
3951 .await
3952 .get_notifications(
3953 session.user_id(),
3954 NOTIFICATION_COUNT_PER_PAGE,
3955 request.before_id.map(db::NotificationId::from_proto),
3956 )
3957 .await?;
3958 response.send(proto::GetNotificationsResponse {
3959 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3960 notifications,
3961 })?;
3962 Ok(())
3963}
3964
3965/// Mark notifications as read
3966async fn mark_notification_as_read(
3967 request: proto::MarkNotificationRead,
3968 response: Response<proto::MarkNotificationRead>,
3969 session: Session,
3970) -> Result<()> {
3971 let database = &session.db().await;
3972 let notifications = database
3973 .mark_notification_as_read_by_id(
3974 session.user_id(),
3975 NotificationId::from_proto(request.notification_id),
3976 )
3977 .await?;
3978 send_notifications(
3979 &*session.connection_pool().await,
3980 &session.peer,
3981 notifications,
3982 );
3983 response.send(proto::Ack {})?;
3984 Ok(())
3985}
3986
3987/// Get the current users information
3988async fn get_private_user_info(
3989 _request: proto::GetPrivateUserInfo,
3990 response: Response<proto::GetPrivateUserInfo>,
3991 session: Session,
3992) -> Result<()> {
3993 let db = session.db().await;
3994
3995 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3996 let user = db
3997 .get_user_by_id(session.user_id())
3998 .await?
3999 .ok_or_else(|| anyhow!("user not found"))?;
4000 let flags = db.get_user_flags(session.user_id()).await?;
4001
4002 response.send(proto::GetPrivateUserInfoResponse {
4003 metrics_id,
4004 staff: user.admin,
4005 flags,
4006 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4007 })?;
4008 Ok(())
4009}
4010
4011/// Accept the terms of service (tos) on behalf of the current user
4012async fn accept_terms_of_service(
4013 _request: proto::AcceptTermsOfService,
4014 response: Response<proto::AcceptTermsOfService>,
4015 session: Session,
4016) -> Result<()> {
4017 let db = session.db().await;
4018
4019 let accepted_tos_at = Utc::now();
4020 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4021 .await?;
4022
4023 response.send(proto::AcceptTermsOfServiceResponse {
4024 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4025 })?;
4026 Ok(())
4027}
4028
4029/// The minimum account age an account must have in order to use the LLM service.
4030const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4031
4032async fn get_llm_api_token(
4033 _request: proto::GetLlmToken,
4034 response: Response<proto::GetLlmToken>,
4035 session: Session,
4036) -> Result<()> {
4037 let db = session.db().await;
4038
4039 let flags = db.get_user_flags(session.user_id()).await?;
4040 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4041 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4042 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4043
4044 if !session.is_staff() && !has_language_models_feature_flag {
4045 Err(anyhow!("permission denied"))?
4046 }
4047
4048 let user_id = session.user_id();
4049 let user = db
4050 .get_user_by_id(user_id)
4051 .await?
4052 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4053
4054 if user.accepted_tos_at.is_none() {
4055 Err(anyhow!("terms of service not accepted"))?
4056 }
4057
4058 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4059
4060 let bypass_account_age_check =
4061 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4062 if !bypass_account_age_check {
4063 let mut account_created_at = user.created_at;
4064 if let Some(github_created_at) = user.github_user_created_at {
4065 account_created_at = account_created_at.min(github_created_at);
4066 }
4067 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4068 Err(anyhow!("account too young"))?
4069 }
4070 }
4071
4072 let billing_preferences = db.get_billing_preferences(user.id).await?;
4073
4074 let token = LlmTokenClaims::create(
4075 &user,
4076 session.is_staff(),
4077 billing_preferences,
4078 has_llm_closed_beta_feature_flag,
4079 has_predict_edits_feature_flag,
4080 has_llm_subscription,
4081 session.current_plan(&db).await?,
4082 session.system_id.clone(),
4083 &session.app_state.config,
4084 )?;
4085 response.send(proto::GetLlmTokenResponse { token })?;
4086 Ok(())
4087}
4088
4089fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4090 let message = match message {
4091 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4092 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4093 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4094 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4095 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4096 code: frame.code.into(),
4097 reason: frame.reason,
4098 })),
4099 // We should never receive a frame while reading the message, according
4100 // to the `tungstenite` maintainers:
4101 //
4102 // > It cannot occur when you read messages from the WebSocket, but it
4103 // > can be used when you want to send the raw frames (e.g. you want to
4104 // > send the frames to the WebSocket without composing the full message first).
4105 // >
4106 // > — https://github.com/snapview/tungstenite-rs/issues/268
4107 TungsteniteMessage::Frame(_) => {
4108 bail!("received an unexpected frame while reading the message")
4109 }
4110 };
4111
4112 Ok(message)
4113}
4114
4115fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4116 match message {
4117 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4118 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4119 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4120 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4121 AxumMessage::Close(frame) => {
4122 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4123 code: frame.code.into(),
4124 reason: frame.reason,
4125 }))
4126 }
4127 }
4128}
4129
4130fn notify_membership_updated(
4131 connection_pool: &mut ConnectionPool,
4132 result: MembershipUpdated,
4133 user_id: UserId,
4134 peer: &Peer,
4135) {
4136 for membership in &result.new_channels.channel_memberships {
4137 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4138 }
4139 for channel_id in &result.removed_channels {
4140 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4141 }
4142
4143 let user_channels_update = proto::UpdateUserChannels {
4144 channel_memberships: result
4145 .new_channels
4146 .channel_memberships
4147 .iter()
4148 .map(|cm| proto::ChannelMembership {
4149 channel_id: cm.channel_id.to_proto(),
4150 role: cm.role.into(),
4151 })
4152 .collect(),
4153 ..Default::default()
4154 };
4155
4156 let mut update = build_channels_update(result.new_channels);
4157 update.delete_channels = result
4158 .removed_channels
4159 .into_iter()
4160 .map(|id| id.to_proto())
4161 .collect();
4162 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4163
4164 for connection_id in connection_pool.user_connection_ids(user_id) {
4165 peer.send(connection_id, user_channels_update.clone())
4166 .trace_err();
4167 peer.send(connection_id, update.clone()).trace_err();
4168 }
4169}
4170
4171fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4172 proto::UpdateUserChannels {
4173 channel_memberships: channels
4174 .channel_memberships
4175 .iter()
4176 .map(|m| proto::ChannelMembership {
4177 channel_id: m.channel_id.to_proto(),
4178 role: m.role.into(),
4179 })
4180 .collect(),
4181 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4182 observed_channel_message_id: channels.observed_channel_messages.clone(),
4183 }
4184}
4185
4186fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4187 let mut update = proto::UpdateChannels::default();
4188
4189 for channel in channels.channels {
4190 update.channels.push(channel.to_proto());
4191 }
4192
4193 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4194 update.latest_channel_message_ids = channels.latest_channel_messages;
4195
4196 for (channel_id, participants) in channels.channel_participants {
4197 update
4198 .channel_participants
4199 .push(proto::ChannelParticipants {
4200 channel_id: channel_id.to_proto(),
4201 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4202 });
4203 }
4204
4205 for channel in channels.invited_channels {
4206 update.channel_invitations.push(channel.to_proto());
4207 }
4208
4209 update
4210}
4211
4212fn build_initial_contacts_update(
4213 contacts: Vec<db::Contact>,
4214 pool: &ConnectionPool,
4215) -> proto::UpdateContacts {
4216 let mut update = proto::UpdateContacts::default();
4217
4218 for contact in contacts {
4219 match contact {
4220 db::Contact::Accepted { user_id, busy } => {
4221 update.contacts.push(contact_for_user(user_id, busy, pool));
4222 }
4223 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4224 db::Contact::Incoming { user_id } => {
4225 update
4226 .incoming_requests
4227 .push(proto::IncomingContactRequest {
4228 requester_id: user_id.to_proto(),
4229 })
4230 }
4231 }
4232 }
4233
4234 update
4235}
4236
4237fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4238 proto::Contact {
4239 user_id: user_id.to_proto(),
4240 online: pool.is_user_online(user_id),
4241 busy,
4242 }
4243}
4244
4245fn room_updated(room: &proto::Room, peer: &Peer) {
4246 broadcast(
4247 None,
4248 room.participants
4249 .iter()
4250 .filter_map(|participant| Some(participant.peer_id?.into())),
4251 |peer_id| {
4252 peer.send(
4253 peer_id,
4254 proto::RoomUpdated {
4255 room: Some(room.clone()),
4256 },
4257 )
4258 },
4259 );
4260}
4261
4262fn channel_updated(
4263 channel: &db::channel::Model,
4264 room: &proto::Room,
4265 peer: &Peer,
4266 pool: &ConnectionPool,
4267) {
4268 let participants = room
4269 .participants
4270 .iter()
4271 .map(|p| p.user_id)
4272 .collect::<Vec<_>>();
4273
4274 broadcast(
4275 None,
4276 pool.channel_connection_ids(channel.root_id())
4277 .filter_map(|(channel_id, role)| {
4278 role.can_see_channel(channel.visibility)
4279 .then_some(channel_id)
4280 }),
4281 |peer_id| {
4282 peer.send(
4283 peer_id,
4284 proto::UpdateChannels {
4285 channel_participants: vec![proto::ChannelParticipants {
4286 channel_id: channel.id.to_proto(),
4287 participant_user_ids: participants.clone(),
4288 }],
4289 ..Default::default()
4290 },
4291 )
4292 },
4293 );
4294}
4295
4296async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4297 let db = session.db().await;
4298
4299 let contacts = db.get_contacts(user_id).await?;
4300 let busy = db.is_user_busy(user_id).await?;
4301
4302 let pool = session.connection_pool().await;
4303 let updated_contact = contact_for_user(user_id, busy, &pool);
4304 for contact in contacts {
4305 if let db::Contact::Accepted {
4306 user_id: contact_user_id,
4307 ..
4308 } = contact
4309 {
4310 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4311 session
4312 .peer
4313 .send(
4314 contact_conn_id,
4315 proto::UpdateContacts {
4316 contacts: vec![updated_contact.clone()],
4317 remove_contacts: Default::default(),
4318 incoming_requests: Default::default(),
4319 remove_incoming_requests: Default::default(),
4320 outgoing_requests: Default::default(),
4321 remove_outgoing_requests: Default::default(),
4322 },
4323 )
4324 .trace_err();
4325 }
4326 }
4327 }
4328 Ok(())
4329}
4330
4331async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4332 let mut contacts_to_update = HashSet::default();
4333
4334 let room_id;
4335 let canceled_calls_to_user_ids;
4336 let livekit_room;
4337 let delete_livekit_room;
4338 let room;
4339 let channel;
4340
4341 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4342 contacts_to_update.insert(session.user_id());
4343
4344 for project in left_room.left_projects.values() {
4345 project_left(project, session);
4346 }
4347
4348 room_id = RoomId::from_proto(left_room.room.id);
4349 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4350 livekit_room = mem::take(&mut left_room.room.livekit_room);
4351 delete_livekit_room = left_room.deleted;
4352 room = mem::take(&mut left_room.room);
4353 channel = mem::take(&mut left_room.channel);
4354
4355 room_updated(&room, &session.peer);
4356 } else {
4357 return Ok(());
4358 }
4359
4360 if let Some(channel) = channel {
4361 channel_updated(
4362 &channel,
4363 &room,
4364 &session.peer,
4365 &*session.connection_pool().await,
4366 );
4367 }
4368
4369 {
4370 let pool = session.connection_pool().await;
4371 for canceled_user_id in canceled_calls_to_user_ids {
4372 for connection_id in pool.user_connection_ids(canceled_user_id) {
4373 session
4374 .peer
4375 .send(
4376 connection_id,
4377 proto::CallCanceled {
4378 room_id: room_id.to_proto(),
4379 },
4380 )
4381 .trace_err();
4382 }
4383 contacts_to_update.insert(canceled_user_id);
4384 }
4385 }
4386
4387 for contact_user_id in contacts_to_update {
4388 update_user_contacts(contact_user_id, session).await?;
4389 }
4390
4391 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4392 live_kit
4393 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4394 .await
4395 .trace_err();
4396
4397 if delete_livekit_room {
4398 live_kit.delete_room(livekit_room).await.trace_err();
4399 }
4400 }
4401
4402 Ok(())
4403}
4404
4405async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4406 let left_channel_buffers = session
4407 .db()
4408 .await
4409 .leave_channel_buffers(session.connection_id)
4410 .await?;
4411
4412 for left_buffer in left_channel_buffers {
4413 channel_buffer_updated(
4414 session.connection_id,
4415 left_buffer.connections,
4416 &proto::UpdateChannelBufferCollaborators {
4417 channel_id: left_buffer.channel_id.to_proto(),
4418 collaborators: left_buffer.collaborators,
4419 },
4420 &session.peer,
4421 );
4422 }
4423
4424 Ok(())
4425}
4426
4427fn project_left(project: &db::LeftProject, session: &Session) {
4428 for connection_id in &project.connection_ids {
4429 if project.should_unshare {
4430 session
4431 .peer
4432 .send(
4433 *connection_id,
4434 proto::UnshareProject {
4435 project_id: project.id.to_proto(),
4436 },
4437 )
4438 .trace_err();
4439 } else {
4440 session
4441 .peer
4442 .send(
4443 *connection_id,
4444 proto::RemoveProjectCollaborator {
4445 project_id: project.id.to_proto(),
4446 peer_id: Some(session.connection_id.into()),
4447 },
4448 )
4449 .trace_err();
4450 }
4451 }
4452}
4453
4454pub trait ResultExt {
4455 type Ok;
4456
4457 fn trace_err(self) -> Option<Self::Ok>;
4458}
4459
4460impl<T, E> ResultExt for Result<T, E>
4461where
4462 E: std::fmt::Debug,
4463{
4464 type Ok = T;
4465
4466 #[track_caller]
4467 fn trace_err(self) -> Option<T> {
4468 match self {
4469 Ok(value) => Some(value),
4470 Err(error) => {
4471 tracing::error!("{:?}", error);
4472 None
4473 }
4474 }
4475 }
4476}