1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
314 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
315 .add_request_handler(
316 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
317 )
318 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
319 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
320 .add_request_handler(
321 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
322 )
323 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
324 .add_request_handler(
325 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
326 )
327 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
328 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
329 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
331 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
332 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
333 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
334 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
337 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
338 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
339 .add_request_handler(
340 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
341 )
342 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
343 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
344 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
345 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
346 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
347 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
348 .add_message_handler(create_buffer_for_peer)
349 .add_request_handler(update_buffer)
350 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
353 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
354 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
355 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
356 .add_request_handler(get_users)
357 .add_request_handler(fuzzy_search_users)
358 .add_request_handler(request_contact)
359 .add_request_handler(remove_contact)
360 .add_request_handler(respond_to_contact_request)
361 .add_message_handler(subscribe_to_channels)
362 .add_request_handler(create_channel)
363 .add_request_handler(delete_channel)
364 .add_request_handler(invite_channel_member)
365 .add_request_handler(remove_channel_member)
366 .add_request_handler(set_channel_member_role)
367 .add_request_handler(set_channel_visibility)
368 .add_request_handler(rename_channel)
369 .add_request_handler(join_channel_buffer)
370 .add_request_handler(leave_channel_buffer)
371 .add_message_handler(update_channel_buffer)
372 .add_request_handler(rejoin_channel_buffers)
373 .add_request_handler(get_channel_members)
374 .add_request_handler(respond_to_channel_invite)
375 .add_request_handler(join_channel)
376 .add_request_handler(join_channel_chat)
377 .add_message_handler(leave_channel_chat)
378 .add_request_handler(send_channel_message)
379 .add_request_handler(remove_channel_message)
380 .add_request_handler(update_channel_message)
381 .add_request_handler(get_channel_messages)
382 .add_request_handler(get_channel_messages_by_id)
383 .add_request_handler(get_notifications)
384 .add_request_handler(mark_notification_as_read)
385 .add_request_handler(move_channel)
386 .add_request_handler(follow)
387 .add_message_handler(unfollow)
388 .add_message_handler(update_followers)
389 .add_request_handler(get_private_user_info)
390 .add_request_handler(get_llm_api_token)
391 .add_request_handler(accept_terms_of_service)
392 .add_message_handler(acknowledge_channel_message)
393 .add_message_handler(acknowledge_buffer_version)
394 .add_request_handler(get_supermaven_api_key)
395 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
396 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
397 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
398 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
399 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
400 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
401 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
402 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
403 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
404 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
405 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
406 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
407 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
408 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
409 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
410 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
411 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
412 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
413 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
414 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
415 .add_message_handler(update_context)
416 .add_request_handler({
417 let app_state = app_state.clone();
418 move |request, response, session| {
419 let app_state = app_state.clone();
420 async move {
421 count_language_model_tokens(request, response, session, &app_state.config)
422 .await
423 }
424 }
425 })
426 .add_request_handler(get_cached_embeddings)
427 .add_request_handler({
428 let app_state = app_state.clone();
429 move |request, response, session| {
430 compute_embeddings(
431 request,
432 response,
433 session,
434 app_state.config.openai_api_key.clone(),
435 )
436 }
437 });
438
439 Arc::new(server)
440 }
441
442 pub async fn start(&self) -> Result<()> {
443 let server_id = *self.id.lock();
444 let app_state = self.app_state.clone();
445 let peer = self.peer.clone();
446 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
447 let pool = self.connection_pool.clone();
448 let livekit_client = self.app_state.livekit_client.clone();
449
450 let span = info_span!("start server");
451 self.app_state.executor.spawn_detached(
452 async move {
453 tracing::info!("waiting for cleanup timeout");
454 timeout.await;
455 tracing::info!("cleanup timeout expired, retrieving stale rooms");
456 if let Some((room_ids, channel_ids)) = app_state
457 .db
458 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
459 .await
460 .trace_err()
461 {
462 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
463 tracing::info!(
464 stale_channel_buffer_count = channel_ids.len(),
465 "retrieved stale channel buffers"
466 );
467
468 for channel_id in channel_ids {
469 if let Some(refreshed_channel_buffer) = app_state
470 .db
471 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
472 .await
473 .trace_err()
474 {
475 for connection_id in refreshed_channel_buffer.connection_ids {
476 peer.send(
477 connection_id,
478 proto::UpdateChannelBufferCollaborators {
479 channel_id: channel_id.to_proto(),
480 collaborators: refreshed_channel_buffer
481 .collaborators
482 .clone(),
483 },
484 )
485 .trace_err();
486 }
487 }
488 }
489
490 for room_id in room_ids {
491 let mut contacts_to_update = HashSet::default();
492 let mut canceled_calls_to_user_ids = Vec::new();
493 let mut livekit_room = String::new();
494 let mut delete_livekit_room = false;
495
496 if let Some(mut refreshed_room) = app_state
497 .db
498 .clear_stale_room_participants(room_id, server_id)
499 .await
500 .trace_err()
501 {
502 tracing::info!(
503 room_id = room_id.0,
504 new_participant_count = refreshed_room.room.participants.len(),
505 "refreshed room"
506 );
507 room_updated(&refreshed_room.room, &peer);
508 if let Some(channel) = refreshed_room.channel.as_ref() {
509 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
510 }
511 contacts_to_update
512 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
513 contacts_to_update
514 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
515 canceled_calls_to_user_ids =
516 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
517 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
518 delete_livekit_room = refreshed_room.room.participants.is_empty();
519 }
520
521 {
522 let pool = pool.lock();
523 for canceled_user_id in canceled_calls_to_user_ids {
524 for connection_id in pool.user_connection_ids(canceled_user_id) {
525 peer.send(
526 connection_id,
527 proto::CallCanceled {
528 room_id: room_id.to_proto(),
529 },
530 )
531 .trace_err();
532 }
533 }
534 }
535
536 for user_id in contacts_to_update {
537 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
538 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
539 if let Some((busy, contacts)) = busy.zip(contacts) {
540 let pool = pool.lock();
541 let updated_contact = contact_for_user(user_id, busy, &pool);
542 for contact in contacts {
543 if let db::Contact::Accepted {
544 user_id: contact_user_id,
545 ..
546 } = contact
547 {
548 for contact_conn_id in
549 pool.user_connection_ids(contact_user_id)
550 {
551 peer.send(
552 contact_conn_id,
553 proto::UpdateContacts {
554 contacts: vec![updated_contact.clone()],
555 remove_contacts: Default::default(),
556 incoming_requests: Default::default(),
557 remove_incoming_requests: Default::default(),
558 outgoing_requests: Default::default(),
559 remove_outgoing_requests: Default::default(),
560 },
561 )
562 .trace_err();
563 }
564 }
565 }
566 }
567 }
568
569 if let Some(live_kit) = livekit_client.as_ref() {
570 if delete_livekit_room {
571 live_kit.delete_room(livekit_room).await.trace_err();
572 }
573 }
574 }
575 }
576
577 app_state
578 .db
579 .delete_stale_servers(&app_state.config.zed_environment, server_id)
580 .await
581 .trace_err();
582 }
583 .instrument(span),
584 );
585 Ok(())
586 }
587
588 pub fn teardown(&self) {
589 self.peer.teardown();
590 self.connection_pool.lock().reset();
591 let _ = self.teardown.send(true);
592 }
593
594 #[cfg(test)]
595 pub fn reset(&self, id: ServerId) {
596 self.teardown();
597 *self.id.lock() = id;
598 self.peer.reset(id.0 as u32);
599 let _ = self.teardown.send(false);
600 }
601
602 #[cfg(test)]
603 pub fn id(&self) -> ServerId {
604 *self.id.lock()
605 }
606
607 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
608 where
609 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
610 Fut: 'static + Send + Future<Output = Result<()>>,
611 M: EnvelopedMessage,
612 {
613 let prev_handler = self.handlers.insert(
614 TypeId::of::<M>(),
615 Box::new(move |envelope, session| {
616 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
617 let received_at = envelope.received_at;
618 tracing::info!("message received");
619 let start_time = Instant::now();
620 let future = (handler)(*envelope, session);
621 async move {
622 let result = future.await;
623 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
624 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
625 let queue_duration_ms = total_duration_ms - processing_duration_ms;
626 let payload_type = M::NAME;
627
628 match result {
629 Err(error) => {
630 tracing::error!(
631 ?error,
632 total_duration_ms,
633 processing_duration_ms,
634 queue_duration_ms,
635 payload_type,
636 "error handling message"
637 )
638 }
639 Ok(()) => tracing::info!(
640 total_duration_ms,
641 processing_duration_ms,
642 queue_duration_ms,
643 "finished handling message"
644 ),
645 }
646 }
647 .boxed()
648 }),
649 );
650 if prev_handler.is_some() {
651 panic!("registered a handler for the same message twice");
652 }
653 self
654 }
655
656 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
657 where
658 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
659 Fut: 'static + Send + Future<Output = Result<()>>,
660 M: EnvelopedMessage,
661 {
662 self.add_handler(move |envelope, session| handler(envelope.payload, session));
663 self
664 }
665
666 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
667 where
668 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
669 Fut: Send + Future<Output = Result<()>>,
670 M: RequestMessage,
671 {
672 let handler = Arc::new(handler);
673 self.add_handler(move |envelope, session| {
674 let receipt = envelope.receipt();
675 let handler = handler.clone();
676 async move {
677 let peer = session.peer.clone();
678 let responded = Arc::new(AtomicBool::default());
679 let response = Response {
680 peer: peer.clone(),
681 responded: responded.clone(),
682 receipt,
683 };
684 match (handler)(envelope.payload, response, session).await {
685 Ok(()) => {
686 if responded.load(std::sync::atomic::Ordering::SeqCst) {
687 Ok(())
688 } else {
689 Err(anyhow!("handler did not send a response"))?
690 }
691 }
692 Err(error) => {
693 let proto_err = match &error {
694 Error::Internal(err) => err.to_proto(),
695 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
696 };
697 peer.respond_with_error(receipt, proto_err)?;
698 Err(error)
699 }
700 }
701 }
702 })
703 }
704
705 pub fn handle_connection(
706 self: &Arc<Self>,
707 connection: Connection,
708 address: String,
709 principal: Principal,
710 zed_version: ZedVersion,
711 geoip_country_code: Option<String>,
712 system_id: Option<String>,
713 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
714 executor: Executor,
715 ) -> impl Future<Output = ()> {
716 let this = self.clone();
717 let span = info_span!("handle connection", %address,
718 connection_id=field::Empty,
719 user_id=field::Empty,
720 login=field::Empty,
721 impersonator=field::Empty,
722 geoip_country_code=field::Empty
723 );
724 principal.update_span(&span);
725 if let Some(country_code) = geoip_country_code.as_ref() {
726 span.record("geoip_country_code", country_code);
727 }
728
729 let mut teardown = self.teardown.subscribe();
730 async move {
731 if *teardown.borrow() {
732 tracing::error!("server is tearing down");
733 return
734 }
735 let (connection_id, handle_io, mut incoming_rx) = this
736 .peer
737 .add_connection(connection, {
738 let executor = executor.clone();
739 move |duration| executor.sleep(duration)
740 });
741 tracing::Span::current().record("connection_id", format!("{}", connection_id));
742
743 tracing::info!("connection opened");
744
745 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
746 let http_client = match ReqwestClient::user_agent(&user_agent) {
747 Ok(http_client) => Arc::new(http_client),
748 Err(error) => {
749 tracing::error!(?error, "failed to create HTTP client");
750 return;
751 }
752 };
753
754 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
755 supermaven_admin_api_key.to_string(),
756 http_client.clone(),
757 )));
758
759 let session = Session {
760 principal: principal.clone(),
761 connection_id,
762 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
763 peer: this.peer.clone(),
764 connection_pool: this.connection_pool.clone(),
765 app_state: this.app_state.clone(),
766 http_client,
767 geoip_country_code,
768 system_id,
769 _executor: executor.clone(),
770 supermaven_client,
771 };
772
773 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
774 tracing::error!(?error, "failed to send initial client update");
775 return;
776 }
777
778 let handle_io = handle_io.fuse();
779 futures::pin_mut!(handle_io);
780
781 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
782 // This prevents deadlocks when e.g., client A performs a request to client B and
783 // client B performs a request to client A. If both clients stop processing further
784 // messages until their respective request completes, they won't have a chance to
785 // respond to the other client's request and cause a deadlock.
786 //
787 // This arrangement ensures we will attempt to process earlier messages first, but fall
788 // back to processing messages arrived later in the spirit of making progress.
789 let mut foreground_message_handlers = FuturesUnordered::new();
790 let concurrent_handlers = Arc::new(Semaphore::new(256));
791 loop {
792 let next_message = async {
793 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
794 let message = incoming_rx.next().await;
795 (permit, message)
796 }.fuse();
797 futures::pin_mut!(next_message);
798 futures::select_biased! {
799 _ = teardown.changed().fuse() => return,
800 result = handle_io => {
801 if let Err(error) = result {
802 tracing::error!(?error, "error handling I/O");
803 }
804 break;
805 }
806 _ = foreground_message_handlers.next() => {}
807 next_message = next_message => {
808 let (permit, message) = next_message;
809 if let Some(message) = message {
810 let type_name = message.payload_type_name();
811 // note: we copy all the fields from the parent span so we can query them in the logs.
812 // (https://github.com/tokio-rs/tracing/issues/2670).
813 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
814 user_id=field::Empty,
815 login=field::Empty,
816 impersonator=field::Empty,
817 );
818 principal.update_span(&span);
819 let span_enter = span.enter();
820 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
821 let is_background = message.is_background();
822 let handle_message = (handler)(message, session.clone());
823 drop(span_enter);
824
825 let handle_message = async move {
826 handle_message.await;
827 drop(permit);
828 }.instrument(span);
829 if is_background {
830 executor.spawn_detached(handle_message);
831 } else {
832 foreground_message_handlers.push(handle_message);
833 }
834 } else {
835 tracing::error!("no message handler");
836 }
837 } else {
838 tracing::info!("connection closed");
839 break;
840 }
841 }
842 }
843 }
844
845 drop(foreground_message_handlers);
846 tracing::info!("signing out");
847 if let Err(error) = connection_lost(session, teardown, executor).await {
848 tracing::error!(?error, "error signing out");
849 }
850
851 }.instrument(span)
852 }
853
854 async fn send_initial_client_update(
855 &self,
856 connection_id: ConnectionId,
857 principal: &Principal,
858 zed_version: ZedVersion,
859 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
860 session: &Session,
861 ) -> Result<()> {
862 self.peer.send(
863 connection_id,
864 proto::Hello {
865 peer_id: Some(connection_id.into()),
866 },
867 )?;
868 tracing::info!("sent hello message");
869 if let Some(send_connection_id) = send_connection_id.take() {
870 let _ = send_connection_id.send(connection_id);
871 }
872
873 match principal {
874 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
875 if !user.connected_once {
876 self.peer.send(connection_id, proto::ShowContacts {})?;
877 self.app_state
878 .db
879 .set_user_connected_once(user.id, true)
880 .await?;
881 }
882
883 update_user_plan(user.id, session).await?;
884
885 let contacts = self.app_state.db.get_contacts(user.id).await?;
886
887 {
888 let mut pool = self.connection_pool.lock();
889 pool.add_connection(connection_id, user.id, user.admin, zed_version);
890 self.peer.send(
891 connection_id,
892 build_initial_contacts_update(contacts, &pool),
893 )?;
894 }
895
896 if should_auto_subscribe_to_channels(zed_version) {
897 subscribe_user_to_channels(user.id, session).await?;
898 }
899
900 if let Some(incoming_call) =
901 self.app_state.db.incoming_call_for_user(user.id).await?
902 {
903 self.peer.send(connection_id, incoming_call)?;
904 }
905
906 update_user_contacts(user.id, session).await?;
907 }
908 }
909
910 Ok(())
911 }
912
913 pub async fn invite_code_redeemed(
914 self: &Arc<Self>,
915 inviter_id: UserId,
916 invitee_id: UserId,
917 ) -> Result<()> {
918 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
919 if let Some(code) = &user.invite_code {
920 let pool = self.connection_pool.lock();
921 let invitee_contact = contact_for_user(invitee_id, false, &pool);
922 for connection_id in pool.user_connection_ids(inviter_id) {
923 self.peer.send(
924 connection_id,
925 proto::UpdateContacts {
926 contacts: vec![invitee_contact.clone()],
927 ..Default::default()
928 },
929 )?;
930 self.peer.send(
931 connection_id,
932 proto::UpdateInviteInfo {
933 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
934 count: user.invite_count as u32,
935 },
936 )?;
937 }
938 }
939 }
940 Ok(())
941 }
942
943 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
944 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
945 if let Some(invite_code) = &user.invite_code {
946 let pool = self.connection_pool.lock();
947 for connection_id in pool.user_connection_ids(user_id) {
948 self.peer.send(
949 connection_id,
950 proto::UpdateInviteInfo {
951 url: format!(
952 "{}{}",
953 self.app_state.config.invite_link_prefix, invite_code
954 ),
955 count: user.invite_count as u32,
956 },
957 )?;
958 }
959 }
960 }
961 Ok(())
962 }
963
964 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
965 let pool = self.connection_pool.lock();
966 for connection_id in pool.user_connection_ids(user_id) {
967 self.peer
968 .send(connection_id, proto::RefreshLlmToken {})
969 .trace_err();
970 }
971 }
972
973 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
974 ServerSnapshot {
975 connection_pool: ConnectionPoolGuard {
976 guard: self.connection_pool.lock(),
977 _not_send: PhantomData,
978 },
979 peer: &self.peer,
980 }
981 }
982}
983
984impl Deref for ConnectionPoolGuard<'_> {
985 type Target = ConnectionPool;
986
987 fn deref(&self) -> &Self::Target {
988 &self.guard
989 }
990}
991
992impl DerefMut for ConnectionPoolGuard<'_> {
993 fn deref_mut(&mut self) -> &mut Self::Target {
994 &mut self.guard
995 }
996}
997
998impl Drop for ConnectionPoolGuard<'_> {
999 fn drop(&mut self) {
1000 #[cfg(test)]
1001 self.check_invariants();
1002 }
1003}
1004
1005fn broadcast<F>(
1006 sender_id: Option<ConnectionId>,
1007 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1008 mut f: F,
1009) where
1010 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1011{
1012 for receiver_id in receiver_ids {
1013 if Some(receiver_id) != sender_id {
1014 if let Err(error) = f(receiver_id) {
1015 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1016 }
1017 }
1018 }
1019}
1020
1021pub struct ProtocolVersion(u32);
1022
1023impl Header for ProtocolVersion {
1024 fn name() -> &'static HeaderName {
1025 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1026 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1027 }
1028
1029 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1030 where
1031 Self: Sized,
1032 I: Iterator<Item = &'i axum::http::HeaderValue>,
1033 {
1034 let version = values
1035 .next()
1036 .ok_or_else(axum::headers::Error::invalid)?
1037 .to_str()
1038 .map_err(|_| axum::headers::Error::invalid())?
1039 .parse()
1040 .map_err(|_| axum::headers::Error::invalid())?;
1041 Ok(Self(version))
1042 }
1043
1044 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1045 values.extend([self.0.to_string().parse().unwrap()]);
1046 }
1047}
1048
1049pub struct AppVersionHeader(SemanticVersion);
1050impl Header for AppVersionHeader {
1051 fn name() -> &'static HeaderName {
1052 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1053 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1054 }
1055
1056 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1057 where
1058 Self: Sized,
1059 I: Iterator<Item = &'i axum::http::HeaderValue>,
1060 {
1061 let version = values
1062 .next()
1063 .ok_or_else(axum::headers::Error::invalid)?
1064 .to_str()
1065 .map_err(|_| axum::headers::Error::invalid())?
1066 .parse()
1067 .map_err(|_| axum::headers::Error::invalid())?;
1068 Ok(Self(version))
1069 }
1070
1071 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1072 values.extend([self.0.to_string().parse().unwrap()]);
1073 }
1074}
1075
1076pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1077 Router::new()
1078 .route("/rpc", get(handle_websocket_request))
1079 .layer(
1080 ServiceBuilder::new()
1081 .layer(Extension(server.app_state.clone()))
1082 .layer(middleware::from_fn(auth::validate_header)),
1083 )
1084 .route("/metrics", get(handle_metrics))
1085 .layer(Extension(server))
1086}
1087
1088pub async fn handle_websocket_request(
1089 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1090 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1091 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1092 Extension(server): Extension<Arc<Server>>,
1093 Extension(principal): Extension<Principal>,
1094 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1095 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1096 ws: WebSocketUpgrade,
1097) -> axum::response::Response {
1098 if protocol_version != rpc::PROTOCOL_VERSION {
1099 return (
1100 StatusCode::UPGRADE_REQUIRED,
1101 "client must be upgraded".to_string(),
1102 )
1103 .into_response();
1104 }
1105
1106 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1107 return (
1108 StatusCode::UPGRADE_REQUIRED,
1109 "no version header found".to_string(),
1110 )
1111 .into_response();
1112 };
1113
1114 if !version.can_collaborate() {
1115 return (
1116 StatusCode::UPGRADE_REQUIRED,
1117 "client must be upgraded".to_string(),
1118 )
1119 .into_response();
1120 }
1121
1122 let socket_address = socket_address.to_string();
1123 ws.on_upgrade(move |socket| {
1124 let socket = socket
1125 .map_ok(to_tungstenite_message)
1126 .err_into()
1127 .with(|message| async move { to_axum_message(message) });
1128 let connection = Connection::new(Box::pin(socket));
1129 async move {
1130 server
1131 .handle_connection(
1132 connection,
1133 socket_address,
1134 principal,
1135 version,
1136 country_code_header.map(|header| header.to_string()),
1137 system_id_header.map(|header| header.to_string()),
1138 None,
1139 Executor::Production,
1140 )
1141 .await;
1142 }
1143 })
1144}
1145
1146pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1147 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1148 let connections_metric = CONNECTIONS_METRIC
1149 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1150
1151 let connections = server
1152 .connection_pool
1153 .lock()
1154 .connections()
1155 .filter(|connection| !connection.admin)
1156 .count();
1157 connections_metric.set(connections as _);
1158
1159 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1160 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1161 register_int_gauge!(
1162 "shared_projects",
1163 "number of open projects with one or more guests"
1164 )
1165 .unwrap()
1166 });
1167
1168 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1169 shared_projects_metric.set(shared_projects as _);
1170
1171 let encoder = prometheus::TextEncoder::new();
1172 let metric_families = prometheus::gather();
1173 let encoded_metrics = encoder
1174 .encode_to_string(&metric_families)
1175 .map_err(|err| anyhow!("{}", err))?;
1176 Ok(encoded_metrics)
1177}
1178
1179#[instrument(err, skip(executor))]
1180async fn connection_lost(
1181 session: Session,
1182 mut teardown: watch::Receiver<bool>,
1183 executor: Executor,
1184) -> Result<()> {
1185 session.peer.disconnect(session.connection_id);
1186 session
1187 .connection_pool()
1188 .await
1189 .remove_connection(session.connection_id)?;
1190
1191 session
1192 .db()
1193 .await
1194 .connection_lost(session.connection_id)
1195 .await
1196 .trace_err();
1197
1198 futures::select_biased! {
1199 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1200
1201 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1202 leave_room_for_session(&session, session.connection_id).await.trace_err();
1203 leave_channel_buffers_for_session(&session)
1204 .await
1205 .trace_err();
1206
1207 if !session
1208 .connection_pool()
1209 .await
1210 .is_user_online(session.user_id())
1211 {
1212 let db = session.db().await;
1213 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1214 room_updated(&room, &session.peer);
1215 }
1216 }
1217
1218 update_user_contacts(session.user_id(), &session).await?;
1219 },
1220 _ = teardown.changed().fuse() => {}
1221 }
1222
1223 Ok(())
1224}
1225
1226/// Acknowledges a ping from a client, used to keep the connection alive.
1227async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1228 response.send(proto::Ack {})?;
1229 Ok(())
1230}
1231
1232/// Creates a new room for calling (outside of channels)
1233async fn create_room(
1234 _request: proto::CreateRoom,
1235 response: Response<proto::CreateRoom>,
1236 session: Session,
1237) -> Result<()> {
1238 let livekit_room = nanoid::nanoid!(30);
1239
1240 let live_kit_connection_info = util::maybe!(async {
1241 let live_kit = session.app_state.livekit_client.as_ref();
1242 let live_kit = live_kit?;
1243 let user_id = session.user_id().to_string();
1244
1245 let token = live_kit
1246 .room_token(&livekit_room, &user_id.to_string())
1247 .trace_err()?;
1248
1249 Some(proto::LiveKitConnectionInfo {
1250 server_url: live_kit.url().into(),
1251 token,
1252 can_publish: true,
1253 })
1254 })
1255 .await;
1256
1257 let room = session
1258 .db()
1259 .await
1260 .create_room(session.user_id(), session.connection_id, &livekit_room)
1261 .await?;
1262
1263 response.send(proto::CreateRoomResponse {
1264 room: Some(room.clone()),
1265 live_kit_connection_info,
1266 })?;
1267
1268 update_user_contacts(session.user_id(), &session).await?;
1269 Ok(())
1270}
1271
1272/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1273async fn join_room(
1274 request: proto::JoinRoom,
1275 response: Response<proto::JoinRoom>,
1276 session: Session,
1277) -> Result<()> {
1278 let room_id = RoomId::from_proto(request.id);
1279
1280 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1281
1282 if let Some(channel_id) = channel_id {
1283 return join_channel_internal(channel_id, Box::new(response), session).await;
1284 }
1285
1286 let joined_room = {
1287 let room = session
1288 .db()
1289 .await
1290 .join_room(room_id, session.user_id(), session.connection_id)
1291 .await?;
1292 room_updated(&room.room, &session.peer);
1293 room.into_inner()
1294 };
1295
1296 for connection_id in session
1297 .connection_pool()
1298 .await
1299 .user_connection_ids(session.user_id())
1300 {
1301 session
1302 .peer
1303 .send(
1304 connection_id,
1305 proto::CallCanceled {
1306 room_id: room_id.to_proto(),
1307 },
1308 )
1309 .trace_err();
1310 }
1311
1312 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1313 {
1314 live_kit
1315 .room_token(
1316 &joined_room.room.livekit_room,
1317 &session.user_id().to_string(),
1318 )
1319 .trace_err()
1320 .map(|token| proto::LiveKitConnectionInfo {
1321 server_url: live_kit.url().into(),
1322 token,
1323 can_publish: true,
1324 })
1325 } else {
1326 None
1327 };
1328
1329 response.send(proto::JoinRoomResponse {
1330 room: Some(joined_room.room),
1331 channel_id: None,
1332 live_kit_connection_info,
1333 })?;
1334
1335 update_user_contacts(session.user_id(), &session).await?;
1336 Ok(())
1337}
1338
1339/// Rejoin room is used to reconnect to a room after connection errors.
1340async fn rejoin_room(
1341 request: proto::RejoinRoom,
1342 response: Response<proto::RejoinRoom>,
1343 session: Session,
1344) -> Result<()> {
1345 let room;
1346 let channel;
1347 {
1348 let mut rejoined_room = session
1349 .db()
1350 .await
1351 .rejoin_room(request, session.user_id(), session.connection_id)
1352 .await?;
1353
1354 response.send(proto::RejoinRoomResponse {
1355 room: Some(rejoined_room.room.clone()),
1356 reshared_projects: rejoined_room
1357 .reshared_projects
1358 .iter()
1359 .map(|project| proto::ResharedProject {
1360 id: project.id.to_proto(),
1361 collaborators: project
1362 .collaborators
1363 .iter()
1364 .map(|collaborator| collaborator.to_proto())
1365 .collect(),
1366 })
1367 .collect(),
1368 rejoined_projects: rejoined_room
1369 .rejoined_projects
1370 .iter()
1371 .map(|rejoined_project| rejoined_project.to_proto())
1372 .collect(),
1373 })?;
1374 room_updated(&rejoined_room.room, &session.peer);
1375
1376 for project in &rejoined_room.reshared_projects {
1377 for collaborator in &project.collaborators {
1378 session
1379 .peer
1380 .send(
1381 collaborator.connection_id,
1382 proto::UpdateProjectCollaborator {
1383 project_id: project.id.to_proto(),
1384 old_peer_id: Some(project.old_connection_id.into()),
1385 new_peer_id: Some(session.connection_id.into()),
1386 },
1387 )
1388 .trace_err();
1389 }
1390
1391 broadcast(
1392 Some(session.connection_id),
1393 project
1394 .collaborators
1395 .iter()
1396 .map(|collaborator| collaborator.connection_id),
1397 |connection_id| {
1398 session.peer.forward_send(
1399 session.connection_id,
1400 connection_id,
1401 proto::UpdateProject {
1402 project_id: project.id.to_proto(),
1403 worktrees: project.worktrees.clone(),
1404 },
1405 )
1406 },
1407 );
1408 }
1409
1410 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1411
1412 let rejoined_room = rejoined_room.into_inner();
1413
1414 room = rejoined_room.room;
1415 channel = rejoined_room.channel;
1416 }
1417
1418 if let Some(channel) = channel {
1419 channel_updated(
1420 &channel,
1421 &room,
1422 &session.peer,
1423 &*session.connection_pool().await,
1424 );
1425 }
1426
1427 update_user_contacts(session.user_id(), &session).await?;
1428 Ok(())
1429}
1430
1431fn notify_rejoined_projects(
1432 rejoined_projects: &mut Vec<RejoinedProject>,
1433 session: &Session,
1434) -> Result<()> {
1435 for project in rejoined_projects.iter() {
1436 for collaborator in &project.collaborators {
1437 session
1438 .peer
1439 .send(
1440 collaborator.connection_id,
1441 proto::UpdateProjectCollaborator {
1442 project_id: project.id.to_proto(),
1443 old_peer_id: Some(project.old_connection_id.into()),
1444 new_peer_id: Some(session.connection_id.into()),
1445 },
1446 )
1447 .trace_err();
1448 }
1449 }
1450
1451 for project in rejoined_projects {
1452 for worktree in mem::take(&mut project.worktrees) {
1453 // Stream this worktree's entries.
1454 let message = proto::UpdateWorktree {
1455 project_id: project.id.to_proto(),
1456 worktree_id: worktree.id,
1457 abs_path: worktree.abs_path.clone(),
1458 root_name: worktree.root_name,
1459 updated_entries: worktree.updated_entries,
1460 removed_entries: worktree.removed_entries,
1461 scan_id: worktree.scan_id,
1462 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1463 updated_repositories: worktree.updated_repositories,
1464 removed_repositories: worktree.removed_repositories,
1465 };
1466 for update in proto::split_worktree_update(message) {
1467 session.peer.send(session.connection_id, update.clone())?;
1468 }
1469
1470 // Stream this worktree's diagnostics.
1471 for summary in worktree.diagnostic_summaries {
1472 session.peer.send(
1473 session.connection_id,
1474 proto::UpdateDiagnosticSummary {
1475 project_id: project.id.to_proto(),
1476 worktree_id: worktree.id,
1477 summary: Some(summary),
1478 },
1479 )?;
1480 }
1481
1482 for settings_file in worktree.settings_files {
1483 session.peer.send(
1484 session.connection_id,
1485 proto::UpdateWorktreeSettings {
1486 project_id: project.id.to_proto(),
1487 worktree_id: worktree.id,
1488 path: settings_file.path,
1489 content: Some(settings_file.content),
1490 kind: Some(settings_file.kind.to_proto().into()),
1491 },
1492 )?;
1493 }
1494 }
1495
1496 for language_server in &project.language_servers {
1497 session.peer.send(
1498 session.connection_id,
1499 proto::UpdateLanguageServer {
1500 project_id: project.id.to_proto(),
1501 language_server_id: language_server.id,
1502 variant: Some(
1503 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1504 proto::LspDiskBasedDiagnosticsUpdated {},
1505 ),
1506 ),
1507 },
1508 )?;
1509 }
1510 }
1511 Ok(())
1512}
1513
1514/// leave room disconnects from the room.
1515async fn leave_room(
1516 _: proto::LeaveRoom,
1517 response: Response<proto::LeaveRoom>,
1518 session: Session,
1519) -> Result<()> {
1520 leave_room_for_session(&session, session.connection_id).await?;
1521 response.send(proto::Ack {})?;
1522 Ok(())
1523}
1524
1525/// Updates the permissions of someone else in the room.
1526async fn set_room_participant_role(
1527 request: proto::SetRoomParticipantRole,
1528 response: Response<proto::SetRoomParticipantRole>,
1529 session: Session,
1530) -> Result<()> {
1531 let user_id = UserId::from_proto(request.user_id);
1532 let role = ChannelRole::from(request.role());
1533
1534 let (livekit_room, can_publish) = {
1535 let room = session
1536 .db()
1537 .await
1538 .set_room_participant_role(
1539 session.user_id(),
1540 RoomId::from_proto(request.room_id),
1541 user_id,
1542 role,
1543 )
1544 .await?;
1545
1546 let livekit_room = room.livekit_room.clone();
1547 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1548 room_updated(&room, &session.peer);
1549 (livekit_room, can_publish)
1550 };
1551
1552 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1553 live_kit
1554 .update_participant(
1555 livekit_room.clone(),
1556 request.user_id.to_string(),
1557 livekit_api::proto::ParticipantPermission {
1558 can_subscribe: true,
1559 can_publish,
1560 can_publish_data: can_publish,
1561 hidden: false,
1562 recorder: false,
1563 },
1564 )
1565 .await
1566 .trace_err();
1567 }
1568
1569 response.send(proto::Ack {})?;
1570 Ok(())
1571}
1572
1573/// Call someone else into the current room
1574async fn call(
1575 request: proto::Call,
1576 response: Response<proto::Call>,
1577 session: Session,
1578) -> Result<()> {
1579 let room_id = RoomId::from_proto(request.room_id);
1580 let calling_user_id = session.user_id();
1581 let calling_connection_id = session.connection_id;
1582 let called_user_id = UserId::from_proto(request.called_user_id);
1583 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1584 if !session
1585 .db()
1586 .await
1587 .has_contact(calling_user_id, called_user_id)
1588 .await?
1589 {
1590 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1591 }
1592
1593 let incoming_call = {
1594 let (room, incoming_call) = &mut *session
1595 .db()
1596 .await
1597 .call(
1598 room_id,
1599 calling_user_id,
1600 calling_connection_id,
1601 called_user_id,
1602 initial_project_id,
1603 )
1604 .await?;
1605 room_updated(room, &session.peer);
1606 mem::take(incoming_call)
1607 };
1608 update_user_contacts(called_user_id, &session).await?;
1609
1610 let mut calls = session
1611 .connection_pool()
1612 .await
1613 .user_connection_ids(called_user_id)
1614 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1615 .collect::<FuturesUnordered<_>>();
1616
1617 while let Some(call_response) = calls.next().await {
1618 match call_response.as_ref() {
1619 Ok(_) => {
1620 response.send(proto::Ack {})?;
1621 return Ok(());
1622 }
1623 Err(_) => {
1624 call_response.trace_err();
1625 }
1626 }
1627 }
1628
1629 {
1630 let room = session
1631 .db()
1632 .await
1633 .call_failed(room_id, called_user_id)
1634 .await?;
1635 room_updated(&room, &session.peer);
1636 }
1637 update_user_contacts(called_user_id, &session).await?;
1638
1639 Err(anyhow!("failed to ring user"))?
1640}
1641
1642/// Cancel an outgoing call.
1643async fn cancel_call(
1644 request: proto::CancelCall,
1645 response: Response<proto::CancelCall>,
1646 session: Session,
1647) -> Result<()> {
1648 let called_user_id = UserId::from_proto(request.called_user_id);
1649 let room_id = RoomId::from_proto(request.room_id);
1650 {
1651 let room = session
1652 .db()
1653 .await
1654 .cancel_call(room_id, session.connection_id, called_user_id)
1655 .await?;
1656 room_updated(&room, &session.peer);
1657 }
1658
1659 for connection_id in session
1660 .connection_pool()
1661 .await
1662 .user_connection_ids(called_user_id)
1663 {
1664 session
1665 .peer
1666 .send(
1667 connection_id,
1668 proto::CallCanceled {
1669 room_id: room_id.to_proto(),
1670 },
1671 )
1672 .trace_err();
1673 }
1674 response.send(proto::Ack {})?;
1675
1676 update_user_contacts(called_user_id, &session).await?;
1677 Ok(())
1678}
1679
1680/// Decline an incoming call.
1681async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1682 let room_id = RoomId::from_proto(message.room_id);
1683 {
1684 let room = session
1685 .db()
1686 .await
1687 .decline_call(Some(room_id), session.user_id())
1688 .await?
1689 .ok_or_else(|| anyhow!("failed to decline call"))?;
1690 room_updated(&room, &session.peer);
1691 }
1692
1693 for connection_id in session
1694 .connection_pool()
1695 .await
1696 .user_connection_ids(session.user_id())
1697 {
1698 session
1699 .peer
1700 .send(
1701 connection_id,
1702 proto::CallCanceled {
1703 room_id: room_id.to_proto(),
1704 },
1705 )
1706 .trace_err();
1707 }
1708 update_user_contacts(session.user_id(), &session).await?;
1709 Ok(())
1710}
1711
1712/// Updates other participants in the room with your current location.
1713async fn update_participant_location(
1714 request: proto::UpdateParticipantLocation,
1715 response: Response<proto::UpdateParticipantLocation>,
1716 session: Session,
1717) -> Result<()> {
1718 let room_id = RoomId::from_proto(request.room_id);
1719 let location = request
1720 .location
1721 .ok_or_else(|| anyhow!("invalid location"))?;
1722
1723 let db = session.db().await;
1724 let room = db
1725 .update_room_participant_location(room_id, session.connection_id, location)
1726 .await?;
1727
1728 room_updated(&room, &session.peer);
1729 response.send(proto::Ack {})?;
1730 Ok(())
1731}
1732
1733/// Share a project into the room.
1734async fn share_project(
1735 request: proto::ShareProject,
1736 response: Response<proto::ShareProject>,
1737 session: Session,
1738) -> Result<()> {
1739 let (project_id, room) = &*session
1740 .db()
1741 .await
1742 .share_project(
1743 RoomId::from_proto(request.room_id),
1744 session.connection_id,
1745 &request.worktrees,
1746 request.is_ssh_project,
1747 )
1748 .await?;
1749 response.send(proto::ShareProjectResponse {
1750 project_id: project_id.to_proto(),
1751 })?;
1752 room_updated(room, &session.peer);
1753
1754 Ok(())
1755}
1756
1757/// Unshare a project from the room.
1758async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1759 let project_id = ProjectId::from_proto(message.project_id);
1760 unshare_project_internal(project_id, session.connection_id, &session).await
1761}
1762
1763async fn unshare_project_internal(
1764 project_id: ProjectId,
1765 connection_id: ConnectionId,
1766 session: &Session,
1767) -> Result<()> {
1768 let delete = {
1769 let room_guard = session
1770 .db()
1771 .await
1772 .unshare_project(project_id, connection_id)
1773 .await?;
1774
1775 let (delete, room, guest_connection_ids) = &*room_guard;
1776
1777 let message = proto::UnshareProject {
1778 project_id: project_id.to_proto(),
1779 };
1780
1781 broadcast(
1782 Some(connection_id),
1783 guest_connection_ids.iter().copied(),
1784 |conn_id| session.peer.send(conn_id, message.clone()),
1785 );
1786 if let Some(room) = room {
1787 room_updated(room, &session.peer);
1788 }
1789
1790 *delete
1791 };
1792
1793 if delete {
1794 let db = session.db().await;
1795 db.delete_project(project_id).await?;
1796 }
1797
1798 Ok(())
1799}
1800
1801/// Join someone elses shared project.
1802async fn join_project(
1803 request: proto::JoinProject,
1804 response: Response<proto::JoinProject>,
1805 session: Session,
1806) -> Result<()> {
1807 let project_id = ProjectId::from_proto(request.project_id);
1808
1809 tracing::info!(%project_id, "join project");
1810
1811 let db = session.db().await;
1812 let (project, replica_id) = &mut *db
1813 .join_project(project_id, session.connection_id, session.user_id())
1814 .await?;
1815 drop(db);
1816 tracing::info!(%project_id, "join remote project");
1817 join_project_internal(response, session, project, replica_id)
1818}
1819
1820trait JoinProjectInternalResponse {
1821 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1822}
1823impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1824 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1825 Response::<proto::JoinProject>::send(self, result)
1826 }
1827}
1828
1829fn join_project_internal(
1830 response: impl JoinProjectInternalResponse,
1831 session: Session,
1832 project: &mut Project,
1833 replica_id: &ReplicaId,
1834) -> Result<()> {
1835 let collaborators = project
1836 .collaborators
1837 .iter()
1838 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1839 .map(|collaborator| collaborator.to_proto())
1840 .collect::<Vec<_>>();
1841 let project_id = project.id;
1842 let guest_user_id = session.user_id();
1843
1844 let worktrees = project
1845 .worktrees
1846 .iter()
1847 .map(|(id, worktree)| proto::WorktreeMetadata {
1848 id: *id,
1849 root_name: worktree.root_name.clone(),
1850 visible: worktree.visible,
1851 abs_path: worktree.abs_path.clone(),
1852 })
1853 .collect::<Vec<_>>();
1854
1855 let add_project_collaborator = proto::AddProjectCollaborator {
1856 project_id: project_id.to_proto(),
1857 collaborator: Some(proto::Collaborator {
1858 peer_id: Some(session.connection_id.into()),
1859 replica_id: replica_id.0 as u32,
1860 user_id: guest_user_id.to_proto(),
1861 is_host: false,
1862 }),
1863 };
1864
1865 for collaborator in &collaborators {
1866 session
1867 .peer
1868 .send(
1869 collaborator.peer_id.unwrap().into(),
1870 add_project_collaborator.clone(),
1871 )
1872 .trace_err();
1873 }
1874
1875 // First, we send the metadata associated with each worktree.
1876 response.send(proto::JoinProjectResponse {
1877 project_id: project.id.0 as u64,
1878 worktrees: worktrees.clone(),
1879 replica_id: replica_id.0 as u32,
1880 collaborators: collaborators.clone(),
1881 language_servers: project.language_servers.clone(),
1882 role: project.role.into(),
1883 })?;
1884
1885 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1886 // Stream this worktree's entries.
1887 let message = proto::UpdateWorktree {
1888 project_id: project_id.to_proto(),
1889 worktree_id,
1890 abs_path: worktree.abs_path.clone(),
1891 root_name: worktree.root_name,
1892 updated_entries: worktree.entries,
1893 removed_entries: Default::default(),
1894 scan_id: worktree.scan_id,
1895 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1896 updated_repositories: worktree.repository_entries.into_values().collect(),
1897 removed_repositories: Default::default(),
1898 };
1899 for update in proto::split_worktree_update(message) {
1900 session.peer.send(session.connection_id, update.clone())?;
1901 }
1902
1903 // Stream this worktree's diagnostics.
1904 for summary in worktree.diagnostic_summaries {
1905 session.peer.send(
1906 session.connection_id,
1907 proto::UpdateDiagnosticSummary {
1908 project_id: project_id.to_proto(),
1909 worktree_id: worktree.id,
1910 summary: Some(summary),
1911 },
1912 )?;
1913 }
1914
1915 for settings_file in worktree.settings_files {
1916 session.peer.send(
1917 session.connection_id,
1918 proto::UpdateWorktreeSettings {
1919 project_id: project_id.to_proto(),
1920 worktree_id: worktree.id,
1921 path: settings_file.path,
1922 content: Some(settings_file.content),
1923 kind: Some(settings_file.kind.to_proto() as i32),
1924 },
1925 )?;
1926 }
1927 }
1928
1929 for language_server in &project.language_servers {
1930 session.peer.send(
1931 session.connection_id,
1932 proto::UpdateLanguageServer {
1933 project_id: project_id.to_proto(),
1934 language_server_id: language_server.id,
1935 variant: Some(
1936 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1937 proto::LspDiskBasedDiagnosticsUpdated {},
1938 ),
1939 ),
1940 },
1941 )?;
1942 }
1943
1944 Ok(())
1945}
1946
1947/// Leave someone elses shared project.
1948async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1949 let sender_id = session.connection_id;
1950 let project_id = ProjectId::from_proto(request.project_id);
1951 let db = session.db().await;
1952
1953 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1954 tracing::info!(
1955 %project_id,
1956 "leave project"
1957 );
1958
1959 project_left(project, &session);
1960 if let Some(room) = room {
1961 room_updated(room, &session.peer);
1962 }
1963
1964 Ok(())
1965}
1966
1967/// Updates other participants with changes to the project
1968async fn update_project(
1969 request: proto::UpdateProject,
1970 response: Response<proto::UpdateProject>,
1971 session: Session,
1972) -> Result<()> {
1973 let project_id = ProjectId::from_proto(request.project_id);
1974 let (room, guest_connection_ids) = &*session
1975 .db()
1976 .await
1977 .update_project(project_id, session.connection_id, &request.worktrees)
1978 .await?;
1979 broadcast(
1980 Some(session.connection_id),
1981 guest_connection_ids.iter().copied(),
1982 |connection_id| {
1983 session
1984 .peer
1985 .forward_send(session.connection_id, connection_id, request.clone())
1986 },
1987 );
1988 if let Some(room) = room {
1989 room_updated(room, &session.peer);
1990 }
1991 response.send(proto::Ack {})?;
1992
1993 Ok(())
1994}
1995
1996/// Updates other participants with changes to the worktree
1997async fn update_worktree(
1998 request: proto::UpdateWorktree,
1999 response: Response<proto::UpdateWorktree>,
2000 session: Session,
2001) -> Result<()> {
2002 let guest_connection_ids = session
2003 .db()
2004 .await
2005 .update_worktree(&request, session.connection_id)
2006 .await?;
2007
2008 broadcast(
2009 Some(session.connection_id),
2010 guest_connection_ids.iter().copied(),
2011 |connection_id| {
2012 session
2013 .peer
2014 .forward_send(session.connection_id, connection_id, request.clone())
2015 },
2016 );
2017 response.send(proto::Ack {})?;
2018 Ok(())
2019}
2020
2021/// Updates other participants with changes to the diagnostics
2022async fn update_diagnostic_summary(
2023 message: proto::UpdateDiagnosticSummary,
2024 session: Session,
2025) -> Result<()> {
2026 let guest_connection_ids = session
2027 .db()
2028 .await
2029 .update_diagnostic_summary(&message, session.connection_id)
2030 .await?;
2031
2032 broadcast(
2033 Some(session.connection_id),
2034 guest_connection_ids.iter().copied(),
2035 |connection_id| {
2036 session
2037 .peer
2038 .forward_send(session.connection_id, connection_id, message.clone())
2039 },
2040 );
2041
2042 Ok(())
2043}
2044
2045/// Updates other participants with changes to the worktree settings
2046async fn update_worktree_settings(
2047 message: proto::UpdateWorktreeSettings,
2048 session: Session,
2049) -> Result<()> {
2050 let guest_connection_ids = session
2051 .db()
2052 .await
2053 .update_worktree_settings(&message, session.connection_id)
2054 .await?;
2055
2056 broadcast(
2057 Some(session.connection_id),
2058 guest_connection_ids.iter().copied(),
2059 |connection_id| {
2060 session
2061 .peer
2062 .forward_send(session.connection_id, connection_id, message.clone())
2063 },
2064 );
2065
2066 Ok(())
2067}
2068
2069/// Notify other participants that a language server has started.
2070async fn start_language_server(
2071 request: proto::StartLanguageServer,
2072 session: Session,
2073) -> Result<()> {
2074 let guest_connection_ids = session
2075 .db()
2076 .await
2077 .start_language_server(&request, session.connection_id)
2078 .await?;
2079
2080 broadcast(
2081 Some(session.connection_id),
2082 guest_connection_ids.iter().copied(),
2083 |connection_id| {
2084 session
2085 .peer
2086 .forward_send(session.connection_id, connection_id, request.clone())
2087 },
2088 );
2089 Ok(())
2090}
2091
2092/// Notify other participants that a language server has changed.
2093async fn update_language_server(
2094 request: proto::UpdateLanguageServer,
2095 session: Session,
2096) -> Result<()> {
2097 let project_id = ProjectId::from_proto(request.project_id);
2098 let project_connection_ids = session
2099 .db()
2100 .await
2101 .project_connection_ids(project_id, session.connection_id, true)
2102 .await?;
2103 broadcast(
2104 Some(session.connection_id),
2105 project_connection_ids.iter().copied(),
2106 |connection_id| {
2107 session
2108 .peer
2109 .forward_send(session.connection_id, connection_id, request.clone())
2110 },
2111 );
2112 Ok(())
2113}
2114
2115/// forward a project request to the host. These requests should be read only
2116/// as guests are allowed to send them.
2117async fn forward_read_only_project_request<T>(
2118 request: T,
2119 response: Response<T>,
2120 session: Session,
2121) -> Result<()>
2122where
2123 T: EntityMessage + RequestMessage,
2124{
2125 let project_id = ProjectId::from_proto(request.remote_entity_id());
2126 let host_connection_id = session
2127 .db()
2128 .await
2129 .host_for_read_only_project_request(project_id, session.connection_id)
2130 .await?;
2131 let payload = session
2132 .peer
2133 .forward_request(session.connection_id, host_connection_id, request)
2134 .await?;
2135 response.send(payload)?;
2136 Ok(())
2137}
2138
2139async fn forward_find_search_candidates_request(
2140 request: proto::FindSearchCandidates,
2141 response: Response<proto::FindSearchCandidates>,
2142 session: Session,
2143) -> Result<()> {
2144 let project_id = ProjectId::from_proto(request.remote_entity_id());
2145 let host_connection_id = session
2146 .db()
2147 .await
2148 .host_for_read_only_project_request(project_id, session.connection_id)
2149 .await?;
2150 let payload = session
2151 .peer
2152 .forward_request(session.connection_id, host_connection_id, request)
2153 .await?;
2154 response.send(payload)?;
2155 Ok(())
2156}
2157
2158/// forward a project request to the host. These requests are disallowed
2159/// for guests.
2160async fn forward_mutating_project_request<T>(
2161 request: T,
2162 response: Response<T>,
2163 session: Session,
2164) -> Result<()>
2165where
2166 T: EntityMessage + RequestMessage,
2167{
2168 let project_id = ProjectId::from_proto(request.remote_entity_id());
2169
2170 let host_connection_id = session
2171 .db()
2172 .await
2173 .host_for_mutating_project_request(project_id, session.connection_id)
2174 .await?;
2175 let payload = session
2176 .peer
2177 .forward_request(session.connection_id, host_connection_id, request)
2178 .await?;
2179 response.send(payload)?;
2180 Ok(())
2181}
2182
2183/// Notify other participants that a new buffer has been created
2184async fn create_buffer_for_peer(
2185 request: proto::CreateBufferForPeer,
2186 session: Session,
2187) -> Result<()> {
2188 session
2189 .db()
2190 .await
2191 .check_user_is_project_host(
2192 ProjectId::from_proto(request.project_id),
2193 session.connection_id,
2194 )
2195 .await?;
2196 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2197 session
2198 .peer
2199 .forward_send(session.connection_id, peer_id.into(), request)?;
2200 Ok(())
2201}
2202
2203/// Notify other participants that a buffer has been updated. This is
2204/// allowed for guests as long as the update is limited to selections.
2205async fn update_buffer(
2206 request: proto::UpdateBuffer,
2207 response: Response<proto::UpdateBuffer>,
2208 session: Session,
2209) -> Result<()> {
2210 let project_id = ProjectId::from_proto(request.project_id);
2211 let mut capability = Capability::ReadOnly;
2212
2213 for op in request.operations.iter() {
2214 match op.variant {
2215 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2216 Some(_) => capability = Capability::ReadWrite,
2217 }
2218 }
2219
2220 let host = {
2221 let guard = session
2222 .db()
2223 .await
2224 .connections_for_buffer_update(project_id, session.connection_id, capability)
2225 .await?;
2226
2227 let (host, guests) = &*guard;
2228
2229 broadcast(
2230 Some(session.connection_id),
2231 guests.clone(),
2232 |connection_id| {
2233 session
2234 .peer
2235 .forward_send(session.connection_id, connection_id, request.clone())
2236 },
2237 );
2238
2239 *host
2240 };
2241
2242 if host != session.connection_id {
2243 session
2244 .peer
2245 .forward_request(session.connection_id, host, request.clone())
2246 .await?;
2247 }
2248
2249 response.send(proto::Ack {})?;
2250 Ok(())
2251}
2252
2253async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2254 let project_id = ProjectId::from_proto(message.project_id);
2255
2256 let operation = message.operation.as_ref().context("invalid operation")?;
2257 let capability = match operation.variant.as_ref() {
2258 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2259 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2260 match buffer_op.variant {
2261 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2262 Capability::ReadOnly
2263 }
2264 _ => Capability::ReadWrite,
2265 }
2266 } else {
2267 Capability::ReadWrite
2268 }
2269 }
2270 Some(_) => Capability::ReadWrite,
2271 None => Capability::ReadOnly,
2272 };
2273
2274 let guard = session
2275 .db()
2276 .await
2277 .connections_for_buffer_update(project_id, session.connection_id, capability)
2278 .await?;
2279
2280 let (host, guests) = &*guard;
2281
2282 broadcast(
2283 Some(session.connection_id),
2284 guests.iter().chain([host]).copied(),
2285 |connection_id| {
2286 session
2287 .peer
2288 .forward_send(session.connection_id, connection_id, message.clone())
2289 },
2290 );
2291
2292 Ok(())
2293}
2294
2295/// Notify other participants that a project has been updated.
2296async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2297 request: T,
2298 session: Session,
2299) -> Result<()> {
2300 let project_id = ProjectId::from_proto(request.remote_entity_id());
2301 let project_connection_ids = session
2302 .db()
2303 .await
2304 .project_connection_ids(project_id, session.connection_id, false)
2305 .await?;
2306
2307 broadcast(
2308 Some(session.connection_id),
2309 project_connection_ids.iter().copied(),
2310 |connection_id| {
2311 session
2312 .peer
2313 .forward_send(session.connection_id, connection_id, request.clone())
2314 },
2315 );
2316 Ok(())
2317}
2318
2319/// Start following another user in a call.
2320async fn follow(
2321 request: proto::Follow,
2322 response: Response<proto::Follow>,
2323 session: Session,
2324) -> Result<()> {
2325 let room_id = RoomId::from_proto(request.room_id);
2326 let project_id = request.project_id.map(ProjectId::from_proto);
2327 let leader_id = request
2328 .leader_id
2329 .ok_or_else(|| anyhow!("invalid leader id"))?
2330 .into();
2331 let follower_id = session.connection_id;
2332
2333 session
2334 .db()
2335 .await
2336 .check_room_participants(room_id, leader_id, session.connection_id)
2337 .await?;
2338
2339 let response_payload = session
2340 .peer
2341 .forward_request(session.connection_id, leader_id, request)
2342 .await?;
2343 response.send(response_payload)?;
2344
2345 if let Some(project_id) = project_id {
2346 let room = session
2347 .db()
2348 .await
2349 .follow(room_id, project_id, leader_id, follower_id)
2350 .await?;
2351 room_updated(&room, &session.peer);
2352 }
2353
2354 Ok(())
2355}
2356
2357/// Stop following another user in a call.
2358async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2359 let room_id = RoomId::from_proto(request.room_id);
2360 let project_id = request.project_id.map(ProjectId::from_proto);
2361 let leader_id = request
2362 .leader_id
2363 .ok_or_else(|| anyhow!("invalid leader id"))?
2364 .into();
2365 let follower_id = session.connection_id;
2366
2367 session
2368 .db()
2369 .await
2370 .check_room_participants(room_id, leader_id, session.connection_id)
2371 .await?;
2372
2373 session
2374 .peer
2375 .forward_send(session.connection_id, leader_id, request)?;
2376
2377 if let Some(project_id) = project_id {
2378 let room = session
2379 .db()
2380 .await
2381 .unfollow(room_id, project_id, leader_id, follower_id)
2382 .await?;
2383 room_updated(&room, &session.peer);
2384 }
2385
2386 Ok(())
2387}
2388
2389/// Notify everyone following you of your current location.
2390async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2391 let room_id = RoomId::from_proto(request.room_id);
2392 let database = session.db.lock().await;
2393
2394 let connection_ids = if let Some(project_id) = request.project_id {
2395 let project_id = ProjectId::from_proto(project_id);
2396 database
2397 .project_connection_ids(project_id, session.connection_id, true)
2398 .await?
2399 } else {
2400 database
2401 .room_connection_ids(room_id, session.connection_id)
2402 .await?
2403 };
2404
2405 // For now, don't send view update messages back to that view's current leader.
2406 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2407 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2408 _ => None,
2409 });
2410
2411 for connection_id in connection_ids.iter().cloned() {
2412 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2413 session
2414 .peer
2415 .forward_send(session.connection_id, connection_id, request.clone())?;
2416 }
2417 }
2418 Ok(())
2419}
2420
2421/// Get public data about users.
2422async fn get_users(
2423 request: proto::GetUsers,
2424 response: Response<proto::GetUsers>,
2425 session: Session,
2426) -> Result<()> {
2427 let user_ids = request
2428 .user_ids
2429 .into_iter()
2430 .map(UserId::from_proto)
2431 .collect();
2432 let users = session
2433 .db()
2434 .await
2435 .get_users_by_ids(user_ids)
2436 .await?
2437 .into_iter()
2438 .map(|user| proto::User {
2439 id: user.id.to_proto(),
2440 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2441 github_login: user.github_login,
2442 email: user.email_address,
2443 name: user.name,
2444 })
2445 .collect();
2446 response.send(proto::UsersResponse { users })?;
2447 Ok(())
2448}
2449
2450/// Search for users (to invite) buy Github login
2451async fn fuzzy_search_users(
2452 request: proto::FuzzySearchUsers,
2453 response: Response<proto::FuzzySearchUsers>,
2454 session: Session,
2455) -> Result<()> {
2456 let query = request.query;
2457 let users = match query.len() {
2458 0 => vec![],
2459 1 | 2 => session
2460 .db()
2461 .await
2462 .get_user_by_github_login(&query)
2463 .await?
2464 .into_iter()
2465 .collect(),
2466 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2467 };
2468 let users = users
2469 .into_iter()
2470 .filter(|user| user.id != session.user_id())
2471 .map(|user| proto::User {
2472 id: user.id.to_proto(),
2473 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2474 github_login: user.github_login,
2475 name: user.name,
2476 email: user.email_address,
2477 })
2478 .collect();
2479 response.send(proto::UsersResponse { users })?;
2480 Ok(())
2481}
2482
2483/// Send a contact request to another user.
2484async fn request_contact(
2485 request: proto::RequestContact,
2486 response: Response<proto::RequestContact>,
2487 session: Session,
2488) -> Result<()> {
2489 let requester_id = session.user_id();
2490 let responder_id = UserId::from_proto(request.responder_id);
2491 if requester_id == responder_id {
2492 return Err(anyhow!("cannot add yourself as a contact"))?;
2493 }
2494
2495 let notifications = session
2496 .db()
2497 .await
2498 .send_contact_request(requester_id, responder_id)
2499 .await?;
2500
2501 // Update outgoing contact requests of requester
2502 let mut update = proto::UpdateContacts::default();
2503 update.outgoing_requests.push(responder_id.to_proto());
2504 for connection_id in session
2505 .connection_pool()
2506 .await
2507 .user_connection_ids(requester_id)
2508 {
2509 session.peer.send(connection_id, update.clone())?;
2510 }
2511
2512 // Update incoming contact requests of responder
2513 let mut update = proto::UpdateContacts::default();
2514 update
2515 .incoming_requests
2516 .push(proto::IncomingContactRequest {
2517 requester_id: requester_id.to_proto(),
2518 });
2519 let connection_pool = session.connection_pool().await;
2520 for connection_id in connection_pool.user_connection_ids(responder_id) {
2521 session.peer.send(connection_id, update.clone())?;
2522 }
2523
2524 send_notifications(&connection_pool, &session.peer, notifications);
2525
2526 response.send(proto::Ack {})?;
2527 Ok(())
2528}
2529
2530/// Accept or decline a contact request
2531async fn respond_to_contact_request(
2532 request: proto::RespondToContactRequest,
2533 response: Response<proto::RespondToContactRequest>,
2534 session: Session,
2535) -> Result<()> {
2536 let responder_id = session.user_id();
2537 let requester_id = UserId::from_proto(request.requester_id);
2538 let db = session.db().await;
2539 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2540 db.dismiss_contact_notification(responder_id, requester_id)
2541 .await?;
2542 } else {
2543 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2544
2545 let notifications = db
2546 .respond_to_contact_request(responder_id, requester_id, accept)
2547 .await?;
2548 let requester_busy = db.is_user_busy(requester_id).await?;
2549 let responder_busy = db.is_user_busy(responder_id).await?;
2550
2551 let pool = session.connection_pool().await;
2552 // Update responder with new contact
2553 let mut update = proto::UpdateContacts::default();
2554 if accept {
2555 update
2556 .contacts
2557 .push(contact_for_user(requester_id, requester_busy, &pool));
2558 }
2559 update
2560 .remove_incoming_requests
2561 .push(requester_id.to_proto());
2562 for connection_id in pool.user_connection_ids(responder_id) {
2563 session.peer.send(connection_id, update.clone())?;
2564 }
2565
2566 // Update requester with new contact
2567 let mut update = proto::UpdateContacts::default();
2568 if accept {
2569 update
2570 .contacts
2571 .push(contact_for_user(responder_id, responder_busy, &pool));
2572 }
2573 update
2574 .remove_outgoing_requests
2575 .push(responder_id.to_proto());
2576
2577 for connection_id in pool.user_connection_ids(requester_id) {
2578 session.peer.send(connection_id, update.clone())?;
2579 }
2580
2581 send_notifications(&pool, &session.peer, notifications);
2582 }
2583
2584 response.send(proto::Ack {})?;
2585 Ok(())
2586}
2587
2588/// Remove a contact.
2589async fn remove_contact(
2590 request: proto::RemoveContact,
2591 response: Response<proto::RemoveContact>,
2592 session: Session,
2593) -> Result<()> {
2594 let requester_id = session.user_id();
2595 let responder_id = UserId::from_proto(request.user_id);
2596 let db = session.db().await;
2597 let (contact_accepted, deleted_notification_id) =
2598 db.remove_contact(requester_id, responder_id).await?;
2599
2600 let pool = session.connection_pool().await;
2601 // Update outgoing contact requests of requester
2602 let mut update = proto::UpdateContacts::default();
2603 if contact_accepted {
2604 update.remove_contacts.push(responder_id.to_proto());
2605 } else {
2606 update
2607 .remove_outgoing_requests
2608 .push(responder_id.to_proto());
2609 }
2610 for connection_id in pool.user_connection_ids(requester_id) {
2611 session.peer.send(connection_id, update.clone())?;
2612 }
2613
2614 // Update incoming contact requests of responder
2615 let mut update = proto::UpdateContacts::default();
2616 if contact_accepted {
2617 update.remove_contacts.push(requester_id.to_proto());
2618 } else {
2619 update
2620 .remove_incoming_requests
2621 .push(requester_id.to_proto());
2622 }
2623 for connection_id in pool.user_connection_ids(responder_id) {
2624 session.peer.send(connection_id, update.clone())?;
2625 if let Some(notification_id) = deleted_notification_id {
2626 session.peer.send(
2627 connection_id,
2628 proto::DeleteNotification {
2629 notification_id: notification_id.to_proto(),
2630 },
2631 )?;
2632 }
2633 }
2634
2635 response.send(proto::Ack {})?;
2636 Ok(())
2637}
2638
2639fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2640 version.0.minor() < 139
2641}
2642
2643async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2644 let plan = session.current_plan(&session.db().await).await?;
2645
2646 session
2647 .peer
2648 .send(
2649 session.connection_id,
2650 proto::UpdateUserPlan { plan: plan.into() },
2651 )
2652 .trace_err();
2653
2654 Ok(())
2655}
2656
2657async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2658 subscribe_user_to_channels(session.user_id(), &session).await?;
2659 Ok(())
2660}
2661
2662async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2663 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2664 let mut pool = session.connection_pool().await;
2665 for membership in &channels_for_user.channel_memberships {
2666 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2667 }
2668 session.peer.send(
2669 session.connection_id,
2670 build_update_user_channels(&channels_for_user),
2671 )?;
2672 session.peer.send(
2673 session.connection_id,
2674 build_channels_update(channels_for_user),
2675 )?;
2676 Ok(())
2677}
2678
2679/// Creates a new channel.
2680async fn create_channel(
2681 request: proto::CreateChannel,
2682 response: Response<proto::CreateChannel>,
2683 session: Session,
2684) -> Result<()> {
2685 let db = session.db().await;
2686
2687 let parent_id = request.parent_id.map(ChannelId::from_proto);
2688 let (channel, membership) = db
2689 .create_channel(&request.name, parent_id, session.user_id())
2690 .await?;
2691
2692 let root_id = channel.root_id();
2693 let channel = Channel::from_model(channel);
2694
2695 response.send(proto::CreateChannelResponse {
2696 channel: Some(channel.to_proto()),
2697 parent_id: request.parent_id,
2698 })?;
2699
2700 let mut connection_pool = session.connection_pool().await;
2701 if let Some(membership) = membership {
2702 connection_pool.subscribe_to_channel(
2703 membership.user_id,
2704 membership.channel_id,
2705 membership.role,
2706 );
2707 let update = proto::UpdateUserChannels {
2708 channel_memberships: vec![proto::ChannelMembership {
2709 channel_id: membership.channel_id.to_proto(),
2710 role: membership.role.into(),
2711 }],
2712 ..Default::default()
2713 };
2714 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2715 session.peer.send(connection_id, update.clone())?;
2716 }
2717 }
2718
2719 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2720 if !role.can_see_channel(channel.visibility) {
2721 continue;
2722 }
2723
2724 let update = proto::UpdateChannels {
2725 channels: vec![channel.to_proto()],
2726 ..Default::default()
2727 };
2728 session.peer.send(connection_id, update.clone())?;
2729 }
2730
2731 Ok(())
2732}
2733
2734/// Delete a channel
2735async fn delete_channel(
2736 request: proto::DeleteChannel,
2737 response: Response<proto::DeleteChannel>,
2738 session: Session,
2739) -> Result<()> {
2740 let db = session.db().await;
2741
2742 let channel_id = request.channel_id;
2743 let (root_channel, removed_channels) = db
2744 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2745 .await?;
2746 response.send(proto::Ack {})?;
2747
2748 // Notify members of removed channels
2749 let mut update = proto::UpdateChannels::default();
2750 update
2751 .delete_channels
2752 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2753
2754 let connection_pool = session.connection_pool().await;
2755 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2756 session.peer.send(connection_id, update.clone())?;
2757 }
2758
2759 Ok(())
2760}
2761
2762/// Invite someone to join a channel.
2763async fn invite_channel_member(
2764 request: proto::InviteChannelMember,
2765 response: Response<proto::InviteChannelMember>,
2766 session: Session,
2767) -> Result<()> {
2768 let db = session.db().await;
2769 let channel_id = ChannelId::from_proto(request.channel_id);
2770 let invitee_id = UserId::from_proto(request.user_id);
2771 let InviteMemberResult {
2772 channel,
2773 notifications,
2774 } = db
2775 .invite_channel_member(
2776 channel_id,
2777 invitee_id,
2778 session.user_id(),
2779 request.role().into(),
2780 )
2781 .await?;
2782
2783 let update = proto::UpdateChannels {
2784 channel_invitations: vec![channel.to_proto()],
2785 ..Default::default()
2786 };
2787
2788 let connection_pool = session.connection_pool().await;
2789 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2790 session.peer.send(connection_id, update.clone())?;
2791 }
2792
2793 send_notifications(&connection_pool, &session.peer, notifications);
2794
2795 response.send(proto::Ack {})?;
2796 Ok(())
2797}
2798
2799/// remove someone from a channel
2800async fn remove_channel_member(
2801 request: proto::RemoveChannelMember,
2802 response: Response<proto::RemoveChannelMember>,
2803 session: Session,
2804) -> Result<()> {
2805 let db = session.db().await;
2806 let channel_id = ChannelId::from_proto(request.channel_id);
2807 let member_id = UserId::from_proto(request.user_id);
2808
2809 let RemoveChannelMemberResult {
2810 membership_update,
2811 notification_id,
2812 } = db
2813 .remove_channel_member(channel_id, member_id, session.user_id())
2814 .await?;
2815
2816 let mut connection_pool = session.connection_pool().await;
2817 notify_membership_updated(
2818 &mut connection_pool,
2819 membership_update,
2820 member_id,
2821 &session.peer,
2822 );
2823 for connection_id in connection_pool.user_connection_ids(member_id) {
2824 if let Some(notification_id) = notification_id {
2825 session
2826 .peer
2827 .send(
2828 connection_id,
2829 proto::DeleteNotification {
2830 notification_id: notification_id.to_proto(),
2831 },
2832 )
2833 .trace_err();
2834 }
2835 }
2836
2837 response.send(proto::Ack {})?;
2838 Ok(())
2839}
2840
2841/// Toggle the channel between public and private.
2842/// Care is taken to maintain the invariant that public channels only descend from public channels,
2843/// (though members-only channels can appear at any point in the hierarchy).
2844async fn set_channel_visibility(
2845 request: proto::SetChannelVisibility,
2846 response: Response<proto::SetChannelVisibility>,
2847 session: Session,
2848) -> Result<()> {
2849 let db = session.db().await;
2850 let channel_id = ChannelId::from_proto(request.channel_id);
2851 let visibility = request.visibility().into();
2852
2853 let channel_model = db
2854 .set_channel_visibility(channel_id, visibility, session.user_id())
2855 .await?;
2856 let root_id = channel_model.root_id();
2857 let channel = Channel::from_model(channel_model);
2858
2859 let mut connection_pool = session.connection_pool().await;
2860 for (user_id, role) in connection_pool
2861 .channel_user_ids(root_id)
2862 .collect::<Vec<_>>()
2863 .into_iter()
2864 {
2865 let update = if role.can_see_channel(channel.visibility) {
2866 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2867 proto::UpdateChannels {
2868 channels: vec![channel.to_proto()],
2869 ..Default::default()
2870 }
2871 } else {
2872 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2873 proto::UpdateChannels {
2874 delete_channels: vec![channel.id.to_proto()],
2875 ..Default::default()
2876 }
2877 };
2878
2879 for connection_id in connection_pool.user_connection_ids(user_id) {
2880 session.peer.send(connection_id, update.clone())?;
2881 }
2882 }
2883
2884 response.send(proto::Ack {})?;
2885 Ok(())
2886}
2887
2888/// Alter the role for a user in the channel.
2889async fn set_channel_member_role(
2890 request: proto::SetChannelMemberRole,
2891 response: Response<proto::SetChannelMemberRole>,
2892 session: Session,
2893) -> Result<()> {
2894 let db = session.db().await;
2895 let channel_id = ChannelId::from_proto(request.channel_id);
2896 let member_id = UserId::from_proto(request.user_id);
2897 let result = db
2898 .set_channel_member_role(
2899 channel_id,
2900 session.user_id(),
2901 member_id,
2902 request.role().into(),
2903 )
2904 .await?;
2905
2906 match result {
2907 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2908 let mut connection_pool = session.connection_pool().await;
2909 notify_membership_updated(
2910 &mut connection_pool,
2911 membership_update,
2912 member_id,
2913 &session.peer,
2914 )
2915 }
2916 db::SetMemberRoleResult::InviteUpdated(channel) => {
2917 let update = proto::UpdateChannels {
2918 channel_invitations: vec![channel.to_proto()],
2919 ..Default::default()
2920 };
2921
2922 for connection_id in session
2923 .connection_pool()
2924 .await
2925 .user_connection_ids(member_id)
2926 {
2927 session.peer.send(connection_id, update.clone())?;
2928 }
2929 }
2930 }
2931
2932 response.send(proto::Ack {})?;
2933 Ok(())
2934}
2935
2936/// Change the name of a channel
2937async fn rename_channel(
2938 request: proto::RenameChannel,
2939 response: Response<proto::RenameChannel>,
2940 session: Session,
2941) -> Result<()> {
2942 let db = session.db().await;
2943 let channel_id = ChannelId::from_proto(request.channel_id);
2944 let channel_model = db
2945 .rename_channel(channel_id, session.user_id(), &request.name)
2946 .await?;
2947 let root_id = channel_model.root_id();
2948 let channel = Channel::from_model(channel_model);
2949
2950 response.send(proto::RenameChannelResponse {
2951 channel: Some(channel.to_proto()),
2952 })?;
2953
2954 let connection_pool = session.connection_pool().await;
2955 let update = proto::UpdateChannels {
2956 channels: vec![channel.to_proto()],
2957 ..Default::default()
2958 };
2959 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2960 if role.can_see_channel(channel.visibility) {
2961 session.peer.send(connection_id, update.clone())?;
2962 }
2963 }
2964
2965 Ok(())
2966}
2967
2968/// Move a channel to a new parent.
2969async fn move_channel(
2970 request: proto::MoveChannel,
2971 response: Response<proto::MoveChannel>,
2972 session: Session,
2973) -> Result<()> {
2974 let channel_id = ChannelId::from_proto(request.channel_id);
2975 let to = ChannelId::from_proto(request.to);
2976
2977 let (root_id, channels) = session
2978 .db()
2979 .await
2980 .move_channel(channel_id, to, session.user_id())
2981 .await?;
2982
2983 let connection_pool = session.connection_pool().await;
2984 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2985 let channels = channels
2986 .iter()
2987 .filter_map(|channel| {
2988 if role.can_see_channel(channel.visibility) {
2989 Some(channel.to_proto())
2990 } else {
2991 None
2992 }
2993 })
2994 .collect::<Vec<_>>();
2995 if channels.is_empty() {
2996 continue;
2997 }
2998
2999 let update = proto::UpdateChannels {
3000 channels,
3001 ..Default::default()
3002 };
3003
3004 session.peer.send(connection_id, update.clone())?;
3005 }
3006
3007 response.send(Ack {})?;
3008 Ok(())
3009}
3010
3011/// Get the list of channel members
3012async fn get_channel_members(
3013 request: proto::GetChannelMembers,
3014 response: Response<proto::GetChannelMembers>,
3015 session: Session,
3016) -> Result<()> {
3017 let db = session.db().await;
3018 let channel_id = ChannelId::from_proto(request.channel_id);
3019 let limit = if request.limit == 0 {
3020 u16::MAX as u64
3021 } else {
3022 request.limit
3023 };
3024 let (members, users) = db
3025 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3026 .await?;
3027 response.send(proto::GetChannelMembersResponse { members, users })?;
3028 Ok(())
3029}
3030
3031/// Accept or decline a channel invitation.
3032async fn respond_to_channel_invite(
3033 request: proto::RespondToChannelInvite,
3034 response: Response<proto::RespondToChannelInvite>,
3035 session: Session,
3036) -> Result<()> {
3037 let db = session.db().await;
3038 let channel_id = ChannelId::from_proto(request.channel_id);
3039 let RespondToChannelInvite {
3040 membership_update,
3041 notifications,
3042 } = db
3043 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3044 .await?;
3045
3046 let mut connection_pool = session.connection_pool().await;
3047 if let Some(membership_update) = membership_update {
3048 notify_membership_updated(
3049 &mut connection_pool,
3050 membership_update,
3051 session.user_id(),
3052 &session.peer,
3053 );
3054 } else {
3055 let update = proto::UpdateChannels {
3056 remove_channel_invitations: vec![channel_id.to_proto()],
3057 ..Default::default()
3058 };
3059
3060 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3061 session.peer.send(connection_id, update.clone())?;
3062 }
3063 };
3064
3065 send_notifications(&connection_pool, &session.peer, notifications);
3066
3067 response.send(proto::Ack {})?;
3068
3069 Ok(())
3070}
3071
3072/// Join the channels' room
3073async fn join_channel(
3074 request: proto::JoinChannel,
3075 response: Response<proto::JoinChannel>,
3076 session: Session,
3077) -> Result<()> {
3078 let channel_id = ChannelId::from_proto(request.channel_id);
3079 join_channel_internal(channel_id, Box::new(response), session).await
3080}
3081
3082trait JoinChannelInternalResponse {
3083 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3084}
3085impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3086 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3087 Response::<proto::JoinChannel>::send(self, result)
3088 }
3089}
3090impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3091 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3092 Response::<proto::JoinRoom>::send(self, result)
3093 }
3094}
3095
3096async fn join_channel_internal(
3097 channel_id: ChannelId,
3098 response: Box<impl JoinChannelInternalResponse>,
3099 session: Session,
3100) -> Result<()> {
3101 let joined_room = {
3102 let mut db = session.db().await;
3103 // If zed quits without leaving the room, and the user re-opens zed before the
3104 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3105 // room they were in.
3106 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3107 tracing::info!(
3108 stale_connection_id = %connection,
3109 "cleaning up stale connection",
3110 );
3111 drop(db);
3112 leave_room_for_session(&session, connection).await?;
3113 db = session.db().await;
3114 }
3115
3116 let (joined_room, membership_updated, role) = db
3117 .join_channel(channel_id, session.user_id(), session.connection_id)
3118 .await?;
3119
3120 let live_kit_connection_info =
3121 session
3122 .app_state
3123 .livekit_client
3124 .as_ref()
3125 .and_then(|live_kit| {
3126 let (can_publish, token) = if role == ChannelRole::Guest {
3127 (
3128 false,
3129 live_kit
3130 .guest_token(
3131 &joined_room.room.livekit_room,
3132 &session.user_id().to_string(),
3133 )
3134 .trace_err()?,
3135 )
3136 } else {
3137 (
3138 true,
3139 live_kit
3140 .room_token(
3141 &joined_room.room.livekit_room,
3142 &session.user_id().to_string(),
3143 )
3144 .trace_err()?,
3145 )
3146 };
3147
3148 Some(LiveKitConnectionInfo {
3149 server_url: live_kit.url().into(),
3150 token,
3151 can_publish,
3152 })
3153 });
3154
3155 response.send(proto::JoinRoomResponse {
3156 room: Some(joined_room.room.clone()),
3157 channel_id: joined_room
3158 .channel
3159 .as_ref()
3160 .map(|channel| channel.id.to_proto()),
3161 live_kit_connection_info,
3162 })?;
3163
3164 let mut connection_pool = session.connection_pool().await;
3165 if let Some(membership_updated) = membership_updated {
3166 notify_membership_updated(
3167 &mut connection_pool,
3168 membership_updated,
3169 session.user_id(),
3170 &session.peer,
3171 );
3172 }
3173
3174 room_updated(&joined_room.room, &session.peer);
3175
3176 joined_room
3177 };
3178
3179 channel_updated(
3180 &joined_room
3181 .channel
3182 .ok_or_else(|| anyhow!("channel not returned"))?,
3183 &joined_room.room,
3184 &session.peer,
3185 &*session.connection_pool().await,
3186 );
3187
3188 update_user_contacts(session.user_id(), &session).await?;
3189 Ok(())
3190}
3191
3192/// Start editing the channel notes
3193async fn join_channel_buffer(
3194 request: proto::JoinChannelBuffer,
3195 response: Response<proto::JoinChannelBuffer>,
3196 session: Session,
3197) -> Result<()> {
3198 let db = session.db().await;
3199 let channel_id = ChannelId::from_proto(request.channel_id);
3200
3201 let open_response = db
3202 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3203 .await?;
3204
3205 let collaborators = open_response.collaborators.clone();
3206 response.send(open_response)?;
3207
3208 let update = UpdateChannelBufferCollaborators {
3209 channel_id: channel_id.to_proto(),
3210 collaborators: collaborators.clone(),
3211 };
3212 channel_buffer_updated(
3213 session.connection_id,
3214 collaborators
3215 .iter()
3216 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3217 &update,
3218 &session.peer,
3219 );
3220
3221 Ok(())
3222}
3223
3224/// Edit the channel notes
3225async fn update_channel_buffer(
3226 request: proto::UpdateChannelBuffer,
3227 session: Session,
3228) -> Result<()> {
3229 let db = session.db().await;
3230 let channel_id = ChannelId::from_proto(request.channel_id);
3231
3232 let (collaborators, epoch, version) = db
3233 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3234 .await?;
3235
3236 channel_buffer_updated(
3237 session.connection_id,
3238 collaborators.clone(),
3239 &proto::UpdateChannelBuffer {
3240 channel_id: channel_id.to_proto(),
3241 operations: request.operations,
3242 },
3243 &session.peer,
3244 );
3245
3246 let pool = &*session.connection_pool().await;
3247
3248 let non_collaborators =
3249 pool.channel_connection_ids(channel_id)
3250 .filter_map(|(connection_id, _)| {
3251 if collaborators.contains(&connection_id) {
3252 None
3253 } else {
3254 Some(connection_id)
3255 }
3256 });
3257
3258 broadcast(None, non_collaborators, |peer_id| {
3259 session.peer.send(
3260 peer_id,
3261 proto::UpdateChannels {
3262 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3263 channel_id: channel_id.to_proto(),
3264 epoch: epoch as u64,
3265 version: version.clone(),
3266 }],
3267 ..Default::default()
3268 },
3269 )
3270 });
3271
3272 Ok(())
3273}
3274
3275/// Rejoin the channel notes after a connection blip
3276async fn rejoin_channel_buffers(
3277 request: proto::RejoinChannelBuffers,
3278 response: Response<proto::RejoinChannelBuffers>,
3279 session: Session,
3280) -> Result<()> {
3281 let db = session.db().await;
3282 let buffers = db
3283 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3284 .await?;
3285
3286 for rejoined_buffer in &buffers {
3287 let collaborators_to_notify = rejoined_buffer
3288 .buffer
3289 .collaborators
3290 .iter()
3291 .filter_map(|c| Some(c.peer_id?.into()));
3292 channel_buffer_updated(
3293 session.connection_id,
3294 collaborators_to_notify,
3295 &proto::UpdateChannelBufferCollaborators {
3296 channel_id: rejoined_buffer.buffer.channel_id,
3297 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3298 },
3299 &session.peer,
3300 );
3301 }
3302
3303 response.send(proto::RejoinChannelBuffersResponse {
3304 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3305 })?;
3306
3307 Ok(())
3308}
3309
3310/// Stop editing the channel notes
3311async fn leave_channel_buffer(
3312 request: proto::LeaveChannelBuffer,
3313 response: Response<proto::LeaveChannelBuffer>,
3314 session: Session,
3315) -> Result<()> {
3316 let db = session.db().await;
3317 let channel_id = ChannelId::from_proto(request.channel_id);
3318
3319 let left_buffer = db
3320 .leave_channel_buffer(channel_id, session.connection_id)
3321 .await?;
3322
3323 response.send(Ack {})?;
3324
3325 channel_buffer_updated(
3326 session.connection_id,
3327 left_buffer.connections,
3328 &proto::UpdateChannelBufferCollaborators {
3329 channel_id: channel_id.to_proto(),
3330 collaborators: left_buffer.collaborators,
3331 },
3332 &session.peer,
3333 );
3334
3335 Ok(())
3336}
3337
3338fn channel_buffer_updated<T: EnvelopedMessage>(
3339 sender_id: ConnectionId,
3340 collaborators: impl IntoIterator<Item = ConnectionId>,
3341 message: &T,
3342 peer: &Peer,
3343) {
3344 broadcast(Some(sender_id), collaborators, |peer_id| {
3345 peer.send(peer_id, message.clone())
3346 });
3347}
3348
3349fn send_notifications(
3350 connection_pool: &ConnectionPool,
3351 peer: &Peer,
3352 notifications: db::NotificationBatch,
3353) {
3354 for (user_id, notification) in notifications {
3355 for connection_id in connection_pool.user_connection_ids(user_id) {
3356 if let Err(error) = peer.send(
3357 connection_id,
3358 proto::AddNotification {
3359 notification: Some(notification.clone()),
3360 },
3361 ) {
3362 tracing::error!(
3363 "failed to send notification to {:?} {}",
3364 connection_id,
3365 error
3366 );
3367 }
3368 }
3369 }
3370}
3371
3372/// Send a message to the channel
3373async fn send_channel_message(
3374 request: proto::SendChannelMessage,
3375 response: Response<proto::SendChannelMessage>,
3376 session: Session,
3377) -> Result<()> {
3378 // Validate the message body.
3379 let body = request.body.trim().to_string();
3380 if body.len() > MAX_MESSAGE_LEN {
3381 return Err(anyhow!("message is too long"))?;
3382 }
3383 if body.is_empty() {
3384 return Err(anyhow!("message can't be blank"))?;
3385 }
3386
3387 // TODO: adjust mentions if body is trimmed
3388
3389 let timestamp = OffsetDateTime::now_utc();
3390 let nonce = request
3391 .nonce
3392 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3393
3394 let channel_id = ChannelId::from_proto(request.channel_id);
3395 let CreatedChannelMessage {
3396 message_id,
3397 participant_connection_ids,
3398 notifications,
3399 } = session
3400 .db()
3401 .await
3402 .create_channel_message(
3403 channel_id,
3404 session.user_id(),
3405 &body,
3406 &request.mentions,
3407 timestamp,
3408 nonce.clone().into(),
3409 request.reply_to_message_id.map(MessageId::from_proto),
3410 )
3411 .await?;
3412
3413 let message = proto::ChannelMessage {
3414 sender_id: session.user_id().to_proto(),
3415 id: message_id.to_proto(),
3416 body,
3417 mentions: request.mentions,
3418 timestamp: timestamp.unix_timestamp() as u64,
3419 nonce: Some(nonce),
3420 reply_to_message_id: request.reply_to_message_id,
3421 edited_at: None,
3422 };
3423 broadcast(
3424 Some(session.connection_id),
3425 participant_connection_ids.clone(),
3426 |connection| {
3427 session.peer.send(
3428 connection,
3429 proto::ChannelMessageSent {
3430 channel_id: channel_id.to_proto(),
3431 message: Some(message.clone()),
3432 },
3433 )
3434 },
3435 );
3436 response.send(proto::SendChannelMessageResponse {
3437 message: Some(message),
3438 })?;
3439
3440 let pool = &*session.connection_pool().await;
3441 let non_participants =
3442 pool.channel_connection_ids(channel_id)
3443 .filter_map(|(connection_id, _)| {
3444 if participant_connection_ids.contains(&connection_id) {
3445 None
3446 } else {
3447 Some(connection_id)
3448 }
3449 });
3450 broadcast(None, non_participants, |peer_id| {
3451 session.peer.send(
3452 peer_id,
3453 proto::UpdateChannels {
3454 latest_channel_message_ids: vec![proto::ChannelMessageId {
3455 channel_id: channel_id.to_proto(),
3456 message_id: message_id.to_proto(),
3457 }],
3458 ..Default::default()
3459 },
3460 )
3461 });
3462 send_notifications(pool, &session.peer, notifications);
3463
3464 Ok(())
3465}
3466
3467/// Delete a channel message
3468async fn remove_channel_message(
3469 request: proto::RemoveChannelMessage,
3470 response: Response<proto::RemoveChannelMessage>,
3471 session: Session,
3472) -> Result<()> {
3473 let channel_id = ChannelId::from_proto(request.channel_id);
3474 let message_id = MessageId::from_proto(request.message_id);
3475 let (connection_ids, existing_notification_ids) = session
3476 .db()
3477 .await
3478 .remove_channel_message(channel_id, message_id, session.user_id())
3479 .await?;
3480
3481 broadcast(
3482 Some(session.connection_id),
3483 connection_ids,
3484 move |connection| {
3485 session.peer.send(connection, request.clone())?;
3486
3487 for notification_id in &existing_notification_ids {
3488 session.peer.send(
3489 connection,
3490 proto::DeleteNotification {
3491 notification_id: (*notification_id).to_proto(),
3492 },
3493 )?;
3494 }
3495
3496 Ok(())
3497 },
3498 );
3499 response.send(proto::Ack {})?;
3500 Ok(())
3501}
3502
3503async fn update_channel_message(
3504 request: proto::UpdateChannelMessage,
3505 response: Response<proto::UpdateChannelMessage>,
3506 session: Session,
3507) -> Result<()> {
3508 let channel_id = ChannelId::from_proto(request.channel_id);
3509 let message_id = MessageId::from_proto(request.message_id);
3510 let updated_at = OffsetDateTime::now_utc();
3511 let UpdatedChannelMessage {
3512 message_id,
3513 participant_connection_ids,
3514 notifications,
3515 reply_to_message_id,
3516 timestamp,
3517 deleted_mention_notification_ids,
3518 updated_mention_notifications,
3519 } = session
3520 .db()
3521 .await
3522 .update_channel_message(
3523 channel_id,
3524 message_id,
3525 session.user_id(),
3526 request.body.as_str(),
3527 &request.mentions,
3528 updated_at,
3529 )
3530 .await?;
3531
3532 let nonce = request
3533 .nonce
3534 .clone()
3535 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3536
3537 let message = proto::ChannelMessage {
3538 sender_id: session.user_id().to_proto(),
3539 id: message_id.to_proto(),
3540 body: request.body.clone(),
3541 mentions: request.mentions.clone(),
3542 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3543 nonce: Some(nonce),
3544 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3545 edited_at: Some(updated_at.unix_timestamp() as u64),
3546 };
3547
3548 response.send(proto::Ack {})?;
3549
3550 let pool = &*session.connection_pool().await;
3551 broadcast(
3552 Some(session.connection_id),
3553 participant_connection_ids,
3554 |connection| {
3555 session.peer.send(
3556 connection,
3557 proto::ChannelMessageUpdate {
3558 channel_id: channel_id.to_proto(),
3559 message: Some(message.clone()),
3560 },
3561 )?;
3562
3563 for notification_id in &deleted_mention_notification_ids {
3564 session.peer.send(
3565 connection,
3566 proto::DeleteNotification {
3567 notification_id: (*notification_id).to_proto(),
3568 },
3569 )?;
3570 }
3571
3572 for notification in &updated_mention_notifications {
3573 session.peer.send(
3574 connection,
3575 proto::UpdateNotification {
3576 notification: Some(notification.clone()),
3577 },
3578 )?;
3579 }
3580
3581 Ok(())
3582 },
3583 );
3584
3585 send_notifications(pool, &session.peer, notifications);
3586
3587 Ok(())
3588}
3589
3590/// Mark a channel message as read
3591async fn acknowledge_channel_message(
3592 request: proto::AckChannelMessage,
3593 session: Session,
3594) -> Result<()> {
3595 let channel_id = ChannelId::from_proto(request.channel_id);
3596 let message_id = MessageId::from_proto(request.message_id);
3597 let notifications = session
3598 .db()
3599 .await
3600 .observe_channel_message(channel_id, session.user_id(), message_id)
3601 .await?;
3602 send_notifications(
3603 &*session.connection_pool().await,
3604 &session.peer,
3605 notifications,
3606 );
3607 Ok(())
3608}
3609
3610/// Mark a buffer version as synced
3611async fn acknowledge_buffer_version(
3612 request: proto::AckBufferOperation,
3613 session: Session,
3614) -> Result<()> {
3615 let buffer_id = BufferId::from_proto(request.buffer_id);
3616 session
3617 .db()
3618 .await
3619 .observe_buffer_version(
3620 buffer_id,
3621 session.user_id(),
3622 request.epoch as i32,
3623 &request.version,
3624 )
3625 .await?;
3626 Ok(())
3627}
3628
3629async fn count_language_model_tokens(
3630 request: proto::CountLanguageModelTokens,
3631 response: Response<proto::CountLanguageModelTokens>,
3632 session: Session,
3633 config: &Config,
3634) -> Result<()> {
3635 authorize_access_to_legacy_llm_endpoints(&session).await?;
3636
3637 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3638 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3639 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3640 };
3641
3642 session
3643 .app_state
3644 .rate_limiter
3645 .check(&*rate_limit, session.user_id())
3646 .await?;
3647
3648 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3649 Some(proto::LanguageModelProvider::Google) => {
3650 let api_key = config
3651 .google_ai_api_key
3652 .as_ref()
3653 .context("no Google AI API key configured on the server")?;
3654 google_ai::count_tokens(
3655 session.http_client.as_ref(),
3656 google_ai::API_URL,
3657 api_key,
3658 serde_json::from_str(&request.request)?,
3659 )
3660 .await?
3661 }
3662 _ => return Err(anyhow!("unsupported provider"))?,
3663 };
3664
3665 response.send(proto::CountLanguageModelTokensResponse {
3666 token_count: result.total_tokens as u32,
3667 })?;
3668
3669 Ok(())
3670}
3671
3672struct ZedProCountLanguageModelTokensRateLimit;
3673
3674impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3675 fn capacity(&self) -> usize {
3676 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3677 .ok()
3678 .and_then(|v| v.parse().ok())
3679 .unwrap_or(600) // Picked arbitrarily
3680 }
3681
3682 fn refill_duration(&self) -> chrono::Duration {
3683 chrono::Duration::hours(1)
3684 }
3685
3686 fn db_name(&self) -> &'static str {
3687 "zed-pro:count-language-model-tokens"
3688 }
3689}
3690
3691struct FreeCountLanguageModelTokensRateLimit;
3692
3693impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3694 fn capacity(&self) -> usize {
3695 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3696 .ok()
3697 .and_then(|v| v.parse().ok())
3698 .unwrap_or(600 / 10) // Picked arbitrarily
3699 }
3700
3701 fn refill_duration(&self) -> chrono::Duration {
3702 chrono::Duration::hours(1)
3703 }
3704
3705 fn db_name(&self) -> &'static str {
3706 "free:count-language-model-tokens"
3707 }
3708}
3709
3710struct ZedProComputeEmbeddingsRateLimit;
3711
3712impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3713 fn capacity(&self) -> usize {
3714 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3715 .ok()
3716 .and_then(|v| v.parse().ok())
3717 .unwrap_or(5000) // Picked arbitrarily
3718 }
3719
3720 fn refill_duration(&self) -> chrono::Duration {
3721 chrono::Duration::hours(1)
3722 }
3723
3724 fn db_name(&self) -> &'static str {
3725 "zed-pro:compute-embeddings"
3726 }
3727}
3728
3729struct FreeComputeEmbeddingsRateLimit;
3730
3731impl RateLimit for FreeComputeEmbeddingsRateLimit {
3732 fn capacity(&self) -> usize {
3733 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3734 .ok()
3735 .and_then(|v| v.parse().ok())
3736 .unwrap_or(5000 / 10) // Picked arbitrarily
3737 }
3738
3739 fn refill_duration(&self) -> chrono::Duration {
3740 chrono::Duration::hours(1)
3741 }
3742
3743 fn db_name(&self) -> &'static str {
3744 "free:compute-embeddings"
3745 }
3746}
3747
3748async fn compute_embeddings(
3749 request: proto::ComputeEmbeddings,
3750 response: Response<proto::ComputeEmbeddings>,
3751 session: Session,
3752 api_key: Option<Arc<str>>,
3753) -> Result<()> {
3754 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3755 authorize_access_to_legacy_llm_endpoints(&session).await?;
3756
3757 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3758 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3759 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3760 };
3761
3762 session
3763 .app_state
3764 .rate_limiter
3765 .check(&*rate_limit, session.user_id())
3766 .await?;
3767
3768 let embeddings = match request.model.as_str() {
3769 "openai/text-embedding-3-small" => {
3770 open_ai::embed(
3771 session.http_client.as_ref(),
3772 OPEN_AI_API_URL,
3773 &api_key,
3774 OpenAiEmbeddingModel::TextEmbedding3Small,
3775 request.texts.iter().map(|text| text.as_str()),
3776 )
3777 .await?
3778 }
3779 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3780 };
3781
3782 let embeddings = request
3783 .texts
3784 .iter()
3785 .map(|text| {
3786 let mut hasher = sha2::Sha256::new();
3787 hasher.update(text.as_bytes());
3788 let result = hasher.finalize();
3789 result.to_vec()
3790 })
3791 .zip(
3792 embeddings
3793 .data
3794 .into_iter()
3795 .map(|embedding| embedding.embedding),
3796 )
3797 .collect::<HashMap<_, _>>();
3798
3799 let db = session.db().await;
3800 db.save_embeddings(&request.model, &embeddings)
3801 .await
3802 .context("failed to save embeddings")
3803 .trace_err();
3804
3805 response.send(proto::ComputeEmbeddingsResponse {
3806 embeddings: embeddings
3807 .into_iter()
3808 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3809 .collect(),
3810 })?;
3811 Ok(())
3812}
3813
3814async fn get_cached_embeddings(
3815 request: proto::GetCachedEmbeddings,
3816 response: Response<proto::GetCachedEmbeddings>,
3817 session: Session,
3818) -> Result<()> {
3819 authorize_access_to_legacy_llm_endpoints(&session).await?;
3820
3821 let db = session.db().await;
3822 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3823
3824 response.send(proto::GetCachedEmbeddingsResponse {
3825 embeddings: embeddings
3826 .into_iter()
3827 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3828 .collect(),
3829 })?;
3830 Ok(())
3831}
3832
3833/// This is leftover from before the LLM service.
3834///
3835/// The endpoints protected by this check will be moved there eventually.
3836async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3837 if session.is_staff() {
3838 Ok(())
3839 } else {
3840 Err(anyhow!("permission denied"))?
3841 }
3842}
3843
3844/// Get a Supermaven API key for the user
3845async fn get_supermaven_api_key(
3846 _request: proto::GetSupermavenApiKey,
3847 response: Response<proto::GetSupermavenApiKey>,
3848 session: Session,
3849) -> Result<()> {
3850 let user_id: String = session.user_id().to_string();
3851 if !session.is_staff() {
3852 return Err(anyhow!("supermaven not enabled for this account"))?;
3853 }
3854
3855 let email = session
3856 .email()
3857 .ok_or_else(|| anyhow!("user must have an email"))?;
3858
3859 let supermaven_admin_api = session
3860 .supermaven_client
3861 .as_ref()
3862 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3863
3864 let result = supermaven_admin_api
3865 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3866 .await?;
3867
3868 response.send(proto::GetSupermavenApiKeyResponse {
3869 api_key: result.api_key,
3870 })?;
3871
3872 Ok(())
3873}
3874
3875/// Start receiving chat updates for a channel
3876async fn join_channel_chat(
3877 request: proto::JoinChannelChat,
3878 response: Response<proto::JoinChannelChat>,
3879 session: Session,
3880) -> Result<()> {
3881 let channel_id = ChannelId::from_proto(request.channel_id);
3882
3883 let db = session.db().await;
3884 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3885 .await?;
3886 let messages = db
3887 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3888 .await?;
3889 response.send(proto::JoinChannelChatResponse {
3890 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3891 messages,
3892 })?;
3893 Ok(())
3894}
3895
3896/// Stop receiving chat updates for a channel
3897async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3898 let channel_id = ChannelId::from_proto(request.channel_id);
3899 session
3900 .db()
3901 .await
3902 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3903 .await?;
3904 Ok(())
3905}
3906
3907/// Retrieve the chat history for a channel
3908async fn get_channel_messages(
3909 request: proto::GetChannelMessages,
3910 response: Response<proto::GetChannelMessages>,
3911 session: Session,
3912) -> Result<()> {
3913 let channel_id = ChannelId::from_proto(request.channel_id);
3914 let messages = session
3915 .db()
3916 .await
3917 .get_channel_messages(
3918 channel_id,
3919 session.user_id(),
3920 MESSAGE_COUNT_PER_PAGE,
3921 Some(MessageId::from_proto(request.before_message_id)),
3922 )
3923 .await?;
3924 response.send(proto::GetChannelMessagesResponse {
3925 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3926 messages,
3927 })?;
3928 Ok(())
3929}
3930
3931/// Retrieve specific chat messages
3932async fn get_channel_messages_by_id(
3933 request: proto::GetChannelMessagesById,
3934 response: Response<proto::GetChannelMessagesById>,
3935 session: Session,
3936) -> Result<()> {
3937 let message_ids = request
3938 .message_ids
3939 .iter()
3940 .map(|id| MessageId::from_proto(*id))
3941 .collect::<Vec<_>>();
3942 let messages = session
3943 .db()
3944 .await
3945 .get_channel_messages_by_id(session.user_id(), &message_ids)
3946 .await?;
3947 response.send(proto::GetChannelMessagesResponse {
3948 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3949 messages,
3950 })?;
3951 Ok(())
3952}
3953
3954/// Retrieve the current users notifications
3955async fn get_notifications(
3956 request: proto::GetNotifications,
3957 response: Response<proto::GetNotifications>,
3958 session: Session,
3959) -> Result<()> {
3960 let notifications = session
3961 .db()
3962 .await
3963 .get_notifications(
3964 session.user_id(),
3965 NOTIFICATION_COUNT_PER_PAGE,
3966 request.before_id.map(db::NotificationId::from_proto),
3967 )
3968 .await?;
3969 response.send(proto::GetNotificationsResponse {
3970 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3971 notifications,
3972 })?;
3973 Ok(())
3974}
3975
3976/// Mark notifications as read
3977async fn mark_notification_as_read(
3978 request: proto::MarkNotificationRead,
3979 response: Response<proto::MarkNotificationRead>,
3980 session: Session,
3981) -> Result<()> {
3982 let database = &session.db().await;
3983 let notifications = database
3984 .mark_notification_as_read_by_id(
3985 session.user_id(),
3986 NotificationId::from_proto(request.notification_id),
3987 )
3988 .await?;
3989 send_notifications(
3990 &*session.connection_pool().await,
3991 &session.peer,
3992 notifications,
3993 );
3994 response.send(proto::Ack {})?;
3995 Ok(())
3996}
3997
3998/// Get the current users information
3999async fn get_private_user_info(
4000 _request: proto::GetPrivateUserInfo,
4001 response: Response<proto::GetPrivateUserInfo>,
4002 session: Session,
4003) -> Result<()> {
4004 let db = session.db().await;
4005
4006 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4007 let user = db
4008 .get_user_by_id(session.user_id())
4009 .await?
4010 .ok_or_else(|| anyhow!("user not found"))?;
4011 let flags = db.get_user_flags(session.user_id()).await?;
4012
4013 response.send(proto::GetPrivateUserInfoResponse {
4014 metrics_id,
4015 staff: user.admin,
4016 flags,
4017 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4018 })?;
4019 Ok(())
4020}
4021
4022/// Accept the terms of service (tos) on behalf of the current user
4023async fn accept_terms_of_service(
4024 _request: proto::AcceptTermsOfService,
4025 response: Response<proto::AcceptTermsOfService>,
4026 session: Session,
4027) -> Result<()> {
4028 let db = session.db().await;
4029
4030 let accepted_tos_at = Utc::now();
4031 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4032 .await?;
4033
4034 response.send(proto::AcceptTermsOfServiceResponse {
4035 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4036 })?;
4037 Ok(())
4038}
4039
4040/// The minimum account age an account must have in order to use the LLM service.
4041pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4042
4043async fn get_llm_api_token(
4044 _request: proto::GetLlmToken,
4045 response: Response<proto::GetLlmToken>,
4046 session: Session,
4047) -> Result<()> {
4048 let db = session.db().await;
4049
4050 let flags = db.get_user_flags(session.user_id()).await?;
4051 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4052
4053 if !session.is_staff() && !has_language_models_feature_flag {
4054 Err(anyhow!("permission denied"))?
4055 }
4056
4057 let user_id = session.user_id();
4058 let user = db
4059 .get_user_by_id(user_id)
4060 .await?
4061 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4062
4063 if user.accepted_tos_at.is_none() {
4064 Err(anyhow!("terms of service not accepted"))?
4065 }
4066
4067 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4068 let billing_preferences = db.get_billing_preferences(user.id).await?;
4069
4070 let token = LlmTokenClaims::create(
4071 &user,
4072 session.is_staff(),
4073 billing_preferences,
4074 &flags,
4075 has_llm_subscription,
4076 session.current_plan(&db).await?,
4077 session.system_id.clone(),
4078 &session.app_state.config,
4079 )?;
4080 response.send(proto::GetLlmTokenResponse { token })?;
4081 Ok(())
4082}
4083
4084fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4085 let message = match message {
4086 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4087 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4088 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4089 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4090 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4091 code: frame.code.into(),
4092 reason: frame.reason,
4093 })),
4094 // We should never receive a frame while reading the message, according
4095 // to the `tungstenite` maintainers:
4096 //
4097 // > It cannot occur when you read messages from the WebSocket, but it
4098 // > can be used when you want to send the raw frames (e.g. you want to
4099 // > send the frames to the WebSocket without composing the full message first).
4100 // >
4101 // > — https://github.com/snapview/tungstenite-rs/issues/268
4102 TungsteniteMessage::Frame(_) => {
4103 bail!("received an unexpected frame while reading the message")
4104 }
4105 };
4106
4107 Ok(message)
4108}
4109
4110fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4111 match message {
4112 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4113 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4114 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4115 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4116 AxumMessage::Close(frame) => {
4117 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4118 code: frame.code.into(),
4119 reason: frame.reason,
4120 }))
4121 }
4122 }
4123}
4124
4125fn notify_membership_updated(
4126 connection_pool: &mut ConnectionPool,
4127 result: MembershipUpdated,
4128 user_id: UserId,
4129 peer: &Peer,
4130) {
4131 for membership in &result.new_channels.channel_memberships {
4132 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4133 }
4134 for channel_id in &result.removed_channels {
4135 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4136 }
4137
4138 let user_channels_update = proto::UpdateUserChannels {
4139 channel_memberships: result
4140 .new_channels
4141 .channel_memberships
4142 .iter()
4143 .map(|cm| proto::ChannelMembership {
4144 channel_id: cm.channel_id.to_proto(),
4145 role: cm.role.into(),
4146 })
4147 .collect(),
4148 ..Default::default()
4149 };
4150
4151 let mut update = build_channels_update(result.new_channels);
4152 update.delete_channels = result
4153 .removed_channels
4154 .into_iter()
4155 .map(|id| id.to_proto())
4156 .collect();
4157 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4158
4159 for connection_id in connection_pool.user_connection_ids(user_id) {
4160 peer.send(connection_id, user_channels_update.clone())
4161 .trace_err();
4162 peer.send(connection_id, update.clone()).trace_err();
4163 }
4164}
4165
4166fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4167 proto::UpdateUserChannels {
4168 channel_memberships: channels
4169 .channel_memberships
4170 .iter()
4171 .map(|m| proto::ChannelMembership {
4172 channel_id: m.channel_id.to_proto(),
4173 role: m.role.into(),
4174 })
4175 .collect(),
4176 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4177 observed_channel_message_id: channels.observed_channel_messages.clone(),
4178 }
4179}
4180
4181fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4182 let mut update = proto::UpdateChannels::default();
4183
4184 for channel in channels.channels {
4185 update.channels.push(channel.to_proto());
4186 }
4187
4188 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4189 update.latest_channel_message_ids = channels.latest_channel_messages;
4190
4191 for (channel_id, participants) in channels.channel_participants {
4192 update
4193 .channel_participants
4194 .push(proto::ChannelParticipants {
4195 channel_id: channel_id.to_proto(),
4196 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4197 });
4198 }
4199
4200 for channel in channels.invited_channels {
4201 update.channel_invitations.push(channel.to_proto());
4202 }
4203
4204 update
4205}
4206
4207fn build_initial_contacts_update(
4208 contacts: Vec<db::Contact>,
4209 pool: &ConnectionPool,
4210) -> proto::UpdateContacts {
4211 let mut update = proto::UpdateContacts::default();
4212
4213 for contact in contacts {
4214 match contact {
4215 db::Contact::Accepted { user_id, busy } => {
4216 update.contacts.push(contact_for_user(user_id, busy, pool));
4217 }
4218 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4219 db::Contact::Incoming { user_id } => {
4220 update
4221 .incoming_requests
4222 .push(proto::IncomingContactRequest {
4223 requester_id: user_id.to_proto(),
4224 })
4225 }
4226 }
4227 }
4228
4229 update
4230}
4231
4232fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4233 proto::Contact {
4234 user_id: user_id.to_proto(),
4235 online: pool.is_user_online(user_id),
4236 busy,
4237 }
4238}
4239
4240fn room_updated(room: &proto::Room, peer: &Peer) {
4241 broadcast(
4242 None,
4243 room.participants
4244 .iter()
4245 .filter_map(|participant| Some(participant.peer_id?.into())),
4246 |peer_id| {
4247 peer.send(
4248 peer_id,
4249 proto::RoomUpdated {
4250 room: Some(room.clone()),
4251 },
4252 )
4253 },
4254 );
4255}
4256
4257fn channel_updated(
4258 channel: &db::channel::Model,
4259 room: &proto::Room,
4260 peer: &Peer,
4261 pool: &ConnectionPool,
4262) {
4263 let participants = room
4264 .participants
4265 .iter()
4266 .map(|p| p.user_id)
4267 .collect::<Vec<_>>();
4268
4269 broadcast(
4270 None,
4271 pool.channel_connection_ids(channel.root_id())
4272 .filter_map(|(channel_id, role)| {
4273 role.can_see_channel(channel.visibility)
4274 .then_some(channel_id)
4275 }),
4276 |peer_id| {
4277 peer.send(
4278 peer_id,
4279 proto::UpdateChannels {
4280 channel_participants: vec![proto::ChannelParticipants {
4281 channel_id: channel.id.to_proto(),
4282 participant_user_ids: participants.clone(),
4283 }],
4284 ..Default::default()
4285 },
4286 )
4287 },
4288 );
4289}
4290
4291async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4292 let db = session.db().await;
4293
4294 let contacts = db.get_contacts(user_id).await?;
4295 let busy = db.is_user_busy(user_id).await?;
4296
4297 let pool = session.connection_pool().await;
4298 let updated_contact = contact_for_user(user_id, busy, &pool);
4299 for contact in contacts {
4300 if let db::Contact::Accepted {
4301 user_id: contact_user_id,
4302 ..
4303 } = contact
4304 {
4305 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4306 session
4307 .peer
4308 .send(
4309 contact_conn_id,
4310 proto::UpdateContacts {
4311 contacts: vec![updated_contact.clone()],
4312 remove_contacts: Default::default(),
4313 incoming_requests: Default::default(),
4314 remove_incoming_requests: Default::default(),
4315 outgoing_requests: Default::default(),
4316 remove_outgoing_requests: Default::default(),
4317 },
4318 )
4319 .trace_err();
4320 }
4321 }
4322 }
4323 Ok(())
4324}
4325
4326async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4327 let mut contacts_to_update = HashSet::default();
4328
4329 let room_id;
4330 let canceled_calls_to_user_ids;
4331 let livekit_room;
4332 let delete_livekit_room;
4333 let room;
4334 let channel;
4335
4336 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4337 contacts_to_update.insert(session.user_id());
4338
4339 for project in left_room.left_projects.values() {
4340 project_left(project, session);
4341 }
4342
4343 room_id = RoomId::from_proto(left_room.room.id);
4344 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4345 livekit_room = mem::take(&mut left_room.room.livekit_room);
4346 delete_livekit_room = left_room.deleted;
4347 room = mem::take(&mut left_room.room);
4348 channel = mem::take(&mut left_room.channel);
4349
4350 room_updated(&room, &session.peer);
4351 } else {
4352 return Ok(());
4353 }
4354
4355 if let Some(channel) = channel {
4356 channel_updated(
4357 &channel,
4358 &room,
4359 &session.peer,
4360 &*session.connection_pool().await,
4361 );
4362 }
4363
4364 {
4365 let pool = session.connection_pool().await;
4366 for canceled_user_id in canceled_calls_to_user_ids {
4367 for connection_id in pool.user_connection_ids(canceled_user_id) {
4368 session
4369 .peer
4370 .send(
4371 connection_id,
4372 proto::CallCanceled {
4373 room_id: room_id.to_proto(),
4374 },
4375 )
4376 .trace_err();
4377 }
4378 contacts_to_update.insert(canceled_user_id);
4379 }
4380 }
4381
4382 for contact_user_id in contacts_to_update {
4383 update_user_contacts(contact_user_id, session).await?;
4384 }
4385
4386 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4387 live_kit
4388 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4389 .await
4390 .trace_err();
4391
4392 if delete_livekit_room {
4393 live_kit.delete_room(livekit_room).await.trace_err();
4394 }
4395 }
4396
4397 Ok(())
4398}
4399
4400async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4401 let left_channel_buffers = session
4402 .db()
4403 .await
4404 .leave_channel_buffers(session.connection_id)
4405 .await?;
4406
4407 for left_buffer in left_channel_buffers {
4408 channel_buffer_updated(
4409 session.connection_id,
4410 left_buffer.connections,
4411 &proto::UpdateChannelBufferCollaborators {
4412 channel_id: left_buffer.channel_id.to_proto(),
4413 collaborators: left_buffer.collaborators,
4414 },
4415 &session.peer,
4416 );
4417 }
4418
4419 Ok(())
4420}
4421
4422fn project_left(project: &db::LeftProject, session: &Session) {
4423 for connection_id in &project.connection_ids {
4424 if project.should_unshare {
4425 session
4426 .peer
4427 .send(
4428 *connection_id,
4429 proto::UnshareProject {
4430 project_id: project.id.to_proto(),
4431 },
4432 )
4433 .trace_err();
4434 } else {
4435 session
4436 .peer
4437 .send(
4438 *connection_id,
4439 proto::RemoveProjectCollaborator {
4440 project_id: project.id.to_proto(),
4441 peer_id: Some(session.connection_id.into()),
4442 },
4443 )
4444 .trace_err();
4445 }
4446 }
4447}
4448
4449pub trait ResultExt {
4450 type Ok;
4451
4452 fn trace_err(self) -> Option<Self::Ok>;
4453}
4454
4455impl<T, E> ResultExt for Result<T, E>
4456where
4457 E: std::fmt::Debug,
4458{
4459 type Ok = T;
4460
4461 #[track_caller]
4462 fn trace_err(self) -> Option<T> {
4463 match self {
4464 Ok(value) => Some(value),
4465 Err(error) => {
4466 tracing::error!("{:?}", error);
4467 None
4468 }
4469 }
4470 }
4471}