1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
314 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
315 .add_request_handler(
316 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
317 )
318 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
319 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
320 .add_request_handler(
321 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
322 )
323 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
324 .add_request_handler(
325 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
326 )
327 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
328 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
329 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
331 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
332 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
333 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
334 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
337 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
338 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
339 .add_request_handler(
340 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
341 )
342 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
343 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
344 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
345 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
346 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
347 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
348 .add_message_handler(create_buffer_for_peer)
349 .add_request_handler(update_buffer)
350 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
353 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
354 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
355 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
356 .add_request_handler(get_users)
357 .add_request_handler(fuzzy_search_users)
358 .add_request_handler(request_contact)
359 .add_request_handler(remove_contact)
360 .add_request_handler(respond_to_contact_request)
361 .add_message_handler(subscribe_to_channels)
362 .add_request_handler(create_channel)
363 .add_request_handler(delete_channel)
364 .add_request_handler(invite_channel_member)
365 .add_request_handler(remove_channel_member)
366 .add_request_handler(set_channel_member_role)
367 .add_request_handler(set_channel_visibility)
368 .add_request_handler(rename_channel)
369 .add_request_handler(join_channel_buffer)
370 .add_request_handler(leave_channel_buffer)
371 .add_message_handler(update_channel_buffer)
372 .add_request_handler(rejoin_channel_buffers)
373 .add_request_handler(get_channel_members)
374 .add_request_handler(respond_to_channel_invite)
375 .add_request_handler(join_channel)
376 .add_request_handler(join_channel_chat)
377 .add_message_handler(leave_channel_chat)
378 .add_request_handler(send_channel_message)
379 .add_request_handler(remove_channel_message)
380 .add_request_handler(update_channel_message)
381 .add_request_handler(get_channel_messages)
382 .add_request_handler(get_channel_messages_by_id)
383 .add_request_handler(get_notifications)
384 .add_request_handler(mark_notification_as_read)
385 .add_request_handler(move_channel)
386 .add_request_handler(follow)
387 .add_message_handler(unfollow)
388 .add_message_handler(update_followers)
389 .add_request_handler(get_private_user_info)
390 .add_request_handler(get_llm_api_token)
391 .add_request_handler(accept_terms_of_service)
392 .add_message_handler(acknowledge_channel_message)
393 .add_message_handler(acknowledge_buffer_version)
394 .add_request_handler(get_supermaven_api_key)
395 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
396 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
397 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
398 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
399 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
400 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
401 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
402 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
403 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
404 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
405 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
406 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
407 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
408 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
409 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
410 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
411 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
412 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
413 .add_message_handler(update_context)
414 .add_request_handler({
415 let app_state = app_state.clone();
416 move |request, response, session| {
417 let app_state = app_state.clone();
418 async move {
419 count_language_model_tokens(request, response, session, &app_state.config)
420 .await
421 }
422 }
423 })
424 .add_request_handler(get_cached_embeddings)
425 .add_request_handler({
426 let app_state = app_state.clone();
427 move |request, response, session| {
428 compute_embeddings(
429 request,
430 response,
431 session,
432 app_state.config.openai_api_key.clone(),
433 )
434 }
435 });
436
437 Arc::new(server)
438 }
439
440 pub async fn start(&self) -> Result<()> {
441 let server_id = *self.id.lock();
442 let app_state = self.app_state.clone();
443 let peer = self.peer.clone();
444 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
445 let pool = self.connection_pool.clone();
446 let livekit_client = self.app_state.livekit_client.clone();
447
448 let span = info_span!("start server");
449 self.app_state.executor.spawn_detached(
450 async move {
451 tracing::info!("waiting for cleanup timeout");
452 timeout.await;
453 tracing::info!("cleanup timeout expired, retrieving stale rooms");
454 if let Some((room_ids, channel_ids)) = app_state
455 .db
456 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
457 .await
458 .trace_err()
459 {
460 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
461 tracing::info!(
462 stale_channel_buffer_count = channel_ids.len(),
463 "retrieved stale channel buffers"
464 );
465
466 for channel_id in channel_ids {
467 if let Some(refreshed_channel_buffer) = app_state
468 .db
469 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
470 .await
471 .trace_err()
472 {
473 for connection_id in refreshed_channel_buffer.connection_ids {
474 peer.send(
475 connection_id,
476 proto::UpdateChannelBufferCollaborators {
477 channel_id: channel_id.to_proto(),
478 collaborators: refreshed_channel_buffer
479 .collaborators
480 .clone(),
481 },
482 )
483 .trace_err();
484 }
485 }
486 }
487
488 for room_id in room_ids {
489 let mut contacts_to_update = HashSet::default();
490 let mut canceled_calls_to_user_ids = Vec::new();
491 let mut livekit_room = String::new();
492 let mut delete_livekit_room = false;
493
494 if let Some(mut refreshed_room) = app_state
495 .db
496 .clear_stale_room_participants(room_id, server_id)
497 .await
498 .trace_err()
499 {
500 tracing::info!(
501 room_id = room_id.0,
502 new_participant_count = refreshed_room.room.participants.len(),
503 "refreshed room"
504 );
505 room_updated(&refreshed_room.room, &peer);
506 if let Some(channel) = refreshed_room.channel.as_ref() {
507 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
508 }
509 contacts_to_update
510 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
511 contacts_to_update
512 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
513 canceled_calls_to_user_ids =
514 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
515 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
516 delete_livekit_room = refreshed_room.room.participants.is_empty();
517 }
518
519 {
520 let pool = pool.lock();
521 for canceled_user_id in canceled_calls_to_user_ids {
522 for connection_id in pool.user_connection_ids(canceled_user_id) {
523 peer.send(
524 connection_id,
525 proto::CallCanceled {
526 room_id: room_id.to_proto(),
527 },
528 )
529 .trace_err();
530 }
531 }
532 }
533
534 for user_id in contacts_to_update {
535 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
536 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
537 if let Some((busy, contacts)) = busy.zip(contacts) {
538 let pool = pool.lock();
539 let updated_contact = contact_for_user(user_id, busy, &pool);
540 for contact in contacts {
541 if let db::Contact::Accepted {
542 user_id: contact_user_id,
543 ..
544 } = contact
545 {
546 for contact_conn_id in
547 pool.user_connection_ids(contact_user_id)
548 {
549 peer.send(
550 contact_conn_id,
551 proto::UpdateContacts {
552 contacts: vec![updated_contact.clone()],
553 remove_contacts: Default::default(),
554 incoming_requests: Default::default(),
555 remove_incoming_requests: Default::default(),
556 outgoing_requests: Default::default(),
557 remove_outgoing_requests: Default::default(),
558 },
559 )
560 .trace_err();
561 }
562 }
563 }
564 }
565 }
566
567 if let Some(live_kit) = livekit_client.as_ref() {
568 if delete_livekit_room {
569 live_kit.delete_room(livekit_room).await.trace_err();
570 }
571 }
572 }
573 }
574
575 app_state
576 .db
577 .delete_stale_servers(&app_state.config.zed_environment, server_id)
578 .await
579 .trace_err();
580 }
581 .instrument(span),
582 );
583 Ok(())
584 }
585
586 pub fn teardown(&self) {
587 self.peer.teardown();
588 self.connection_pool.lock().reset();
589 let _ = self.teardown.send(true);
590 }
591
592 #[cfg(test)]
593 pub fn reset(&self, id: ServerId) {
594 self.teardown();
595 *self.id.lock() = id;
596 self.peer.reset(id.0 as u32);
597 let _ = self.teardown.send(false);
598 }
599
600 #[cfg(test)]
601 pub fn id(&self) -> ServerId {
602 *self.id.lock()
603 }
604
605 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
606 where
607 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
608 Fut: 'static + Send + Future<Output = Result<()>>,
609 M: EnvelopedMessage,
610 {
611 let prev_handler = self.handlers.insert(
612 TypeId::of::<M>(),
613 Box::new(move |envelope, session| {
614 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
615 let received_at = envelope.received_at;
616 tracing::info!("message received");
617 let start_time = Instant::now();
618 let future = (handler)(*envelope, session);
619 async move {
620 let result = future.await;
621 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
622 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
623 let queue_duration_ms = total_duration_ms - processing_duration_ms;
624 let payload_type = M::NAME;
625
626 match result {
627 Err(error) => {
628 tracing::error!(
629 ?error,
630 total_duration_ms,
631 processing_duration_ms,
632 queue_duration_ms,
633 payload_type,
634 "error handling message"
635 )
636 }
637 Ok(()) => tracing::info!(
638 total_duration_ms,
639 processing_duration_ms,
640 queue_duration_ms,
641 "finished handling message"
642 ),
643 }
644 }
645 .boxed()
646 }),
647 );
648 if prev_handler.is_some() {
649 panic!("registered a handler for the same message twice");
650 }
651 self
652 }
653
654 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
655 where
656 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
657 Fut: 'static + Send + Future<Output = Result<()>>,
658 M: EnvelopedMessage,
659 {
660 self.add_handler(move |envelope, session| handler(envelope.payload, session));
661 self
662 }
663
664 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
665 where
666 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
667 Fut: Send + Future<Output = Result<()>>,
668 M: RequestMessage,
669 {
670 let handler = Arc::new(handler);
671 self.add_handler(move |envelope, session| {
672 let receipt = envelope.receipt();
673 let handler = handler.clone();
674 async move {
675 let peer = session.peer.clone();
676 let responded = Arc::new(AtomicBool::default());
677 let response = Response {
678 peer: peer.clone(),
679 responded: responded.clone(),
680 receipt,
681 };
682 match (handler)(envelope.payload, response, session).await {
683 Ok(()) => {
684 if responded.load(std::sync::atomic::Ordering::SeqCst) {
685 Ok(())
686 } else {
687 Err(anyhow!("handler did not send a response"))?
688 }
689 }
690 Err(error) => {
691 let proto_err = match &error {
692 Error::Internal(err) => err.to_proto(),
693 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
694 };
695 peer.respond_with_error(receipt, proto_err)?;
696 Err(error)
697 }
698 }
699 }
700 })
701 }
702
703 pub fn handle_connection(
704 self: &Arc<Self>,
705 connection: Connection,
706 address: String,
707 principal: Principal,
708 zed_version: ZedVersion,
709 geoip_country_code: Option<String>,
710 system_id: Option<String>,
711 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
712 executor: Executor,
713 ) -> impl Future<Output = ()> {
714 let this = self.clone();
715 let span = info_span!("handle connection", %address,
716 connection_id=field::Empty,
717 user_id=field::Empty,
718 login=field::Empty,
719 impersonator=field::Empty,
720 geoip_country_code=field::Empty
721 );
722 principal.update_span(&span);
723 if let Some(country_code) = geoip_country_code.as_ref() {
724 span.record("geoip_country_code", country_code);
725 }
726
727 let mut teardown = self.teardown.subscribe();
728 async move {
729 if *teardown.borrow() {
730 tracing::error!("server is tearing down");
731 return
732 }
733 let (connection_id, handle_io, mut incoming_rx) = this
734 .peer
735 .add_connection(connection, {
736 let executor = executor.clone();
737 move |duration| executor.sleep(duration)
738 });
739 tracing::Span::current().record("connection_id", format!("{}", connection_id));
740
741 tracing::info!("connection opened");
742
743 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
744 let http_client = match ReqwestClient::user_agent(&user_agent) {
745 Ok(http_client) => Arc::new(http_client),
746 Err(error) => {
747 tracing::error!(?error, "failed to create HTTP client");
748 return;
749 }
750 };
751
752 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
753 supermaven_admin_api_key.to_string(),
754 http_client.clone(),
755 )));
756
757 let session = Session {
758 principal: principal.clone(),
759 connection_id,
760 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
761 peer: this.peer.clone(),
762 connection_pool: this.connection_pool.clone(),
763 app_state: this.app_state.clone(),
764 http_client,
765 geoip_country_code,
766 system_id,
767 _executor: executor.clone(),
768 supermaven_client,
769 };
770
771 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
772 tracing::error!(?error, "failed to send initial client update");
773 return;
774 }
775
776 let handle_io = handle_io.fuse();
777 futures::pin_mut!(handle_io);
778
779 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
780 // This prevents deadlocks when e.g., client A performs a request to client B and
781 // client B performs a request to client A. If both clients stop processing further
782 // messages until their respective request completes, they won't have a chance to
783 // respond to the other client's request and cause a deadlock.
784 //
785 // This arrangement ensures we will attempt to process earlier messages first, but fall
786 // back to processing messages arrived later in the spirit of making progress.
787 let mut foreground_message_handlers = FuturesUnordered::new();
788 let concurrent_handlers = Arc::new(Semaphore::new(256));
789 loop {
790 let next_message = async {
791 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
792 let message = incoming_rx.next().await;
793 (permit, message)
794 }.fuse();
795 futures::pin_mut!(next_message);
796 futures::select_biased! {
797 _ = teardown.changed().fuse() => return,
798 result = handle_io => {
799 if let Err(error) = result {
800 tracing::error!(?error, "error handling I/O");
801 }
802 break;
803 }
804 _ = foreground_message_handlers.next() => {}
805 next_message = next_message => {
806 let (permit, message) = next_message;
807 if let Some(message) = message {
808 let type_name = message.payload_type_name();
809 // note: we copy all the fields from the parent span so we can query them in the logs.
810 // (https://github.com/tokio-rs/tracing/issues/2670).
811 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
812 user_id=field::Empty,
813 login=field::Empty,
814 impersonator=field::Empty,
815 );
816 principal.update_span(&span);
817 let span_enter = span.enter();
818 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
819 let is_background = message.is_background();
820 let handle_message = (handler)(message, session.clone());
821 drop(span_enter);
822
823 let handle_message = async move {
824 handle_message.await;
825 drop(permit);
826 }.instrument(span);
827 if is_background {
828 executor.spawn_detached(handle_message);
829 } else {
830 foreground_message_handlers.push(handle_message);
831 }
832 } else {
833 tracing::error!("no message handler");
834 }
835 } else {
836 tracing::info!("connection closed");
837 break;
838 }
839 }
840 }
841 }
842
843 drop(foreground_message_handlers);
844 tracing::info!("signing out");
845 if let Err(error) = connection_lost(session, teardown, executor).await {
846 tracing::error!(?error, "error signing out");
847 }
848
849 }.instrument(span)
850 }
851
852 async fn send_initial_client_update(
853 &self,
854 connection_id: ConnectionId,
855 principal: &Principal,
856 zed_version: ZedVersion,
857 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
858 session: &Session,
859 ) -> Result<()> {
860 self.peer.send(
861 connection_id,
862 proto::Hello {
863 peer_id: Some(connection_id.into()),
864 },
865 )?;
866 tracing::info!("sent hello message");
867 if let Some(send_connection_id) = send_connection_id.take() {
868 let _ = send_connection_id.send(connection_id);
869 }
870
871 match principal {
872 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
873 if !user.connected_once {
874 self.peer.send(connection_id, proto::ShowContacts {})?;
875 self.app_state
876 .db
877 .set_user_connected_once(user.id, true)
878 .await?;
879 }
880
881 update_user_plan(user.id, session).await?;
882
883 let contacts = self.app_state.db.get_contacts(user.id).await?;
884
885 {
886 let mut pool = self.connection_pool.lock();
887 pool.add_connection(connection_id, user.id, user.admin, zed_version);
888 self.peer.send(
889 connection_id,
890 build_initial_contacts_update(contacts, &pool),
891 )?;
892 }
893
894 if should_auto_subscribe_to_channels(zed_version) {
895 subscribe_user_to_channels(user.id, session).await?;
896 }
897
898 if let Some(incoming_call) =
899 self.app_state.db.incoming_call_for_user(user.id).await?
900 {
901 self.peer.send(connection_id, incoming_call)?;
902 }
903
904 update_user_contacts(user.id, session).await?;
905 }
906 }
907
908 Ok(())
909 }
910
911 pub async fn invite_code_redeemed(
912 self: &Arc<Self>,
913 inviter_id: UserId,
914 invitee_id: UserId,
915 ) -> Result<()> {
916 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
917 if let Some(code) = &user.invite_code {
918 let pool = self.connection_pool.lock();
919 let invitee_contact = contact_for_user(invitee_id, false, &pool);
920 for connection_id in pool.user_connection_ids(inviter_id) {
921 self.peer.send(
922 connection_id,
923 proto::UpdateContacts {
924 contacts: vec![invitee_contact.clone()],
925 ..Default::default()
926 },
927 )?;
928 self.peer.send(
929 connection_id,
930 proto::UpdateInviteInfo {
931 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
932 count: user.invite_count as u32,
933 },
934 )?;
935 }
936 }
937 }
938 Ok(())
939 }
940
941 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
942 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
943 if let Some(invite_code) = &user.invite_code {
944 let pool = self.connection_pool.lock();
945 for connection_id in pool.user_connection_ids(user_id) {
946 self.peer.send(
947 connection_id,
948 proto::UpdateInviteInfo {
949 url: format!(
950 "{}{}",
951 self.app_state.config.invite_link_prefix, invite_code
952 ),
953 count: user.invite_count as u32,
954 },
955 )?;
956 }
957 }
958 }
959 Ok(())
960 }
961
962 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
963 let pool = self.connection_pool.lock();
964 for connection_id in pool.user_connection_ids(user_id) {
965 self.peer
966 .send(connection_id, proto::RefreshLlmToken {})
967 .trace_err();
968 }
969 }
970
971 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
972 ServerSnapshot {
973 connection_pool: ConnectionPoolGuard {
974 guard: self.connection_pool.lock(),
975 _not_send: PhantomData,
976 },
977 peer: &self.peer,
978 }
979 }
980}
981
982impl Deref for ConnectionPoolGuard<'_> {
983 type Target = ConnectionPool;
984
985 fn deref(&self) -> &Self::Target {
986 &self.guard
987 }
988}
989
990impl DerefMut for ConnectionPoolGuard<'_> {
991 fn deref_mut(&mut self) -> &mut Self::Target {
992 &mut self.guard
993 }
994}
995
996impl Drop for ConnectionPoolGuard<'_> {
997 fn drop(&mut self) {
998 #[cfg(test)]
999 self.check_invariants();
1000 }
1001}
1002
1003fn broadcast<F>(
1004 sender_id: Option<ConnectionId>,
1005 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1006 mut f: F,
1007) where
1008 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1009{
1010 for receiver_id in receiver_ids {
1011 if Some(receiver_id) != sender_id {
1012 if let Err(error) = f(receiver_id) {
1013 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1014 }
1015 }
1016 }
1017}
1018
1019pub struct ProtocolVersion(u32);
1020
1021impl Header for ProtocolVersion {
1022 fn name() -> &'static HeaderName {
1023 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1024 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1025 }
1026
1027 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1028 where
1029 Self: Sized,
1030 I: Iterator<Item = &'i axum::http::HeaderValue>,
1031 {
1032 let version = values
1033 .next()
1034 .ok_or_else(axum::headers::Error::invalid)?
1035 .to_str()
1036 .map_err(|_| axum::headers::Error::invalid())?
1037 .parse()
1038 .map_err(|_| axum::headers::Error::invalid())?;
1039 Ok(Self(version))
1040 }
1041
1042 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1043 values.extend([self.0.to_string().parse().unwrap()]);
1044 }
1045}
1046
1047pub struct AppVersionHeader(SemanticVersion);
1048impl Header for AppVersionHeader {
1049 fn name() -> &'static HeaderName {
1050 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1051 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1052 }
1053
1054 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1055 where
1056 Self: Sized,
1057 I: Iterator<Item = &'i axum::http::HeaderValue>,
1058 {
1059 let version = values
1060 .next()
1061 .ok_or_else(axum::headers::Error::invalid)?
1062 .to_str()
1063 .map_err(|_| axum::headers::Error::invalid())?
1064 .parse()
1065 .map_err(|_| axum::headers::Error::invalid())?;
1066 Ok(Self(version))
1067 }
1068
1069 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1070 values.extend([self.0.to_string().parse().unwrap()]);
1071 }
1072}
1073
1074pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1075 Router::new()
1076 .route("/rpc", get(handle_websocket_request))
1077 .layer(
1078 ServiceBuilder::new()
1079 .layer(Extension(server.app_state.clone()))
1080 .layer(middleware::from_fn(auth::validate_header)),
1081 )
1082 .route("/metrics", get(handle_metrics))
1083 .layer(Extension(server))
1084}
1085
1086pub async fn handle_websocket_request(
1087 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1088 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1089 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1090 Extension(server): Extension<Arc<Server>>,
1091 Extension(principal): Extension<Principal>,
1092 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1093 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1094 ws: WebSocketUpgrade,
1095) -> axum::response::Response {
1096 if protocol_version != rpc::PROTOCOL_VERSION {
1097 return (
1098 StatusCode::UPGRADE_REQUIRED,
1099 "client must be upgraded".to_string(),
1100 )
1101 .into_response();
1102 }
1103
1104 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1105 return (
1106 StatusCode::UPGRADE_REQUIRED,
1107 "no version header found".to_string(),
1108 )
1109 .into_response();
1110 };
1111
1112 if !version.can_collaborate() {
1113 return (
1114 StatusCode::UPGRADE_REQUIRED,
1115 "client must be upgraded".to_string(),
1116 )
1117 .into_response();
1118 }
1119
1120 let socket_address = socket_address.to_string();
1121 ws.on_upgrade(move |socket| {
1122 let socket = socket
1123 .map_ok(to_tungstenite_message)
1124 .err_into()
1125 .with(|message| async move { to_axum_message(message) });
1126 let connection = Connection::new(Box::pin(socket));
1127 async move {
1128 server
1129 .handle_connection(
1130 connection,
1131 socket_address,
1132 principal,
1133 version,
1134 country_code_header.map(|header| header.to_string()),
1135 system_id_header.map(|header| header.to_string()),
1136 None,
1137 Executor::Production,
1138 )
1139 .await;
1140 }
1141 })
1142}
1143
1144pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1145 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1146 let connections_metric = CONNECTIONS_METRIC
1147 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1148
1149 let connections = server
1150 .connection_pool
1151 .lock()
1152 .connections()
1153 .filter(|connection| !connection.admin)
1154 .count();
1155 connections_metric.set(connections as _);
1156
1157 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1158 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1159 register_int_gauge!(
1160 "shared_projects",
1161 "number of open projects with one or more guests"
1162 )
1163 .unwrap()
1164 });
1165
1166 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1167 shared_projects_metric.set(shared_projects as _);
1168
1169 let encoder = prometheus::TextEncoder::new();
1170 let metric_families = prometheus::gather();
1171 let encoded_metrics = encoder
1172 .encode_to_string(&metric_families)
1173 .map_err(|err| anyhow!("{}", err))?;
1174 Ok(encoded_metrics)
1175}
1176
1177#[instrument(err, skip(executor))]
1178async fn connection_lost(
1179 session: Session,
1180 mut teardown: watch::Receiver<bool>,
1181 executor: Executor,
1182) -> Result<()> {
1183 session.peer.disconnect(session.connection_id);
1184 session
1185 .connection_pool()
1186 .await
1187 .remove_connection(session.connection_id)?;
1188
1189 session
1190 .db()
1191 .await
1192 .connection_lost(session.connection_id)
1193 .await
1194 .trace_err();
1195
1196 futures::select_biased! {
1197 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1198
1199 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1200 leave_room_for_session(&session, session.connection_id).await.trace_err();
1201 leave_channel_buffers_for_session(&session)
1202 .await
1203 .trace_err();
1204
1205 if !session
1206 .connection_pool()
1207 .await
1208 .is_user_online(session.user_id())
1209 {
1210 let db = session.db().await;
1211 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1212 room_updated(&room, &session.peer);
1213 }
1214 }
1215
1216 update_user_contacts(session.user_id(), &session).await?;
1217 },
1218 _ = teardown.changed().fuse() => {}
1219 }
1220
1221 Ok(())
1222}
1223
1224/// Acknowledges a ping from a client, used to keep the connection alive.
1225async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1226 response.send(proto::Ack {})?;
1227 Ok(())
1228}
1229
1230/// Creates a new room for calling (outside of channels)
1231async fn create_room(
1232 _request: proto::CreateRoom,
1233 response: Response<proto::CreateRoom>,
1234 session: Session,
1235) -> Result<()> {
1236 let livekit_room = nanoid::nanoid!(30);
1237
1238 let live_kit_connection_info = util::maybe!(async {
1239 let live_kit = session.app_state.livekit_client.as_ref();
1240 let live_kit = live_kit?;
1241 let user_id = session.user_id().to_string();
1242
1243 let token = live_kit
1244 .room_token(&livekit_room, &user_id.to_string())
1245 .trace_err()?;
1246
1247 Some(proto::LiveKitConnectionInfo {
1248 server_url: live_kit.url().into(),
1249 token,
1250 can_publish: true,
1251 })
1252 })
1253 .await;
1254
1255 let room = session
1256 .db()
1257 .await
1258 .create_room(session.user_id(), session.connection_id, &livekit_room)
1259 .await?;
1260
1261 response.send(proto::CreateRoomResponse {
1262 room: Some(room.clone()),
1263 live_kit_connection_info,
1264 })?;
1265
1266 update_user_contacts(session.user_id(), &session).await?;
1267 Ok(())
1268}
1269
1270/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1271async fn join_room(
1272 request: proto::JoinRoom,
1273 response: Response<proto::JoinRoom>,
1274 session: Session,
1275) -> Result<()> {
1276 let room_id = RoomId::from_proto(request.id);
1277
1278 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1279
1280 if let Some(channel_id) = channel_id {
1281 return join_channel_internal(channel_id, Box::new(response), session).await;
1282 }
1283
1284 let joined_room = {
1285 let room = session
1286 .db()
1287 .await
1288 .join_room(room_id, session.user_id(), session.connection_id)
1289 .await?;
1290 room_updated(&room.room, &session.peer);
1291 room.into_inner()
1292 };
1293
1294 for connection_id in session
1295 .connection_pool()
1296 .await
1297 .user_connection_ids(session.user_id())
1298 {
1299 session
1300 .peer
1301 .send(
1302 connection_id,
1303 proto::CallCanceled {
1304 room_id: room_id.to_proto(),
1305 },
1306 )
1307 .trace_err();
1308 }
1309
1310 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1311 {
1312 live_kit
1313 .room_token(
1314 &joined_room.room.livekit_room,
1315 &session.user_id().to_string(),
1316 )
1317 .trace_err()
1318 .map(|token| proto::LiveKitConnectionInfo {
1319 server_url: live_kit.url().into(),
1320 token,
1321 can_publish: true,
1322 })
1323 } else {
1324 None
1325 };
1326
1327 response.send(proto::JoinRoomResponse {
1328 room: Some(joined_room.room),
1329 channel_id: None,
1330 live_kit_connection_info,
1331 })?;
1332
1333 update_user_contacts(session.user_id(), &session).await?;
1334 Ok(())
1335}
1336
1337/// Rejoin room is used to reconnect to a room after connection errors.
1338async fn rejoin_room(
1339 request: proto::RejoinRoom,
1340 response: Response<proto::RejoinRoom>,
1341 session: Session,
1342) -> Result<()> {
1343 let room;
1344 let channel;
1345 {
1346 let mut rejoined_room = session
1347 .db()
1348 .await
1349 .rejoin_room(request, session.user_id(), session.connection_id)
1350 .await?;
1351
1352 response.send(proto::RejoinRoomResponse {
1353 room: Some(rejoined_room.room.clone()),
1354 reshared_projects: rejoined_room
1355 .reshared_projects
1356 .iter()
1357 .map(|project| proto::ResharedProject {
1358 id: project.id.to_proto(),
1359 collaborators: project
1360 .collaborators
1361 .iter()
1362 .map(|collaborator| collaborator.to_proto())
1363 .collect(),
1364 })
1365 .collect(),
1366 rejoined_projects: rejoined_room
1367 .rejoined_projects
1368 .iter()
1369 .map(|rejoined_project| rejoined_project.to_proto())
1370 .collect(),
1371 })?;
1372 room_updated(&rejoined_room.room, &session.peer);
1373
1374 for project in &rejoined_room.reshared_projects {
1375 for collaborator in &project.collaborators {
1376 session
1377 .peer
1378 .send(
1379 collaborator.connection_id,
1380 proto::UpdateProjectCollaborator {
1381 project_id: project.id.to_proto(),
1382 old_peer_id: Some(project.old_connection_id.into()),
1383 new_peer_id: Some(session.connection_id.into()),
1384 },
1385 )
1386 .trace_err();
1387 }
1388
1389 broadcast(
1390 Some(session.connection_id),
1391 project
1392 .collaborators
1393 .iter()
1394 .map(|collaborator| collaborator.connection_id),
1395 |connection_id| {
1396 session.peer.forward_send(
1397 session.connection_id,
1398 connection_id,
1399 proto::UpdateProject {
1400 project_id: project.id.to_proto(),
1401 worktrees: project.worktrees.clone(),
1402 },
1403 )
1404 },
1405 );
1406 }
1407
1408 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1409
1410 let rejoined_room = rejoined_room.into_inner();
1411
1412 room = rejoined_room.room;
1413 channel = rejoined_room.channel;
1414 }
1415
1416 if let Some(channel) = channel {
1417 channel_updated(
1418 &channel,
1419 &room,
1420 &session.peer,
1421 &*session.connection_pool().await,
1422 );
1423 }
1424
1425 update_user_contacts(session.user_id(), &session).await?;
1426 Ok(())
1427}
1428
1429fn notify_rejoined_projects(
1430 rejoined_projects: &mut Vec<RejoinedProject>,
1431 session: &Session,
1432) -> Result<()> {
1433 for project in rejoined_projects.iter() {
1434 for collaborator in &project.collaborators {
1435 session
1436 .peer
1437 .send(
1438 collaborator.connection_id,
1439 proto::UpdateProjectCollaborator {
1440 project_id: project.id.to_proto(),
1441 old_peer_id: Some(project.old_connection_id.into()),
1442 new_peer_id: Some(session.connection_id.into()),
1443 },
1444 )
1445 .trace_err();
1446 }
1447 }
1448
1449 for project in rejoined_projects {
1450 for worktree in mem::take(&mut project.worktrees) {
1451 // Stream this worktree's entries.
1452 let message = proto::UpdateWorktree {
1453 project_id: project.id.to_proto(),
1454 worktree_id: worktree.id,
1455 abs_path: worktree.abs_path.clone(),
1456 root_name: worktree.root_name,
1457 updated_entries: worktree.updated_entries,
1458 removed_entries: worktree.removed_entries,
1459 scan_id: worktree.scan_id,
1460 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1461 updated_repositories: worktree.updated_repositories,
1462 removed_repositories: worktree.removed_repositories,
1463 };
1464 for update in proto::split_worktree_update(message) {
1465 session.peer.send(session.connection_id, update.clone())?;
1466 }
1467
1468 // Stream this worktree's diagnostics.
1469 for summary in worktree.diagnostic_summaries {
1470 session.peer.send(
1471 session.connection_id,
1472 proto::UpdateDiagnosticSummary {
1473 project_id: project.id.to_proto(),
1474 worktree_id: worktree.id,
1475 summary: Some(summary),
1476 },
1477 )?;
1478 }
1479
1480 for settings_file in worktree.settings_files {
1481 session.peer.send(
1482 session.connection_id,
1483 proto::UpdateWorktreeSettings {
1484 project_id: project.id.to_proto(),
1485 worktree_id: worktree.id,
1486 path: settings_file.path,
1487 content: Some(settings_file.content),
1488 kind: Some(settings_file.kind.to_proto().into()),
1489 },
1490 )?;
1491 }
1492 }
1493
1494 for language_server in &project.language_servers {
1495 session.peer.send(
1496 session.connection_id,
1497 proto::UpdateLanguageServer {
1498 project_id: project.id.to_proto(),
1499 language_server_id: language_server.id,
1500 variant: Some(
1501 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1502 proto::LspDiskBasedDiagnosticsUpdated {},
1503 ),
1504 ),
1505 },
1506 )?;
1507 }
1508 }
1509 Ok(())
1510}
1511
1512/// leave room disconnects from the room.
1513async fn leave_room(
1514 _: proto::LeaveRoom,
1515 response: Response<proto::LeaveRoom>,
1516 session: Session,
1517) -> Result<()> {
1518 leave_room_for_session(&session, session.connection_id).await?;
1519 response.send(proto::Ack {})?;
1520 Ok(())
1521}
1522
1523/// Updates the permissions of someone else in the room.
1524async fn set_room_participant_role(
1525 request: proto::SetRoomParticipantRole,
1526 response: Response<proto::SetRoomParticipantRole>,
1527 session: Session,
1528) -> Result<()> {
1529 let user_id = UserId::from_proto(request.user_id);
1530 let role = ChannelRole::from(request.role());
1531
1532 let (livekit_room, can_publish) = {
1533 let room = session
1534 .db()
1535 .await
1536 .set_room_participant_role(
1537 session.user_id(),
1538 RoomId::from_proto(request.room_id),
1539 user_id,
1540 role,
1541 )
1542 .await?;
1543
1544 let livekit_room = room.livekit_room.clone();
1545 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1546 room_updated(&room, &session.peer);
1547 (livekit_room, can_publish)
1548 };
1549
1550 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1551 live_kit
1552 .update_participant(
1553 livekit_room.clone(),
1554 request.user_id.to_string(),
1555 livekit_api::proto::ParticipantPermission {
1556 can_subscribe: true,
1557 can_publish,
1558 can_publish_data: can_publish,
1559 hidden: false,
1560 recorder: false,
1561 },
1562 )
1563 .await
1564 .trace_err();
1565 }
1566
1567 response.send(proto::Ack {})?;
1568 Ok(())
1569}
1570
1571/// Call someone else into the current room
1572async fn call(
1573 request: proto::Call,
1574 response: Response<proto::Call>,
1575 session: Session,
1576) -> Result<()> {
1577 let room_id = RoomId::from_proto(request.room_id);
1578 let calling_user_id = session.user_id();
1579 let calling_connection_id = session.connection_id;
1580 let called_user_id = UserId::from_proto(request.called_user_id);
1581 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1582 if !session
1583 .db()
1584 .await
1585 .has_contact(calling_user_id, called_user_id)
1586 .await?
1587 {
1588 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1589 }
1590
1591 let incoming_call = {
1592 let (room, incoming_call) = &mut *session
1593 .db()
1594 .await
1595 .call(
1596 room_id,
1597 calling_user_id,
1598 calling_connection_id,
1599 called_user_id,
1600 initial_project_id,
1601 )
1602 .await?;
1603 room_updated(room, &session.peer);
1604 mem::take(incoming_call)
1605 };
1606 update_user_contacts(called_user_id, &session).await?;
1607
1608 let mut calls = session
1609 .connection_pool()
1610 .await
1611 .user_connection_ids(called_user_id)
1612 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1613 .collect::<FuturesUnordered<_>>();
1614
1615 while let Some(call_response) = calls.next().await {
1616 match call_response.as_ref() {
1617 Ok(_) => {
1618 response.send(proto::Ack {})?;
1619 return Ok(());
1620 }
1621 Err(_) => {
1622 call_response.trace_err();
1623 }
1624 }
1625 }
1626
1627 {
1628 let room = session
1629 .db()
1630 .await
1631 .call_failed(room_id, called_user_id)
1632 .await?;
1633 room_updated(&room, &session.peer);
1634 }
1635 update_user_contacts(called_user_id, &session).await?;
1636
1637 Err(anyhow!("failed to ring user"))?
1638}
1639
1640/// Cancel an outgoing call.
1641async fn cancel_call(
1642 request: proto::CancelCall,
1643 response: Response<proto::CancelCall>,
1644 session: Session,
1645) -> Result<()> {
1646 let called_user_id = UserId::from_proto(request.called_user_id);
1647 let room_id = RoomId::from_proto(request.room_id);
1648 {
1649 let room = session
1650 .db()
1651 .await
1652 .cancel_call(room_id, session.connection_id, called_user_id)
1653 .await?;
1654 room_updated(&room, &session.peer);
1655 }
1656
1657 for connection_id in session
1658 .connection_pool()
1659 .await
1660 .user_connection_ids(called_user_id)
1661 {
1662 session
1663 .peer
1664 .send(
1665 connection_id,
1666 proto::CallCanceled {
1667 room_id: room_id.to_proto(),
1668 },
1669 )
1670 .trace_err();
1671 }
1672 response.send(proto::Ack {})?;
1673
1674 update_user_contacts(called_user_id, &session).await?;
1675 Ok(())
1676}
1677
1678/// Decline an incoming call.
1679async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1680 let room_id = RoomId::from_proto(message.room_id);
1681 {
1682 let room = session
1683 .db()
1684 .await
1685 .decline_call(Some(room_id), session.user_id())
1686 .await?
1687 .ok_or_else(|| anyhow!("failed to decline call"))?;
1688 room_updated(&room, &session.peer);
1689 }
1690
1691 for connection_id in session
1692 .connection_pool()
1693 .await
1694 .user_connection_ids(session.user_id())
1695 {
1696 session
1697 .peer
1698 .send(
1699 connection_id,
1700 proto::CallCanceled {
1701 room_id: room_id.to_proto(),
1702 },
1703 )
1704 .trace_err();
1705 }
1706 update_user_contacts(session.user_id(), &session).await?;
1707 Ok(())
1708}
1709
1710/// Updates other participants in the room with your current location.
1711async fn update_participant_location(
1712 request: proto::UpdateParticipantLocation,
1713 response: Response<proto::UpdateParticipantLocation>,
1714 session: Session,
1715) -> Result<()> {
1716 let room_id = RoomId::from_proto(request.room_id);
1717 let location = request
1718 .location
1719 .ok_or_else(|| anyhow!("invalid location"))?;
1720
1721 let db = session.db().await;
1722 let room = db
1723 .update_room_participant_location(room_id, session.connection_id, location)
1724 .await?;
1725
1726 room_updated(&room, &session.peer);
1727 response.send(proto::Ack {})?;
1728 Ok(())
1729}
1730
1731/// Share a project into the room.
1732async fn share_project(
1733 request: proto::ShareProject,
1734 response: Response<proto::ShareProject>,
1735 session: Session,
1736) -> Result<()> {
1737 let (project_id, room) = &*session
1738 .db()
1739 .await
1740 .share_project(
1741 RoomId::from_proto(request.room_id),
1742 session.connection_id,
1743 &request.worktrees,
1744 request.is_ssh_project,
1745 )
1746 .await?;
1747 response.send(proto::ShareProjectResponse {
1748 project_id: project_id.to_proto(),
1749 })?;
1750 room_updated(room, &session.peer);
1751
1752 Ok(())
1753}
1754
1755/// Unshare a project from the room.
1756async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1757 let project_id = ProjectId::from_proto(message.project_id);
1758 unshare_project_internal(project_id, session.connection_id, &session).await
1759}
1760
1761async fn unshare_project_internal(
1762 project_id: ProjectId,
1763 connection_id: ConnectionId,
1764 session: &Session,
1765) -> Result<()> {
1766 let delete = {
1767 let room_guard = session
1768 .db()
1769 .await
1770 .unshare_project(project_id, connection_id)
1771 .await?;
1772
1773 let (delete, room, guest_connection_ids) = &*room_guard;
1774
1775 let message = proto::UnshareProject {
1776 project_id: project_id.to_proto(),
1777 };
1778
1779 broadcast(
1780 Some(connection_id),
1781 guest_connection_ids.iter().copied(),
1782 |conn_id| session.peer.send(conn_id, message.clone()),
1783 );
1784 if let Some(room) = room {
1785 room_updated(room, &session.peer);
1786 }
1787
1788 *delete
1789 };
1790
1791 if delete {
1792 let db = session.db().await;
1793 db.delete_project(project_id).await?;
1794 }
1795
1796 Ok(())
1797}
1798
1799/// Join someone elses shared project.
1800async fn join_project(
1801 request: proto::JoinProject,
1802 response: Response<proto::JoinProject>,
1803 session: Session,
1804) -> Result<()> {
1805 let project_id = ProjectId::from_proto(request.project_id);
1806
1807 tracing::info!(%project_id, "join project");
1808
1809 let db = session.db().await;
1810 let (project, replica_id) = &mut *db
1811 .join_project(project_id, session.connection_id, session.user_id())
1812 .await?;
1813 drop(db);
1814 tracing::info!(%project_id, "join remote project");
1815 join_project_internal(response, session, project, replica_id)
1816}
1817
1818trait JoinProjectInternalResponse {
1819 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1820}
1821impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1822 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1823 Response::<proto::JoinProject>::send(self, result)
1824 }
1825}
1826
1827fn join_project_internal(
1828 response: impl JoinProjectInternalResponse,
1829 session: Session,
1830 project: &mut Project,
1831 replica_id: &ReplicaId,
1832) -> Result<()> {
1833 let collaborators = project
1834 .collaborators
1835 .iter()
1836 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1837 .map(|collaborator| collaborator.to_proto())
1838 .collect::<Vec<_>>();
1839 let project_id = project.id;
1840 let guest_user_id = session.user_id();
1841
1842 let worktrees = project
1843 .worktrees
1844 .iter()
1845 .map(|(id, worktree)| proto::WorktreeMetadata {
1846 id: *id,
1847 root_name: worktree.root_name.clone(),
1848 visible: worktree.visible,
1849 abs_path: worktree.abs_path.clone(),
1850 })
1851 .collect::<Vec<_>>();
1852
1853 let add_project_collaborator = proto::AddProjectCollaborator {
1854 project_id: project_id.to_proto(),
1855 collaborator: Some(proto::Collaborator {
1856 peer_id: Some(session.connection_id.into()),
1857 replica_id: replica_id.0 as u32,
1858 user_id: guest_user_id.to_proto(),
1859 is_host: false,
1860 }),
1861 };
1862
1863 for collaborator in &collaborators {
1864 session
1865 .peer
1866 .send(
1867 collaborator.peer_id.unwrap().into(),
1868 add_project_collaborator.clone(),
1869 )
1870 .trace_err();
1871 }
1872
1873 // First, we send the metadata associated with each worktree.
1874 response.send(proto::JoinProjectResponse {
1875 project_id: project.id.0 as u64,
1876 worktrees: worktrees.clone(),
1877 replica_id: replica_id.0 as u32,
1878 collaborators: collaborators.clone(),
1879 language_servers: project.language_servers.clone(),
1880 role: project.role.into(),
1881 })?;
1882
1883 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1884 // Stream this worktree's entries.
1885 let message = proto::UpdateWorktree {
1886 project_id: project_id.to_proto(),
1887 worktree_id,
1888 abs_path: worktree.abs_path.clone(),
1889 root_name: worktree.root_name,
1890 updated_entries: worktree.entries,
1891 removed_entries: Default::default(),
1892 scan_id: worktree.scan_id,
1893 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1894 updated_repositories: worktree.repository_entries.into_values().collect(),
1895 removed_repositories: Default::default(),
1896 };
1897 for update in proto::split_worktree_update(message) {
1898 session.peer.send(session.connection_id, update.clone())?;
1899 }
1900
1901 // Stream this worktree's diagnostics.
1902 for summary in worktree.diagnostic_summaries {
1903 session.peer.send(
1904 session.connection_id,
1905 proto::UpdateDiagnosticSummary {
1906 project_id: project_id.to_proto(),
1907 worktree_id: worktree.id,
1908 summary: Some(summary),
1909 },
1910 )?;
1911 }
1912
1913 for settings_file in worktree.settings_files {
1914 session.peer.send(
1915 session.connection_id,
1916 proto::UpdateWorktreeSettings {
1917 project_id: project_id.to_proto(),
1918 worktree_id: worktree.id,
1919 path: settings_file.path,
1920 content: Some(settings_file.content),
1921 kind: Some(settings_file.kind.to_proto() as i32),
1922 },
1923 )?;
1924 }
1925 }
1926
1927 for language_server in &project.language_servers {
1928 session.peer.send(
1929 session.connection_id,
1930 proto::UpdateLanguageServer {
1931 project_id: project_id.to_proto(),
1932 language_server_id: language_server.id,
1933 variant: Some(
1934 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1935 proto::LspDiskBasedDiagnosticsUpdated {},
1936 ),
1937 ),
1938 },
1939 )?;
1940 }
1941
1942 Ok(())
1943}
1944
1945/// Leave someone elses shared project.
1946async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1947 let sender_id = session.connection_id;
1948 let project_id = ProjectId::from_proto(request.project_id);
1949 let db = session.db().await;
1950
1951 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1952 tracing::info!(
1953 %project_id,
1954 "leave project"
1955 );
1956
1957 project_left(project, &session);
1958 if let Some(room) = room {
1959 room_updated(room, &session.peer);
1960 }
1961
1962 Ok(())
1963}
1964
1965/// Updates other participants with changes to the project
1966async fn update_project(
1967 request: proto::UpdateProject,
1968 response: Response<proto::UpdateProject>,
1969 session: Session,
1970) -> Result<()> {
1971 let project_id = ProjectId::from_proto(request.project_id);
1972 let (room, guest_connection_ids) = &*session
1973 .db()
1974 .await
1975 .update_project(project_id, session.connection_id, &request.worktrees)
1976 .await?;
1977 broadcast(
1978 Some(session.connection_id),
1979 guest_connection_ids.iter().copied(),
1980 |connection_id| {
1981 session
1982 .peer
1983 .forward_send(session.connection_id, connection_id, request.clone())
1984 },
1985 );
1986 if let Some(room) = room {
1987 room_updated(room, &session.peer);
1988 }
1989 response.send(proto::Ack {})?;
1990
1991 Ok(())
1992}
1993
1994/// Updates other participants with changes to the worktree
1995async fn update_worktree(
1996 request: proto::UpdateWorktree,
1997 response: Response<proto::UpdateWorktree>,
1998 session: Session,
1999) -> Result<()> {
2000 let guest_connection_ids = session
2001 .db()
2002 .await
2003 .update_worktree(&request, session.connection_id)
2004 .await?;
2005
2006 broadcast(
2007 Some(session.connection_id),
2008 guest_connection_ids.iter().copied(),
2009 |connection_id| {
2010 session
2011 .peer
2012 .forward_send(session.connection_id, connection_id, request.clone())
2013 },
2014 );
2015 response.send(proto::Ack {})?;
2016 Ok(())
2017}
2018
2019/// Updates other participants with changes to the diagnostics
2020async fn update_diagnostic_summary(
2021 message: proto::UpdateDiagnosticSummary,
2022 session: Session,
2023) -> Result<()> {
2024 let guest_connection_ids = session
2025 .db()
2026 .await
2027 .update_diagnostic_summary(&message, session.connection_id)
2028 .await?;
2029
2030 broadcast(
2031 Some(session.connection_id),
2032 guest_connection_ids.iter().copied(),
2033 |connection_id| {
2034 session
2035 .peer
2036 .forward_send(session.connection_id, connection_id, message.clone())
2037 },
2038 );
2039
2040 Ok(())
2041}
2042
2043/// Updates other participants with changes to the worktree settings
2044async fn update_worktree_settings(
2045 message: proto::UpdateWorktreeSettings,
2046 session: Session,
2047) -> Result<()> {
2048 let guest_connection_ids = session
2049 .db()
2050 .await
2051 .update_worktree_settings(&message, session.connection_id)
2052 .await?;
2053
2054 broadcast(
2055 Some(session.connection_id),
2056 guest_connection_ids.iter().copied(),
2057 |connection_id| {
2058 session
2059 .peer
2060 .forward_send(session.connection_id, connection_id, message.clone())
2061 },
2062 );
2063
2064 Ok(())
2065}
2066
2067/// Notify other participants that a language server has started.
2068async fn start_language_server(
2069 request: proto::StartLanguageServer,
2070 session: Session,
2071) -> Result<()> {
2072 let guest_connection_ids = session
2073 .db()
2074 .await
2075 .start_language_server(&request, session.connection_id)
2076 .await?;
2077
2078 broadcast(
2079 Some(session.connection_id),
2080 guest_connection_ids.iter().copied(),
2081 |connection_id| {
2082 session
2083 .peer
2084 .forward_send(session.connection_id, connection_id, request.clone())
2085 },
2086 );
2087 Ok(())
2088}
2089
2090/// Notify other participants that a language server has changed.
2091async fn update_language_server(
2092 request: proto::UpdateLanguageServer,
2093 session: Session,
2094) -> Result<()> {
2095 let project_id = ProjectId::from_proto(request.project_id);
2096 let project_connection_ids = session
2097 .db()
2098 .await
2099 .project_connection_ids(project_id, session.connection_id, true)
2100 .await?;
2101 broadcast(
2102 Some(session.connection_id),
2103 project_connection_ids.iter().copied(),
2104 |connection_id| {
2105 session
2106 .peer
2107 .forward_send(session.connection_id, connection_id, request.clone())
2108 },
2109 );
2110 Ok(())
2111}
2112
2113/// forward a project request to the host. These requests should be read only
2114/// as guests are allowed to send them.
2115async fn forward_read_only_project_request<T>(
2116 request: T,
2117 response: Response<T>,
2118 session: Session,
2119) -> Result<()>
2120where
2121 T: EntityMessage + RequestMessage,
2122{
2123 let project_id = ProjectId::from_proto(request.remote_entity_id());
2124 let host_connection_id = session
2125 .db()
2126 .await
2127 .host_for_read_only_project_request(project_id, session.connection_id)
2128 .await?;
2129 let payload = session
2130 .peer
2131 .forward_request(session.connection_id, host_connection_id, request)
2132 .await?;
2133 response.send(payload)?;
2134 Ok(())
2135}
2136
2137async fn forward_find_search_candidates_request(
2138 request: proto::FindSearchCandidates,
2139 response: Response<proto::FindSearchCandidates>,
2140 session: Session,
2141) -> Result<()> {
2142 let project_id = ProjectId::from_proto(request.remote_entity_id());
2143 let host_connection_id = session
2144 .db()
2145 .await
2146 .host_for_read_only_project_request(project_id, session.connection_id)
2147 .await?;
2148 let payload = session
2149 .peer
2150 .forward_request(session.connection_id, host_connection_id, request)
2151 .await?;
2152 response.send(payload)?;
2153 Ok(())
2154}
2155
2156/// forward a project request to the host. These requests are disallowed
2157/// for guests.
2158async fn forward_mutating_project_request<T>(
2159 request: T,
2160 response: Response<T>,
2161 session: Session,
2162) -> Result<()>
2163where
2164 T: EntityMessage + RequestMessage,
2165{
2166 let project_id = ProjectId::from_proto(request.remote_entity_id());
2167
2168 let host_connection_id = session
2169 .db()
2170 .await
2171 .host_for_mutating_project_request(project_id, session.connection_id)
2172 .await?;
2173 let payload = session
2174 .peer
2175 .forward_request(session.connection_id, host_connection_id, request)
2176 .await?;
2177 response.send(payload)?;
2178 Ok(())
2179}
2180
2181/// Notify other participants that a new buffer has been created
2182async fn create_buffer_for_peer(
2183 request: proto::CreateBufferForPeer,
2184 session: Session,
2185) -> Result<()> {
2186 session
2187 .db()
2188 .await
2189 .check_user_is_project_host(
2190 ProjectId::from_proto(request.project_id),
2191 session.connection_id,
2192 )
2193 .await?;
2194 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2195 session
2196 .peer
2197 .forward_send(session.connection_id, peer_id.into(), request)?;
2198 Ok(())
2199}
2200
2201/// Notify other participants that a buffer has been updated. This is
2202/// allowed for guests as long as the update is limited to selections.
2203async fn update_buffer(
2204 request: proto::UpdateBuffer,
2205 response: Response<proto::UpdateBuffer>,
2206 session: Session,
2207) -> Result<()> {
2208 let project_id = ProjectId::from_proto(request.project_id);
2209 let mut capability = Capability::ReadOnly;
2210
2211 for op in request.operations.iter() {
2212 match op.variant {
2213 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2214 Some(_) => capability = Capability::ReadWrite,
2215 }
2216 }
2217
2218 let host = {
2219 let guard = session
2220 .db()
2221 .await
2222 .connections_for_buffer_update(project_id, session.connection_id, capability)
2223 .await?;
2224
2225 let (host, guests) = &*guard;
2226
2227 broadcast(
2228 Some(session.connection_id),
2229 guests.clone(),
2230 |connection_id| {
2231 session
2232 .peer
2233 .forward_send(session.connection_id, connection_id, request.clone())
2234 },
2235 );
2236
2237 *host
2238 };
2239
2240 if host != session.connection_id {
2241 session
2242 .peer
2243 .forward_request(session.connection_id, host, request.clone())
2244 .await?;
2245 }
2246
2247 response.send(proto::Ack {})?;
2248 Ok(())
2249}
2250
2251async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2252 let project_id = ProjectId::from_proto(message.project_id);
2253
2254 let operation = message.operation.as_ref().context("invalid operation")?;
2255 let capability = match operation.variant.as_ref() {
2256 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2257 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2258 match buffer_op.variant {
2259 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2260 Capability::ReadOnly
2261 }
2262 _ => Capability::ReadWrite,
2263 }
2264 } else {
2265 Capability::ReadWrite
2266 }
2267 }
2268 Some(_) => Capability::ReadWrite,
2269 None => Capability::ReadOnly,
2270 };
2271
2272 let guard = session
2273 .db()
2274 .await
2275 .connections_for_buffer_update(project_id, session.connection_id, capability)
2276 .await?;
2277
2278 let (host, guests) = &*guard;
2279
2280 broadcast(
2281 Some(session.connection_id),
2282 guests.iter().chain([host]).copied(),
2283 |connection_id| {
2284 session
2285 .peer
2286 .forward_send(session.connection_id, connection_id, message.clone())
2287 },
2288 );
2289
2290 Ok(())
2291}
2292
2293/// Notify other participants that a project has been updated.
2294async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2295 request: T,
2296 session: Session,
2297) -> Result<()> {
2298 let project_id = ProjectId::from_proto(request.remote_entity_id());
2299 let project_connection_ids = session
2300 .db()
2301 .await
2302 .project_connection_ids(project_id, session.connection_id, false)
2303 .await?;
2304
2305 broadcast(
2306 Some(session.connection_id),
2307 project_connection_ids.iter().copied(),
2308 |connection_id| {
2309 session
2310 .peer
2311 .forward_send(session.connection_id, connection_id, request.clone())
2312 },
2313 );
2314 Ok(())
2315}
2316
2317/// Start following another user in a call.
2318async fn follow(
2319 request: proto::Follow,
2320 response: Response<proto::Follow>,
2321 session: Session,
2322) -> Result<()> {
2323 let room_id = RoomId::from_proto(request.room_id);
2324 let project_id = request.project_id.map(ProjectId::from_proto);
2325 let leader_id = request
2326 .leader_id
2327 .ok_or_else(|| anyhow!("invalid leader id"))?
2328 .into();
2329 let follower_id = session.connection_id;
2330
2331 session
2332 .db()
2333 .await
2334 .check_room_participants(room_id, leader_id, session.connection_id)
2335 .await?;
2336
2337 let response_payload = session
2338 .peer
2339 .forward_request(session.connection_id, leader_id, request)
2340 .await?;
2341 response.send(response_payload)?;
2342
2343 if let Some(project_id) = project_id {
2344 let room = session
2345 .db()
2346 .await
2347 .follow(room_id, project_id, leader_id, follower_id)
2348 .await?;
2349 room_updated(&room, &session.peer);
2350 }
2351
2352 Ok(())
2353}
2354
2355/// Stop following another user in a call.
2356async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2357 let room_id = RoomId::from_proto(request.room_id);
2358 let project_id = request.project_id.map(ProjectId::from_proto);
2359 let leader_id = request
2360 .leader_id
2361 .ok_or_else(|| anyhow!("invalid leader id"))?
2362 .into();
2363 let follower_id = session.connection_id;
2364
2365 session
2366 .db()
2367 .await
2368 .check_room_participants(room_id, leader_id, session.connection_id)
2369 .await?;
2370
2371 session
2372 .peer
2373 .forward_send(session.connection_id, leader_id, request)?;
2374
2375 if let Some(project_id) = project_id {
2376 let room = session
2377 .db()
2378 .await
2379 .unfollow(room_id, project_id, leader_id, follower_id)
2380 .await?;
2381 room_updated(&room, &session.peer);
2382 }
2383
2384 Ok(())
2385}
2386
2387/// Notify everyone following you of your current location.
2388async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2389 let room_id = RoomId::from_proto(request.room_id);
2390 let database = session.db.lock().await;
2391
2392 let connection_ids = if let Some(project_id) = request.project_id {
2393 let project_id = ProjectId::from_proto(project_id);
2394 database
2395 .project_connection_ids(project_id, session.connection_id, true)
2396 .await?
2397 } else {
2398 database
2399 .room_connection_ids(room_id, session.connection_id)
2400 .await?
2401 };
2402
2403 // For now, don't send view update messages back to that view's current leader.
2404 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2405 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2406 _ => None,
2407 });
2408
2409 for connection_id in connection_ids.iter().cloned() {
2410 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2411 session
2412 .peer
2413 .forward_send(session.connection_id, connection_id, request.clone())?;
2414 }
2415 }
2416 Ok(())
2417}
2418
2419/// Get public data about users.
2420async fn get_users(
2421 request: proto::GetUsers,
2422 response: Response<proto::GetUsers>,
2423 session: Session,
2424) -> Result<()> {
2425 let user_ids = request
2426 .user_ids
2427 .into_iter()
2428 .map(UserId::from_proto)
2429 .collect();
2430 let users = session
2431 .db()
2432 .await
2433 .get_users_by_ids(user_ids)
2434 .await?
2435 .into_iter()
2436 .map(|user| proto::User {
2437 id: user.id.to_proto(),
2438 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2439 github_login: user.github_login,
2440 email: user.email_address,
2441 name: user.name,
2442 })
2443 .collect();
2444 response.send(proto::UsersResponse { users })?;
2445 Ok(())
2446}
2447
2448/// Search for users (to invite) buy Github login
2449async fn fuzzy_search_users(
2450 request: proto::FuzzySearchUsers,
2451 response: Response<proto::FuzzySearchUsers>,
2452 session: Session,
2453) -> Result<()> {
2454 let query = request.query;
2455 let users = match query.len() {
2456 0 => vec![],
2457 1 | 2 => session
2458 .db()
2459 .await
2460 .get_user_by_github_login(&query)
2461 .await?
2462 .into_iter()
2463 .collect(),
2464 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2465 };
2466 let users = users
2467 .into_iter()
2468 .filter(|user| user.id != session.user_id())
2469 .map(|user| proto::User {
2470 id: user.id.to_proto(),
2471 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2472 github_login: user.github_login,
2473 name: user.name,
2474 email: user.email_address,
2475 })
2476 .collect();
2477 response.send(proto::UsersResponse { users })?;
2478 Ok(())
2479}
2480
2481/// Send a contact request to another user.
2482async fn request_contact(
2483 request: proto::RequestContact,
2484 response: Response<proto::RequestContact>,
2485 session: Session,
2486) -> Result<()> {
2487 let requester_id = session.user_id();
2488 let responder_id = UserId::from_proto(request.responder_id);
2489 if requester_id == responder_id {
2490 return Err(anyhow!("cannot add yourself as a contact"))?;
2491 }
2492
2493 let notifications = session
2494 .db()
2495 .await
2496 .send_contact_request(requester_id, responder_id)
2497 .await?;
2498
2499 // Update outgoing contact requests of requester
2500 let mut update = proto::UpdateContacts::default();
2501 update.outgoing_requests.push(responder_id.to_proto());
2502 for connection_id in session
2503 .connection_pool()
2504 .await
2505 .user_connection_ids(requester_id)
2506 {
2507 session.peer.send(connection_id, update.clone())?;
2508 }
2509
2510 // Update incoming contact requests of responder
2511 let mut update = proto::UpdateContacts::default();
2512 update
2513 .incoming_requests
2514 .push(proto::IncomingContactRequest {
2515 requester_id: requester_id.to_proto(),
2516 });
2517 let connection_pool = session.connection_pool().await;
2518 for connection_id in connection_pool.user_connection_ids(responder_id) {
2519 session.peer.send(connection_id, update.clone())?;
2520 }
2521
2522 send_notifications(&connection_pool, &session.peer, notifications);
2523
2524 response.send(proto::Ack {})?;
2525 Ok(())
2526}
2527
2528/// Accept or decline a contact request
2529async fn respond_to_contact_request(
2530 request: proto::RespondToContactRequest,
2531 response: Response<proto::RespondToContactRequest>,
2532 session: Session,
2533) -> Result<()> {
2534 let responder_id = session.user_id();
2535 let requester_id = UserId::from_proto(request.requester_id);
2536 let db = session.db().await;
2537 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2538 db.dismiss_contact_notification(responder_id, requester_id)
2539 .await?;
2540 } else {
2541 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2542
2543 let notifications = db
2544 .respond_to_contact_request(responder_id, requester_id, accept)
2545 .await?;
2546 let requester_busy = db.is_user_busy(requester_id).await?;
2547 let responder_busy = db.is_user_busy(responder_id).await?;
2548
2549 let pool = session.connection_pool().await;
2550 // Update responder with new contact
2551 let mut update = proto::UpdateContacts::default();
2552 if accept {
2553 update
2554 .contacts
2555 .push(contact_for_user(requester_id, requester_busy, &pool));
2556 }
2557 update
2558 .remove_incoming_requests
2559 .push(requester_id.to_proto());
2560 for connection_id in pool.user_connection_ids(responder_id) {
2561 session.peer.send(connection_id, update.clone())?;
2562 }
2563
2564 // Update requester with new contact
2565 let mut update = proto::UpdateContacts::default();
2566 if accept {
2567 update
2568 .contacts
2569 .push(contact_for_user(responder_id, responder_busy, &pool));
2570 }
2571 update
2572 .remove_outgoing_requests
2573 .push(responder_id.to_proto());
2574
2575 for connection_id in pool.user_connection_ids(requester_id) {
2576 session.peer.send(connection_id, update.clone())?;
2577 }
2578
2579 send_notifications(&pool, &session.peer, notifications);
2580 }
2581
2582 response.send(proto::Ack {})?;
2583 Ok(())
2584}
2585
2586/// Remove a contact.
2587async fn remove_contact(
2588 request: proto::RemoveContact,
2589 response: Response<proto::RemoveContact>,
2590 session: Session,
2591) -> Result<()> {
2592 let requester_id = session.user_id();
2593 let responder_id = UserId::from_proto(request.user_id);
2594 let db = session.db().await;
2595 let (contact_accepted, deleted_notification_id) =
2596 db.remove_contact(requester_id, responder_id).await?;
2597
2598 let pool = session.connection_pool().await;
2599 // Update outgoing contact requests of requester
2600 let mut update = proto::UpdateContacts::default();
2601 if contact_accepted {
2602 update.remove_contacts.push(responder_id.to_proto());
2603 } else {
2604 update
2605 .remove_outgoing_requests
2606 .push(responder_id.to_proto());
2607 }
2608 for connection_id in pool.user_connection_ids(requester_id) {
2609 session.peer.send(connection_id, update.clone())?;
2610 }
2611
2612 // Update incoming contact requests of responder
2613 let mut update = proto::UpdateContacts::default();
2614 if contact_accepted {
2615 update.remove_contacts.push(requester_id.to_proto());
2616 } else {
2617 update
2618 .remove_incoming_requests
2619 .push(requester_id.to_proto());
2620 }
2621 for connection_id in pool.user_connection_ids(responder_id) {
2622 session.peer.send(connection_id, update.clone())?;
2623 if let Some(notification_id) = deleted_notification_id {
2624 session.peer.send(
2625 connection_id,
2626 proto::DeleteNotification {
2627 notification_id: notification_id.to_proto(),
2628 },
2629 )?;
2630 }
2631 }
2632
2633 response.send(proto::Ack {})?;
2634 Ok(())
2635}
2636
2637fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2638 version.0.minor() < 139
2639}
2640
2641async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2642 let plan = session.current_plan(&session.db().await).await?;
2643
2644 session
2645 .peer
2646 .send(
2647 session.connection_id,
2648 proto::UpdateUserPlan { plan: plan.into() },
2649 )
2650 .trace_err();
2651
2652 Ok(())
2653}
2654
2655async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2656 subscribe_user_to_channels(session.user_id(), &session).await?;
2657 Ok(())
2658}
2659
2660async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2661 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2662 let mut pool = session.connection_pool().await;
2663 for membership in &channels_for_user.channel_memberships {
2664 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2665 }
2666 session.peer.send(
2667 session.connection_id,
2668 build_update_user_channels(&channels_for_user),
2669 )?;
2670 session.peer.send(
2671 session.connection_id,
2672 build_channels_update(channels_for_user),
2673 )?;
2674 Ok(())
2675}
2676
2677/// Creates a new channel.
2678async fn create_channel(
2679 request: proto::CreateChannel,
2680 response: Response<proto::CreateChannel>,
2681 session: Session,
2682) -> Result<()> {
2683 let db = session.db().await;
2684
2685 let parent_id = request.parent_id.map(ChannelId::from_proto);
2686 let (channel, membership) = db
2687 .create_channel(&request.name, parent_id, session.user_id())
2688 .await?;
2689
2690 let root_id = channel.root_id();
2691 let channel = Channel::from_model(channel);
2692
2693 response.send(proto::CreateChannelResponse {
2694 channel: Some(channel.to_proto()),
2695 parent_id: request.parent_id,
2696 })?;
2697
2698 let mut connection_pool = session.connection_pool().await;
2699 if let Some(membership) = membership {
2700 connection_pool.subscribe_to_channel(
2701 membership.user_id,
2702 membership.channel_id,
2703 membership.role,
2704 );
2705 let update = proto::UpdateUserChannels {
2706 channel_memberships: vec![proto::ChannelMembership {
2707 channel_id: membership.channel_id.to_proto(),
2708 role: membership.role.into(),
2709 }],
2710 ..Default::default()
2711 };
2712 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2713 session.peer.send(connection_id, update.clone())?;
2714 }
2715 }
2716
2717 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2718 if !role.can_see_channel(channel.visibility) {
2719 continue;
2720 }
2721
2722 let update = proto::UpdateChannels {
2723 channels: vec![channel.to_proto()],
2724 ..Default::default()
2725 };
2726 session.peer.send(connection_id, update.clone())?;
2727 }
2728
2729 Ok(())
2730}
2731
2732/// Delete a channel
2733async fn delete_channel(
2734 request: proto::DeleteChannel,
2735 response: Response<proto::DeleteChannel>,
2736 session: Session,
2737) -> Result<()> {
2738 let db = session.db().await;
2739
2740 let channel_id = request.channel_id;
2741 let (root_channel, removed_channels) = db
2742 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2743 .await?;
2744 response.send(proto::Ack {})?;
2745
2746 // Notify members of removed channels
2747 let mut update = proto::UpdateChannels::default();
2748 update
2749 .delete_channels
2750 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2751
2752 let connection_pool = session.connection_pool().await;
2753 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2754 session.peer.send(connection_id, update.clone())?;
2755 }
2756
2757 Ok(())
2758}
2759
2760/// Invite someone to join a channel.
2761async fn invite_channel_member(
2762 request: proto::InviteChannelMember,
2763 response: Response<proto::InviteChannelMember>,
2764 session: Session,
2765) -> Result<()> {
2766 let db = session.db().await;
2767 let channel_id = ChannelId::from_proto(request.channel_id);
2768 let invitee_id = UserId::from_proto(request.user_id);
2769 let InviteMemberResult {
2770 channel,
2771 notifications,
2772 } = db
2773 .invite_channel_member(
2774 channel_id,
2775 invitee_id,
2776 session.user_id(),
2777 request.role().into(),
2778 )
2779 .await?;
2780
2781 let update = proto::UpdateChannels {
2782 channel_invitations: vec![channel.to_proto()],
2783 ..Default::default()
2784 };
2785
2786 let connection_pool = session.connection_pool().await;
2787 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2788 session.peer.send(connection_id, update.clone())?;
2789 }
2790
2791 send_notifications(&connection_pool, &session.peer, notifications);
2792
2793 response.send(proto::Ack {})?;
2794 Ok(())
2795}
2796
2797/// remove someone from a channel
2798async fn remove_channel_member(
2799 request: proto::RemoveChannelMember,
2800 response: Response<proto::RemoveChannelMember>,
2801 session: Session,
2802) -> Result<()> {
2803 let db = session.db().await;
2804 let channel_id = ChannelId::from_proto(request.channel_id);
2805 let member_id = UserId::from_proto(request.user_id);
2806
2807 let RemoveChannelMemberResult {
2808 membership_update,
2809 notification_id,
2810 } = db
2811 .remove_channel_member(channel_id, member_id, session.user_id())
2812 .await?;
2813
2814 let mut connection_pool = session.connection_pool().await;
2815 notify_membership_updated(
2816 &mut connection_pool,
2817 membership_update,
2818 member_id,
2819 &session.peer,
2820 );
2821 for connection_id in connection_pool.user_connection_ids(member_id) {
2822 if let Some(notification_id) = notification_id {
2823 session
2824 .peer
2825 .send(
2826 connection_id,
2827 proto::DeleteNotification {
2828 notification_id: notification_id.to_proto(),
2829 },
2830 )
2831 .trace_err();
2832 }
2833 }
2834
2835 response.send(proto::Ack {})?;
2836 Ok(())
2837}
2838
2839/// Toggle the channel between public and private.
2840/// Care is taken to maintain the invariant that public channels only descend from public channels,
2841/// (though members-only channels can appear at any point in the hierarchy).
2842async fn set_channel_visibility(
2843 request: proto::SetChannelVisibility,
2844 response: Response<proto::SetChannelVisibility>,
2845 session: Session,
2846) -> Result<()> {
2847 let db = session.db().await;
2848 let channel_id = ChannelId::from_proto(request.channel_id);
2849 let visibility = request.visibility().into();
2850
2851 let channel_model = db
2852 .set_channel_visibility(channel_id, visibility, session.user_id())
2853 .await?;
2854 let root_id = channel_model.root_id();
2855 let channel = Channel::from_model(channel_model);
2856
2857 let mut connection_pool = session.connection_pool().await;
2858 for (user_id, role) in connection_pool
2859 .channel_user_ids(root_id)
2860 .collect::<Vec<_>>()
2861 .into_iter()
2862 {
2863 let update = if role.can_see_channel(channel.visibility) {
2864 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2865 proto::UpdateChannels {
2866 channels: vec![channel.to_proto()],
2867 ..Default::default()
2868 }
2869 } else {
2870 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2871 proto::UpdateChannels {
2872 delete_channels: vec![channel.id.to_proto()],
2873 ..Default::default()
2874 }
2875 };
2876
2877 for connection_id in connection_pool.user_connection_ids(user_id) {
2878 session.peer.send(connection_id, update.clone())?;
2879 }
2880 }
2881
2882 response.send(proto::Ack {})?;
2883 Ok(())
2884}
2885
2886/// Alter the role for a user in the channel.
2887async fn set_channel_member_role(
2888 request: proto::SetChannelMemberRole,
2889 response: Response<proto::SetChannelMemberRole>,
2890 session: Session,
2891) -> Result<()> {
2892 let db = session.db().await;
2893 let channel_id = ChannelId::from_proto(request.channel_id);
2894 let member_id = UserId::from_proto(request.user_id);
2895 let result = db
2896 .set_channel_member_role(
2897 channel_id,
2898 session.user_id(),
2899 member_id,
2900 request.role().into(),
2901 )
2902 .await?;
2903
2904 match result {
2905 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2906 let mut connection_pool = session.connection_pool().await;
2907 notify_membership_updated(
2908 &mut connection_pool,
2909 membership_update,
2910 member_id,
2911 &session.peer,
2912 )
2913 }
2914 db::SetMemberRoleResult::InviteUpdated(channel) => {
2915 let update = proto::UpdateChannels {
2916 channel_invitations: vec![channel.to_proto()],
2917 ..Default::default()
2918 };
2919
2920 for connection_id in session
2921 .connection_pool()
2922 .await
2923 .user_connection_ids(member_id)
2924 {
2925 session.peer.send(connection_id, update.clone())?;
2926 }
2927 }
2928 }
2929
2930 response.send(proto::Ack {})?;
2931 Ok(())
2932}
2933
2934/// Change the name of a channel
2935async fn rename_channel(
2936 request: proto::RenameChannel,
2937 response: Response<proto::RenameChannel>,
2938 session: Session,
2939) -> Result<()> {
2940 let db = session.db().await;
2941 let channel_id = ChannelId::from_proto(request.channel_id);
2942 let channel_model = db
2943 .rename_channel(channel_id, session.user_id(), &request.name)
2944 .await?;
2945 let root_id = channel_model.root_id();
2946 let channel = Channel::from_model(channel_model);
2947
2948 response.send(proto::RenameChannelResponse {
2949 channel: Some(channel.to_proto()),
2950 })?;
2951
2952 let connection_pool = session.connection_pool().await;
2953 let update = proto::UpdateChannels {
2954 channels: vec![channel.to_proto()],
2955 ..Default::default()
2956 };
2957 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2958 if role.can_see_channel(channel.visibility) {
2959 session.peer.send(connection_id, update.clone())?;
2960 }
2961 }
2962
2963 Ok(())
2964}
2965
2966/// Move a channel to a new parent.
2967async fn move_channel(
2968 request: proto::MoveChannel,
2969 response: Response<proto::MoveChannel>,
2970 session: Session,
2971) -> Result<()> {
2972 let channel_id = ChannelId::from_proto(request.channel_id);
2973 let to = ChannelId::from_proto(request.to);
2974
2975 let (root_id, channels) = session
2976 .db()
2977 .await
2978 .move_channel(channel_id, to, session.user_id())
2979 .await?;
2980
2981 let connection_pool = session.connection_pool().await;
2982 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2983 let channels = channels
2984 .iter()
2985 .filter_map(|channel| {
2986 if role.can_see_channel(channel.visibility) {
2987 Some(channel.to_proto())
2988 } else {
2989 None
2990 }
2991 })
2992 .collect::<Vec<_>>();
2993 if channels.is_empty() {
2994 continue;
2995 }
2996
2997 let update = proto::UpdateChannels {
2998 channels,
2999 ..Default::default()
3000 };
3001
3002 session.peer.send(connection_id, update.clone())?;
3003 }
3004
3005 response.send(Ack {})?;
3006 Ok(())
3007}
3008
3009/// Get the list of channel members
3010async fn get_channel_members(
3011 request: proto::GetChannelMembers,
3012 response: Response<proto::GetChannelMembers>,
3013 session: Session,
3014) -> Result<()> {
3015 let db = session.db().await;
3016 let channel_id = ChannelId::from_proto(request.channel_id);
3017 let limit = if request.limit == 0 {
3018 u16::MAX as u64
3019 } else {
3020 request.limit
3021 };
3022 let (members, users) = db
3023 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3024 .await?;
3025 response.send(proto::GetChannelMembersResponse { members, users })?;
3026 Ok(())
3027}
3028
3029/// Accept or decline a channel invitation.
3030async fn respond_to_channel_invite(
3031 request: proto::RespondToChannelInvite,
3032 response: Response<proto::RespondToChannelInvite>,
3033 session: Session,
3034) -> Result<()> {
3035 let db = session.db().await;
3036 let channel_id = ChannelId::from_proto(request.channel_id);
3037 let RespondToChannelInvite {
3038 membership_update,
3039 notifications,
3040 } = db
3041 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3042 .await?;
3043
3044 let mut connection_pool = session.connection_pool().await;
3045 if let Some(membership_update) = membership_update {
3046 notify_membership_updated(
3047 &mut connection_pool,
3048 membership_update,
3049 session.user_id(),
3050 &session.peer,
3051 );
3052 } else {
3053 let update = proto::UpdateChannels {
3054 remove_channel_invitations: vec![channel_id.to_proto()],
3055 ..Default::default()
3056 };
3057
3058 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3059 session.peer.send(connection_id, update.clone())?;
3060 }
3061 };
3062
3063 send_notifications(&connection_pool, &session.peer, notifications);
3064
3065 response.send(proto::Ack {})?;
3066
3067 Ok(())
3068}
3069
3070/// Join the channels' room
3071async fn join_channel(
3072 request: proto::JoinChannel,
3073 response: Response<proto::JoinChannel>,
3074 session: Session,
3075) -> Result<()> {
3076 let channel_id = ChannelId::from_proto(request.channel_id);
3077 join_channel_internal(channel_id, Box::new(response), session).await
3078}
3079
3080trait JoinChannelInternalResponse {
3081 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3082}
3083impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3084 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3085 Response::<proto::JoinChannel>::send(self, result)
3086 }
3087}
3088impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3089 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3090 Response::<proto::JoinRoom>::send(self, result)
3091 }
3092}
3093
3094async fn join_channel_internal(
3095 channel_id: ChannelId,
3096 response: Box<impl JoinChannelInternalResponse>,
3097 session: Session,
3098) -> Result<()> {
3099 let joined_room = {
3100 let mut db = session.db().await;
3101 // If zed quits without leaving the room, and the user re-opens zed before the
3102 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3103 // room they were in.
3104 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3105 tracing::info!(
3106 stale_connection_id = %connection,
3107 "cleaning up stale connection",
3108 );
3109 drop(db);
3110 leave_room_for_session(&session, connection).await?;
3111 db = session.db().await;
3112 }
3113
3114 let (joined_room, membership_updated, role) = db
3115 .join_channel(channel_id, session.user_id(), session.connection_id)
3116 .await?;
3117
3118 let live_kit_connection_info =
3119 session
3120 .app_state
3121 .livekit_client
3122 .as_ref()
3123 .and_then(|live_kit| {
3124 let (can_publish, token) = if role == ChannelRole::Guest {
3125 (
3126 false,
3127 live_kit
3128 .guest_token(
3129 &joined_room.room.livekit_room,
3130 &session.user_id().to_string(),
3131 )
3132 .trace_err()?,
3133 )
3134 } else {
3135 (
3136 true,
3137 live_kit
3138 .room_token(
3139 &joined_room.room.livekit_room,
3140 &session.user_id().to_string(),
3141 )
3142 .trace_err()?,
3143 )
3144 };
3145
3146 Some(LiveKitConnectionInfo {
3147 server_url: live_kit.url().into(),
3148 token,
3149 can_publish,
3150 })
3151 });
3152
3153 response.send(proto::JoinRoomResponse {
3154 room: Some(joined_room.room.clone()),
3155 channel_id: joined_room
3156 .channel
3157 .as_ref()
3158 .map(|channel| channel.id.to_proto()),
3159 live_kit_connection_info,
3160 })?;
3161
3162 let mut connection_pool = session.connection_pool().await;
3163 if let Some(membership_updated) = membership_updated {
3164 notify_membership_updated(
3165 &mut connection_pool,
3166 membership_updated,
3167 session.user_id(),
3168 &session.peer,
3169 );
3170 }
3171
3172 room_updated(&joined_room.room, &session.peer);
3173
3174 joined_room
3175 };
3176
3177 channel_updated(
3178 &joined_room
3179 .channel
3180 .ok_or_else(|| anyhow!("channel not returned"))?,
3181 &joined_room.room,
3182 &session.peer,
3183 &*session.connection_pool().await,
3184 );
3185
3186 update_user_contacts(session.user_id(), &session).await?;
3187 Ok(())
3188}
3189
3190/// Start editing the channel notes
3191async fn join_channel_buffer(
3192 request: proto::JoinChannelBuffer,
3193 response: Response<proto::JoinChannelBuffer>,
3194 session: Session,
3195) -> Result<()> {
3196 let db = session.db().await;
3197 let channel_id = ChannelId::from_proto(request.channel_id);
3198
3199 let open_response = db
3200 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3201 .await?;
3202
3203 let collaborators = open_response.collaborators.clone();
3204 response.send(open_response)?;
3205
3206 let update = UpdateChannelBufferCollaborators {
3207 channel_id: channel_id.to_proto(),
3208 collaborators: collaborators.clone(),
3209 };
3210 channel_buffer_updated(
3211 session.connection_id,
3212 collaborators
3213 .iter()
3214 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3215 &update,
3216 &session.peer,
3217 );
3218
3219 Ok(())
3220}
3221
3222/// Edit the channel notes
3223async fn update_channel_buffer(
3224 request: proto::UpdateChannelBuffer,
3225 session: Session,
3226) -> Result<()> {
3227 let db = session.db().await;
3228 let channel_id = ChannelId::from_proto(request.channel_id);
3229
3230 let (collaborators, epoch, version) = db
3231 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3232 .await?;
3233
3234 channel_buffer_updated(
3235 session.connection_id,
3236 collaborators.clone(),
3237 &proto::UpdateChannelBuffer {
3238 channel_id: channel_id.to_proto(),
3239 operations: request.operations,
3240 },
3241 &session.peer,
3242 );
3243
3244 let pool = &*session.connection_pool().await;
3245
3246 let non_collaborators =
3247 pool.channel_connection_ids(channel_id)
3248 .filter_map(|(connection_id, _)| {
3249 if collaborators.contains(&connection_id) {
3250 None
3251 } else {
3252 Some(connection_id)
3253 }
3254 });
3255
3256 broadcast(None, non_collaborators, |peer_id| {
3257 session.peer.send(
3258 peer_id,
3259 proto::UpdateChannels {
3260 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3261 channel_id: channel_id.to_proto(),
3262 epoch: epoch as u64,
3263 version: version.clone(),
3264 }],
3265 ..Default::default()
3266 },
3267 )
3268 });
3269
3270 Ok(())
3271}
3272
3273/// Rejoin the channel notes after a connection blip
3274async fn rejoin_channel_buffers(
3275 request: proto::RejoinChannelBuffers,
3276 response: Response<proto::RejoinChannelBuffers>,
3277 session: Session,
3278) -> Result<()> {
3279 let db = session.db().await;
3280 let buffers = db
3281 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3282 .await?;
3283
3284 for rejoined_buffer in &buffers {
3285 let collaborators_to_notify = rejoined_buffer
3286 .buffer
3287 .collaborators
3288 .iter()
3289 .filter_map(|c| Some(c.peer_id?.into()));
3290 channel_buffer_updated(
3291 session.connection_id,
3292 collaborators_to_notify,
3293 &proto::UpdateChannelBufferCollaborators {
3294 channel_id: rejoined_buffer.buffer.channel_id,
3295 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3296 },
3297 &session.peer,
3298 );
3299 }
3300
3301 response.send(proto::RejoinChannelBuffersResponse {
3302 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3303 })?;
3304
3305 Ok(())
3306}
3307
3308/// Stop editing the channel notes
3309async fn leave_channel_buffer(
3310 request: proto::LeaveChannelBuffer,
3311 response: Response<proto::LeaveChannelBuffer>,
3312 session: Session,
3313) -> Result<()> {
3314 let db = session.db().await;
3315 let channel_id = ChannelId::from_proto(request.channel_id);
3316
3317 let left_buffer = db
3318 .leave_channel_buffer(channel_id, session.connection_id)
3319 .await?;
3320
3321 response.send(Ack {})?;
3322
3323 channel_buffer_updated(
3324 session.connection_id,
3325 left_buffer.connections,
3326 &proto::UpdateChannelBufferCollaborators {
3327 channel_id: channel_id.to_proto(),
3328 collaborators: left_buffer.collaborators,
3329 },
3330 &session.peer,
3331 );
3332
3333 Ok(())
3334}
3335
3336fn channel_buffer_updated<T: EnvelopedMessage>(
3337 sender_id: ConnectionId,
3338 collaborators: impl IntoIterator<Item = ConnectionId>,
3339 message: &T,
3340 peer: &Peer,
3341) {
3342 broadcast(Some(sender_id), collaborators, |peer_id| {
3343 peer.send(peer_id, message.clone())
3344 });
3345}
3346
3347fn send_notifications(
3348 connection_pool: &ConnectionPool,
3349 peer: &Peer,
3350 notifications: db::NotificationBatch,
3351) {
3352 for (user_id, notification) in notifications {
3353 for connection_id in connection_pool.user_connection_ids(user_id) {
3354 if let Err(error) = peer.send(
3355 connection_id,
3356 proto::AddNotification {
3357 notification: Some(notification.clone()),
3358 },
3359 ) {
3360 tracing::error!(
3361 "failed to send notification to {:?} {}",
3362 connection_id,
3363 error
3364 );
3365 }
3366 }
3367 }
3368}
3369
3370/// Send a message to the channel
3371async fn send_channel_message(
3372 request: proto::SendChannelMessage,
3373 response: Response<proto::SendChannelMessage>,
3374 session: Session,
3375) -> Result<()> {
3376 // Validate the message body.
3377 let body = request.body.trim().to_string();
3378 if body.len() > MAX_MESSAGE_LEN {
3379 return Err(anyhow!("message is too long"))?;
3380 }
3381 if body.is_empty() {
3382 return Err(anyhow!("message can't be blank"))?;
3383 }
3384
3385 // TODO: adjust mentions if body is trimmed
3386
3387 let timestamp = OffsetDateTime::now_utc();
3388 let nonce = request
3389 .nonce
3390 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3391
3392 let channel_id = ChannelId::from_proto(request.channel_id);
3393 let CreatedChannelMessage {
3394 message_id,
3395 participant_connection_ids,
3396 notifications,
3397 } = session
3398 .db()
3399 .await
3400 .create_channel_message(
3401 channel_id,
3402 session.user_id(),
3403 &body,
3404 &request.mentions,
3405 timestamp,
3406 nonce.clone().into(),
3407 request.reply_to_message_id.map(MessageId::from_proto),
3408 )
3409 .await?;
3410
3411 let message = proto::ChannelMessage {
3412 sender_id: session.user_id().to_proto(),
3413 id: message_id.to_proto(),
3414 body,
3415 mentions: request.mentions,
3416 timestamp: timestamp.unix_timestamp() as u64,
3417 nonce: Some(nonce),
3418 reply_to_message_id: request.reply_to_message_id,
3419 edited_at: None,
3420 };
3421 broadcast(
3422 Some(session.connection_id),
3423 participant_connection_ids.clone(),
3424 |connection| {
3425 session.peer.send(
3426 connection,
3427 proto::ChannelMessageSent {
3428 channel_id: channel_id.to_proto(),
3429 message: Some(message.clone()),
3430 },
3431 )
3432 },
3433 );
3434 response.send(proto::SendChannelMessageResponse {
3435 message: Some(message),
3436 })?;
3437
3438 let pool = &*session.connection_pool().await;
3439 let non_participants =
3440 pool.channel_connection_ids(channel_id)
3441 .filter_map(|(connection_id, _)| {
3442 if participant_connection_ids.contains(&connection_id) {
3443 None
3444 } else {
3445 Some(connection_id)
3446 }
3447 });
3448 broadcast(None, non_participants, |peer_id| {
3449 session.peer.send(
3450 peer_id,
3451 proto::UpdateChannels {
3452 latest_channel_message_ids: vec![proto::ChannelMessageId {
3453 channel_id: channel_id.to_proto(),
3454 message_id: message_id.to_proto(),
3455 }],
3456 ..Default::default()
3457 },
3458 )
3459 });
3460 send_notifications(pool, &session.peer, notifications);
3461
3462 Ok(())
3463}
3464
3465/// Delete a channel message
3466async fn remove_channel_message(
3467 request: proto::RemoveChannelMessage,
3468 response: Response<proto::RemoveChannelMessage>,
3469 session: Session,
3470) -> Result<()> {
3471 let channel_id = ChannelId::from_proto(request.channel_id);
3472 let message_id = MessageId::from_proto(request.message_id);
3473 let (connection_ids, existing_notification_ids) = session
3474 .db()
3475 .await
3476 .remove_channel_message(channel_id, message_id, session.user_id())
3477 .await?;
3478
3479 broadcast(
3480 Some(session.connection_id),
3481 connection_ids,
3482 move |connection| {
3483 session.peer.send(connection, request.clone())?;
3484
3485 for notification_id in &existing_notification_ids {
3486 session.peer.send(
3487 connection,
3488 proto::DeleteNotification {
3489 notification_id: (*notification_id).to_proto(),
3490 },
3491 )?;
3492 }
3493
3494 Ok(())
3495 },
3496 );
3497 response.send(proto::Ack {})?;
3498 Ok(())
3499}
3500
3501async fn update_channel_message(
3502 request: proto::UpdateChannelMessage,
3503 response: Response<proto::UpdateChannelMessage>,
3504 session: Session,
3505) -> Result<()> {
3506 let channel_id = ChannelId::from_proto(request.channel_id);
3507 let message_id = MessageId::from_proto(request.message_id);
3508 let updated_at = OffsetDateTime::now_utc();
3509 let UpdatedChannelMessage {
3510 message_id,
3511 participant_connection_ids,
3512 notifications,
3513 reply_to_message_id,
3514 timestamp,
3515 deleted_mention_notification_ids,
3516 updated_mention_notifications,
3517 } = session
3518 .db()
3519 .await
3520 .update_channel_message(
3521 channel_id,
3522 message_id,
3523 session.user_id(),
3524 request.body.as_str(),
3525 &request.mentions,
3526 updated_at,
3527 )
3528 .await?;
3529
3530 let nonce = request
3531 .nonce
3532 .clone()
3533 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3534
3535 let message = proto::ChannelMessage {
3536 sender_id: session.user_id().to_proto(),
3537 id: message_id.to_proto(),
3538 body: request.body.clone(),
3539 mentions: request.mentions.clone(),
3540 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3541 nonce: Some(nonce),
3542 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3543 edited_at: Some(updated_at.unix_timestamp() as u64),
3544 };
3545
3546 response.send(proto::Ack {})?;
3547
3548 let pool = &*session.connection_pool().await;
3549 broadcast(
3550 Some(session.connection_id),
3551 participant_connection_ids,
3552 |connection| {
3553 session.peer.send(
3554 connection,
3555 proto::ChannelMessageUpdate {
3556 channel_id: channel_id.to_proto(),
3557 message: Some(message.clone()),
3558 },
3559 )?;
3560
3561 for notification_id in &deleted_mention_notification_ids {
3562 session.peer.send(
3563 connection,
3564 proto::DeleteNotification {
3565 notification_id: (*notification_id).to_proto(),
3566 },
3567 )?;
3568 }
3569
3570 for notification in &updated_mention_notifications {
3571 session.peer.send(
3572 connection,
3573 proto::UpdateNotification {
3574 notification: Some(notification.clone()),
3575 },
3576 )?;
3577 }
3578
3579 Ok(())
3580 },
3581 );
3582
3583 send_notifications(pool, &session.peer, notifications);
3584
3585 Ok(())
3586}
3587
3588/// Mark a channel message as read
3589async fn acknowledge_channel_message(
3590 request: proto::AckChannelMessage,
3591 session: Session,
3592) -> Result<()> {
3593 let channel_id = ChannelId::from_proto(request.channel_id);
3594 let message_id = MessageId::from_proto(request.message_id);
3595 let notifications = session
3596 .db()
3597 .await
3598 .observe_channel_message(channel_id, session.user_id(), message_id)
3599 .await?;
3600 send_notifications(
3601 &*session.connection_pool().await,
3602 &session.peer,
3603 notifications,
3604 );
3605 Ok(())
3606}
3607
3608/// Mark a buffer version as synced
3609async fn acknowledge_buffer_version(
3610 request: proto::AckBufferOperation,
3611 session: Session,
3612) -> Result<()> {
3613 let buffer_id = BufferId::from_proto(request.buffer_id);
3614 session
3615 .db()
3616 .await
3617 .observe_buffer_version(
3618 buffer_id,
3619 session.user_id(),
3620 request.epoch as i32,
3621 &request.version,
3622 )
3623 .await?;
3624 Ok(())
3625}
3626
3627async fn count_language_model_tokens(
3628 request: proto::CountLanguageModelTokens,
3629 response: Response<proto::CountLanguageModelTokens>,
3630 session: Session,
3631 config: &Config,
3632) -> Result<()> {
3633 authorize_access_to_legacy_llm_endpoints(&session).await?;
3634
3635 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3636 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3637 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3638 };
3639
3640 session
3641 .app_state
3642 .rate_limiter
3643 .check(&*rate_limit, session.user_id())
3644 .await?;
3645
3646 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3647 Some(proto::LanguageModelProvider::Google) => {
3648 let api_key = config
3649 .google_ai_api_key
3650 .as_ref()
3651 .context("no Google AI API key configured on the server")?;
3652 google_ai::count_tokens(
3653 session.http_client.as_ref(),
3654 google_ai::API_URL,
3655 api_key,
3656 serde_json::from_str(&request.request)?,
3657 )
3658 .await?
3659 }
3660 _ => return Err(anyhow!("unsupported provider"))?,
3661 };
3662
3663 response.send(proto::CountLanguageModelTokensResponse {
3664 token_count: result.total_tokens as u32,
3665 })?;
3666
3667 Ok(())
3668}
3669
3670struct ZedProCountLanguageModelTokensRateLimit;
3671
3672impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3673 fn capacity(&self) -> usize {
3674 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3675 .ok()
3676 .and_then(|v| v.parse().ok())
3677 .unwrap_or(600) // Picked arbitrarily
3678 }
3679
3680 fn refill_duration(&self) -> chrono::Duration {
3681 chrono::Duration::hours(1)
3682 }
3683
3684 fn db_name(&self) -> &'static str {
3685 "zed-pro:count-language-model-tokens"
3686 }
3687}
3688
3689struct FreeCountLanguageModelTokensRateLimit;
3690
3691impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3692 fn capacity(&self) -> usize {
3693 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3694 .ok()
3695 .and_then(|v| v.parse().ok())
3696 .unwrap_or(600 / 10) // Picked arbitrarily
3697 }
3698
3699 fn refill_duration(&self) -> chrono::Duration {
3700 chrono::Duration::hours(1)
3701 }
3702
3703 fn db_name(&self) -> &'static str {
3704 "free:count-language-model-tokens"
3705 }
3706}
3707
3708struct ZedProComputeEmbeddingsRateLimit;
3709
3710impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3711 fn capacity(&self) -> usize {
3712 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3713 .ok()
3714 .and_then(|v| v.parse().ok())
3715 .unwrap_or(5000) // Picked arbitrarily
3716 }
3717
3718 fn refill_duration(&self) -> chrono::Duration {
3719 chrono::Duration::hours(1)
3720 }
3721
3722 fn db_name(&self) -> &'static str {
3723 "zed-pro:compute-embeddings"
3724 }
3725}
3726
3727struct FreeComputeEmbeddingsRateLimit;
3728
3729impl RateLimit for FreeComputeEmbeddingsRateLimit {
3730 fn capacity(&self) -> usize {
3731 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3732 .ok()
3733 .and_then(|v| v.parse().ok())
3734 .unwrap_or(5000 / 10) // Picked arbitrarily
3735 }
3736
3737 fn refill_duration(&self) -> chrono::Duration {
3738 chrono::Duration::hours(1)
3739 }
3740
3741 fn db_name(&self) -> &'static str {
3742 "free:compute-embeddings"
3743 }
3744}
3745
3746async fn compute_embeddings(
3747 request: proto::ComputeEmbeddings,
3748 response: Response<proto::ComputeEmbeddings>,
3749 session: Session,
3750 api_key: Option<Arc<str>>,
3751) -> Result<()> {
3752 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3753 authorize_access_to_legacy_llm_endpoints(&session).await?;
3754
3755 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3756 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3757 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3758 };
3759
3760 session
3761 .app_state
3762 .rate_limiter
3763 .check(&*rate_limit, session.user_id())
3764 .await?;
3765
3766 let embeddings = match request.model.as_str() {
3767 "openai/text-embedding-3-small" => {
3768 open_ai::embed(
3769 session.http_client.as_ref(),
3770 OPEN_AI_API_URL,
3771 &api_key,
3772 OpenAiEmbeddingModel::TextEmbedding3Small,
3773 request.texts.iter().map(|text| text.as_str()),
3774 )
3775 .await?
3776 }
3777 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3778 };
3779
3780 let embeddings = request
3781 .texts
3782 .iter()
3783 .map(|text| {
3784 let mut hasher = sha2::Sha256::new();
3785 hasher.update(text.as_bytes());
3786 let result = hasher.finalize();
3787 result.to_vec()
3788 })
3789 .zip(
3790 embeddings
3791 .data
3792 .into_iter()
3793 .map(|embedding| embedding.embedding),
3794 )
3795 .collect::<HashMap<_, _>>();
3796
3797 let db = session.db().await;
3798 db.save_embeddings(&request.model, &embeddings)
3799 .await
3800 .context("failed to save embeddings")
3801 .trace_err();
3802
3803 response.send(proto::ComputeEmbeddingsResponse {
3804 embeddings: embeddings
3805 .into_iter()
3806 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3807 .collect(),
3808 })?;
3809 Ok(())
3810}
3811
3812async fn get_cached_embeddings(
3813 request: proto::GetCachedEmbeddings,
3814 response: Response<proto::GetCachedEmbeddings>,
3815 session: Session,
3816) -> Result<()> {
3817 authorize_access_to_legacy_llm_endpoints(&session).await?;
3818
3819 let db = session.db().await;
3820 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3821
3822 response.send(proto::GetCachedEmbeddingsResponse {
3823 embeddings: embeddings
3824 .into_iter()
3825 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3826 .collect(),
3827 })?;
3828 Ok(())
3829}
3830
3831/// This is leftover from before the LLM service.
3832///
3833/// The endpoints protected by this check will be moved there eventually.
3834async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3835 if session.is_staff() {
3836 Ok(())
3837 } else {
3838 Err(anyhow!("permission denied"))?
3839 }
3840}
3841
3842/// Get a Supermaven API key for the user
3843async fn get_supermaven_api_key(
3844 _request: proto::GetSupermavenApiKey,
3845 response: Response<proto::GetSupermavenApiKey>,
3846 session: Session,
3847) -> Result<()> {
3848 let user_id: String = session.user_id().to_string();
3849 if !session.is_staff() {
3850 return Err(anyhow!("supermaven not enabled for this account"))?;
3851 }
3852
3853 let email = session
3854 .email()
3855 .ok_or_else(|| anyhow!("user must have an email"))?;
3856
3857 let supermaven_admin_api = session
3858 .supermaven_client
3859 .as_ref()
3860 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3861
3862 let result = supermaven_admin_api
3863 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3864 .await?;
3865
3866 response.send(proto::GetSupermavenApiKeyResponse {
3867 api_key: result.api_key,
3868 })?;
3869
3870 Ok(())
3871}
3872
3873/// Start receiving chat updates for a channel
3874async fn join_channel_chat(
3875 request: proto::JoinChannelChat,
3876 response: Response<proto::JoinChannelChat>,
3877 session: Session,
3878) -> Result<()> {
3879 let channel_id = ChannelId::from_proto(request.channel_id);
3880
3881 let db = session.db().await;
3882 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3883 .await?;
3884 let messages = db
3885 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3886 .await?;
3887 response.send(proto::JoinChannelChatResponse {
3888 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3889 messages,
3890 })?;
3891 Ok(())
3892}
3893
3894/// Stop receiving chat updates for a channel
3895async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3896 let channel_id = ChannelId::from_proto(request.channel_id);
3897 session
3898 .db()
3899 .await
3900 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3901 .await?;
3902 Ok(())
3903}
3904
3905/// Retrieve the chat history for a channel
3906async fn get_channel_messages(
3907 request: proto::GetChannelMessages,
3908 response: Response<proto::GetChannelMessages>,
3909 session: Session,
3910) -> Result<()> {
3911 let channel_id = ChannelId::from_proto(request.channel_id);
3912 let messages = session
3913 .db()
3914 .await
3915 .get_channel_messages(
3916 channel_id,
3917 session.user_id(),
3918 MESSAGE_COUNT_PER_PAGE,
3919 Some(MessageId::from_proto(request.before_message_id)),
3920 )
3921 .await?;
3922 response.send(proto::GetChannelMessagesResponse {
3923 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3924 messages,
3925 })?;
3926 Ok(())
3927}
3928
3929/// Retrieve specific chat messages
3930async fn get_channel_messages_by_id(
3931 request: proto::GetChannelMessagesById,
3932 response: Response<proto::GetChannelMessagesById>,
3933 session: Session,
3934) -> Result<()> {
3935 let message_ids = request
3936 .message_ids
3937 .iter()
3938 .map(|id| MessageId::from_proto(*id))
3939 .collect::<Vec<_>>();
3940 let messages = session
3941 .db()
3942 .await
3943 .get_channel_messages_by_id(session.user_id(), &message_ids)
3944 .await?;
3945 response.send(proto::GetChannelMessagesResponse {
3946 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3947 messages,
3948 })?;
3949 Ok(())
3950}
3951
3952/// Retrieve the current users notifications
3953async fn get_notifications(
3954 request: proto::GetNotifications,
3955 response: Response<proto::GetNotifications>,
3956 session: Session,
3957) -> Result<()> {
3958 let notifications = session
3959 .db()
3960 .await
3961 .get_notifications(
3962 session.user_id(),
3963 NOTIFICATION_COUNT_PER_PAGE,
3964 request.before_id.map(db::NotificationId::from_proto),
3965 )
3966 .await?;
3967 response.send(proto::GetNotificationsResponse {
3968 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3969 notifications,
3970 })?;
3971 Ok(())
3972}
3973
3974/// Mark notifications as read
3975async fn mark_notification_as_read(
3976 request: proto::MarkNotificationRead,
3977 response: Response<proto::MarkNotificationRead>,
3978 session: Session,
3979) -> Result<()> {
3980 let database = &session.db().await;
3981 let notifications = database
3982 .mark_notification_as_read_by_id(
3983 session.user_id(),
3984 NotificationId::from_proto(request.notification_id),
3985 )
3986 .await?;
3987 send_notifications(
3988 &*session.connection_pool().await,
3989 &session.peer,
3990 notifications,
3991 );
3992 response.send(proto::Ack {})?;
3993 Ok(())
3994}
3995
3996/// Get the current users information
3997async fn get_private_user_info(
3998 _request: proto::GetPrivateUserInfo,
3999 response: Response<proto::GetPrivateUserInfo>,
4000 session: Session,
4001) -> Result<()> {
4002 let db = session.db().await;
4003
4004 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4005 let user = db
4006 .get_user_by_id(session.user_id())
4007 .await?
4008 .ok_or_else(|| anyhow!("user not found"))?;
4009 let flags = db.get_user_flags(session.user_id()).await?;
4010
4011 response.send(proto::GetPrivateUserInfoResponse {
4012 metrics_id,
4013 staff: user.admin,
4014 flags,
4015 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4016 })?;
4017 Ok(())
4018}
4019
4020/// Accept the terms of service (tos) on behalf of the current user
4021async fn accept_terms_of_service(
4022 _request: proto::AcceptTermsOfService,
4023 response: Response<proto::AcceptTermsOfService>,
4024 session: Session,
4025) -> Result<()> {
4026 let db = session.db().await;
4027
4028 let accepted_tos_at = Utc::now();
4029 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4030 .await?;
4031
4032 response.send(proto::AcceptTermsOfServiceResponse {
4033 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4034 })?;
4035 Ok(())
4036}
4037
4038/// The minimum account age an account must have in order to use the LLM service.
4039pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4040
4041async fn get_llm_api_token(
4042 _request: proto::GetLlmToken,
4043 response: Response<proto::GetLlmToken>,
4044 session: Session,
4045) -> Result<()> {
4046 let db = session.db().await;
4047
4048 let flags = db.get_user_flags(session.user_id()).await?;
4049 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4050
4051 if !session.is_staff() && !has_language_models_feature_flag {
4052 Err(anyhow!("permission denied"))?
4053 }
4054
4055 let user_id = session.user_id();
4056 let user = db
4057 .get_user_by_id(user_id)
4058 .await?
4059 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4060
4061 if user.accepted_tos_at.is_none() {
4062 Err(anyhow!("terms of service not accepted"))?
4063 }
4064
4065 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4066 let billing_preferences = db.get_billing_preferences(user.id).await?;
4067
4068 let token = LlmTokenClaims::create(
4069 &user,
4070 session.is_staff(),
4071 billing_preferences,
4072 &flags,
4073 has_llm_subscription,
4074 session.current_plan(&db).await?,
4075 session.system_id.clone(),
4076 &session.app_state.config,
4077 )?;
4078 response.send(proto::GetLlmTokenResponse { token })?;
4079 Ok(())
4080}
4081
4082fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4083 let message = match message {
4084 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4085 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4086 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4087 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4088 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4089 code: frame.code.into(),
4090 reason: frame.reason,
4091 })),
4092 // We should never receive a frame while reading the message, according
4093 // to the `tungstenite` maintainers:
4094 //
4095 // > It cannot occur when you read messages from the WebSocket, but it
4096 // > can be used when you want to send the raw frames (e.g. you want to
4097 // > send the frames to the WebSocket without composing the full message first).
4098 // >
4099 // > — https://github.com/snapview/tungstenite-rs/issues/268
4100 TungsteniteMessage::Frame(_) => {
4101 bail!("received an unexpected frame while reading the message")
4102 }
4103 };
4104
4105 Ok(message)
4106}
4107
4108fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4109 match message {
4110 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4111 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4112 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4113 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4114 AxumMessage::Close(frame) => {
4115 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4116 code: frame.code.into(),
4117 reason: frame.reason,
4118 }))
4119 }
4120 }
4121}
4122
4123fn notify_membership_updated(
4124 connection_pool: &mut ConnectionPool,
4125 result: MembershipUpdated,
4126 user_id: UserId,
4127 peer: &Peer,
4128) {
4129 for membership in &result.new_channels.channel_memberships {
4130 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4131 }
4132 for channel_id in &result.removed_channels {
4133 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4134 }
4135
4136 let user_channels_update = proto::UpdateUserChannels {
4137 channel_memberships: result
4138 .new_channels
4139 .channel_memberships
4140 .iter()
4141 .map(|cm| proto::ChannelMembership {
4142 channel_id: cm.channel_id.to_proto(),
4143 role: cm.role.into(),
4144 })
4145 .collect(),
4146 ..Default::default()
4147 };
4148
4149 let mut update = build_channels_update(result.new_channels);
4150 update.delete_channels = result
4151 .removed_channels
4152 .into_iter()
4153 .map(|id| id.to_proto())
4154 .collect();
4155 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4156
4157 for connection_id in connection_pool.user_connection_ids(user_id) {
4158 peer.send(connection_id, user_channels_update.clone())
4159 .trace_err();
4160 peer.send(connection_id, update.clone()).trace_err();
4161 }
4162}
4163
4164fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4165 proto::UpdateUserChannels {
4166 channel_memberships: channels
4167 .channel_memberships
4168 .iter()
4169 .map(|m| proto::ChannelMembership {
4170 channel_id: m.channel_id.to_proto(),
4171 role: m.role.into(),
4172 })
4173 .collect(),
4174 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4175 observed_channel_message_id: channels.observed_channel_messages.clone(),
4176 }
4177}
4178
4179fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4180 let mut update = proto::UpdateChannels::default();
4181
4182 for channel in channels.channels {
4183 update.channels.push(channel.to_proto());
4184 }
4185
4186 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4187 update.latest_channel_message_ids = channels.latest_channel_messages;
4188
4189 for (channel_id, participants) in channels.channel_participants {
4190 update
4191 .channel_participants
4192 .push(proto::ChannelParticipants {
4193 channel_id: channel_id.to_proto(),
4194 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4195 });
4196 }
4197
4198 for channel in channels.invited_channels {
4199 update.channel_invitations.push(channel.to_proto());
4200 }
4201
4202 update
4203}
4204
4205fn build_initial_contacts_update(
4206 contacts: Vec<db::Contact>,
4207 pool: &ConnectionPool,
4208) -> proto::UpdateContacts {
4209 let mut update = proto::UpdateContacts::default();
4210
4211 for contact in contacts {
4212 match contact {
4213 db::Contact::Accepted { user_id, busy } => {
4214 update.contacts.push(contact_for_user(user_id, busy, pool));
4215 }
4216 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4217 db::Contact::Incoming { user_id } => {
4218 update
4219 .incoming_requests
4220 .push(proto::IncomingContactRequest {
4221 requester_id: user_id.to_proto(),
4222 })
4223 }
4224 }
4225 }
4226
4227 update
4228}
4229
4230fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4231 proto::Contact {
4232 user_id: user_id.to_proto(),
4233 online: pool.is_user_online(user_id),
4234 busy,
4235 }
4236}
4237
4238fn room_updated(room: &proto::Room, peer: &Peer) {
4239 broadcast(
4240 None,
4241 room.participants
4242 .iter()
4243 .filter_map(|participant| Some(participant.peer_id?.into())),
4244 |peer_id| {
4245 peer.send(
4246 peer_id,
4247 proto::RoomUpdated {
4248 room: Some(room.clone()),
4249 },
4250 )
4251 },
4252 );
4253}
4254
4255fn channel_updated(
4256 channel: &db::channel::Model,
4257 room: &proto::Room,
4258 peer: &Peer,
4259 pool: &ConnectionPool,
4260) {
4261 let participants = room
4262 .participants
4263 .iter()
4264 .map(|p| p.user_id)
4265 .collect::<Vec<_>>();
4266
4267 broadcast(
4268 None,
4269 pool.channel_connection_ids(channel.root_id())
4270 .filter_map(|(channel_id, role)| {
4271 role.can_see_channel(channel.visibility)
4272 .then_some(channel_id)
4273 }),
4274 |peer_id| {
4275 peer.send(
4276 peer_id,
4277 proto::UpdateChannels {
4278 channel_participants: vec![proto::ChannelParticipants {
4279 channel_id: channel.id.to_proto(),
4280 participant_user_ids: participants.clone(),
4281 }],
4282 ..Default::default()
4283 },
4284 )
4285 },
4286 );
4287}
4288
4289async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4290 let db = session.db().await;
4291
4292 let contacts = db.get_contacts(user_id).await?;
4293 let busy = db.is_user_busy(user_id).await?;
4294
4295 let pool = session.connection_pool().await;
4296 let updated_contact = contact_for_user(user_id, busy, &pool);
4297 for contact in contacts {
4298 if let db::Contact::Accepted {
4299 user_id: contact_user_id,
4300 ..
4301 } = contact
4302 {
4303 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4304 session
4305 .peer
4306 .send(
4307 contact_conn_id,
4308 proto::UpdateContacts {
4309 contacts: vec![updated_contact.clone()],
4310 remove_contacts: Default::default(),
4311 incoming_requests: Default::default(),
4312 remove_incoming_requests: Default::default(),
4313 outgoing_requests: Default::default(),
4314 remove_outgoing_requests: Default::default(),
4315 },
4316 )
4317 .trace_err();
4318 }
4319 }
4320 }
4321 Ok(())
4322}
4323
4324async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4325 let mut contacts_to_update = HashSet::default();
4326
4327 let room_id;
4328 let canceled_calls_to_user_ids;
4329 let livekit_room;
4330 let delete_livekit_room;
4331 let room;
4332 let channel;
4333
4334 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4335 contacts_to_update.insert(session.user_id());
4336
4337 for project in left_room.left_projects.values() {
4338 project_left(project, session);
4339 }
4340
4341 room_id = RoomId::from_proto(left_room.room.id);
4342 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4343 livekit_room = mem::take(&mut left_room.room.livekit_room);
4344 delete_livekit_room = left_room.deleted;
4345 room = mem::take(&mut left_room.room);
4346 channel = mem::take(&mut left_room.channel);
4347
4348 room_updated(&room, &session.peer);
4349 } else {
4350 return Ok(());
4351 }
4352
4353 if let Some(channel) = channel {
4354 channel_updated(
4355 &channel,
4356 &room,
4357 &session.peer,
4358 &*session.connection_pool().await,
4359 );
4360 }
4361
4362 {
4363 let pool = session.connection_pool().await;
4364 for canceled_user_id in canceled_calls_to_user_ids {
4365 for connection_id in pool.user_connection_ids(canceled_user_id) {
4366 session
4367 .peer
4368 .send(
4369 connection_id,
4370 proto::CallCanceled {
4371 room_id: room_id.to_proto(),
4372 },
4373 )
4374 .trace_err();
4375 }
4376 contacts_to_update.insert(canceled_user_id);
4377 }
4378 }
4379
4380 for contact_user_id in contacts_to_update {
4381 update_user_contacts(contact_user_id, session).await?;
4382 }
4383
4384 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4385 live_kit
4386 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4387 .await
4388 .trace_err();
4389
4390 if delete_livekit_room {
4391 live_kit.delete_room(livekit_room).await.trace_err();
4392 }
4393 }
4394
4395 Ok(())
4396}
4397
4398async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4399 let left_channel_buffers = session
4400 .db()
4401 .await
4402 .leave_channel_buffers(session.connection_id)
4403 .await?;
4404
4405 for left_buffer in left_channel_buffers {
4406 channel_buffer_updated(
4407 session.connection_id,
4408 left_buffer.connections,
4409 &proto::UpdateChannelBufferCollaborators {
4410 channel_id: left_buffer.channel_id.to_proto(),
4411 collaborators: left_buffer.collaborators,
4412 },
4413 &session.peer,
4414 );
4415 }
4416
4417 Ok(())
4418}
4419
4420fn project_left(project: &db::LeftProject, session: &Session) {
4421 for connection_id in &project.connection_ids {
4422 if project.should_unshare {
4423 session
4424 .peer
4425 .send(
4426 *connection_id,
4427 proto::UnshareProject {
4428 project_id: project.id.to_proto(),
4429 },
4430 )
4431 .trace_err();
4432 } else {
4433 session
4434 .peer
4435 .send(
4436 *connection_id,
4437 proto::RemoveProjectCollaborator {
4438 project_id: project.id.to_proto(),
4439 peer_id: Some(session.connection_id.into()),
4440 },
4441 )
4442 .trace_err();
4443 }
4444 }
4445}
4446
4447pub trait ResultExt {
4448 type Ok;
4449
4450 fn trace_err(self) -> Option<Self::Ok>;
4451}
4452
4453impl<T, E> ResultExt for Result<T, E>
4454where
4455 E: std::fmt::Debug,
4456{
4457 type Ok = T;
4458
4459 #[track_caller]
4460 fn trace_err(self) -> Option<T> {
4461 match self {
4462 Ok(value) => Some(value),
4463 Err(error) => {
4464 tracing::error!("{:?}", error);
4465 None
4466 }
4467 }
4468 }
4469}