1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
314 .add_request_handler(
315 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
316 )
317 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
318 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
323 .add_request_handler(
324 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
325 )
326 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
327 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
328 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
332 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
333 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
341 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
342 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
343 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
344 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
345 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
346 .add_message_handler(create_buffer_for_peer)
347 .add_request_handler(update_buffer)
348 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
349 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
350 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
353 .add_request_handler(get_users)
354 .add_request_handler(fuzzy_search_users)
355 .add_request_handler(request_contact)
356 .add_request_handler(remove_contact)
357 .add_request_handler(respond_to_contact_request)
358 .add_message_handler(subscribe_to_channels)
359 .add_request_handler(create_channel)
360 .add_request_handler(delete_channel)
361 .add_request_handler(invite_channel_member)
362 .add_request_handler(remove_channel_member)
363 .add_request_handler(set_channel_member_role)
364 .add_request_handler(set_channel_visibility)
365 .add_request_handler(rename_channel)
366 .add_request_handler(join_channel_buffer)
367 .add_request_handler(leave_channel_buffer)
368 .add_message_handler(update_channel_buffer)
369 .add_request_handler(rejoin_channel_buffers)
370 .add_request_handler(get_channel_members)
371 .add_request_handler(respond_to_channel_invite)
372 .add_request_handler(join_channel)
373 .add_request_handler(join_channel_chat)
374 .add_message_handler(leave_channel_chat)
375 .add_request_handler(send_channel_message)
376 .add_request_handler(remove_channel_message)
377 .add_request_handler(update_channel_message)
378 .add_request_handler(get_channel_messages)
379 .add_request_handler(get_channel_messages_by_id)
380 .add_request_handler(get_notifications)
381 .add_request_handler(mark_notification_as_read)
382 .add_request_handler(move_channel)
383 .add_request_handler(follow)
384 .add_message_handler(unfollow)
385 .add_message_handler(update_followers)
386 .add_request_handler(get_private_user_info)
387 .add_request_handler(get_llm_api_token)
388 .add_request_handler(accept_terms_of_service)
389 .add_message_handler(acknowledge_channel_message)
390 .add_message_handler(acknowledge_buffer_version)
391 .add_request_handler(get_supermaven_api_key)
392 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
393 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
394 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
395 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
396 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
397 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
398 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
399 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
400 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
401 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
402 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
403 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
404 .add_message_handler(update_context)
405 .add_request_handler({
406 let app_state = app_state.clone();
407 move |request, response, session| {
408 let app_state = app_state.clone();
409 async move {
410 count_language_model_tokens(request, response, session, &app_state.config)
411 .await
412 }
413 }
414 })
415 .add_request_handler(get_cached_embeddings)
416 .add_request_handler({
417 let app_state = app_state.clone();
418 move |request, response, session| {
419 compute_embeddings(
420 request,
421 response,
422 session,
423 app_state.config.openai_api_key.clone(),
424 )
425 }
426 });
427
428 Arc::new(server)
429 }
430
431 pub async fn start(&self) -> Result<()> {
432 let server_id = *self.id.lock();
433 let app_state = self.app_state.clone();
434 let peer = self.peer.clone();
435 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
436 let pool = self.connection_pool.clone();
437 let livekit_client = self.app_state.livekit_client.clone();
438
439 let span = info_span!("start server");
440 self.app_state.executor.spawn_detached(
441 async move {
442 tracing::info!("waiting for cleanup timeout");
443 timeout.await;
444 tracing::info!("cleanup timeout expired, retrieving stale rooms");
445 if let Some((room_ids, channel_ids)) = app_state
446 .db
447 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
448 .await
449 .trace_err()
450 {
451 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
452 tracing::info!(
453 stale_channel_buffer_count = channel_ids.len(),
454 "retrieved stale channel buffers"
455 );
456
457 for channel_id in channel_ids {
458 if let Some(refreshed_channel_buffer) = app_state
459 .db
460 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
461 .await
462 .trace_err()
463 {
464 for connection_id in refreshed_channel_buffer.connection_ids {
465 peer.send(
466 connection_id,
467 proto::UpdateChannelBufferCollaborators {
468 channel_id: channel_id.to_proto(),
469 collaborators: refreshed_channel_buffer
470 .collaborators
471 .clone(),
472 },
473 )
474 .trace_err();
475 }
476 }
477 }
478
479 for room_id in room_ids {
480 let mut contacts_to_update = HashSet::default();
481 let mut canceled_calls_to_user_ids = Vec::new();
482 let mut livekit_room = String::new();
483 let mut delete_livekit_room = false;
484
485 if let Some(mut refreshed_room) = app_state
486 .db
487 .clear_stale_room_participants(room_id, server_id)
488 .await
489 .trace_err()
490 {
491 tracing::info!(
492 room_id = room_id.0,
493 new_participant_count = refreshed_room.room.participants.len(),
494 "refreshed room"
495 );
496 room_updated(&refreshed_room.room, &peer);
497 if let Some(channel) = refreshed_room.channel.as_ref() {
498 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
499 }
500 contacts_to_update
501 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
502 contacts_to_update
503 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
504 canceled_calls_to_user_ids =
505 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
506 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
507 delete_livekit_room = refreshed_room.room.participants.is_empty();
508 }
509
510 {
511 let pool = pool.lock();
512 for canceled_user_id in canceled_calls_to_user_ids {
513 for connection_id in pool.user_connection_ids(canceled_user_id) {
514 peer.send(
515 connection_id,
516 proto::CallCanceled {
517 room_id: room_id.to_proto(),
518 },
519 )
520 .trace_err();
521 }
522 }
523 }
524
525 for user_id in contacts_to_update {
526 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
527 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
528 if let Some((busy, contacts)) = busy.zip(contacts) {
529 let pool = pool.lock();
530 let updated_contact = contact_for_user(user_id, busy, &pool);
531 for contact in contacts {
532 if let db::Contact::Accepted {
533 user_id: contact_user_id,
534 ..
535 } = contact
536 {
537 for contact_conn_id in
538 pool.user_connection_ids(contact_user_id)
539 {
540 peer.send(
541 contact_conn_id,
542 proto::UpdateContacts {
543 contacts: vec![updated_contact.clone()],
544 remove_contacts: Default::default(),
545 incoming_requests: Default::default(),
546 remove_incoming_requests: Default::default(),
547 outgoing_requests: Default::default(),
548 remove_outgoing_requests: Default::default(),
549 },
550 )
551 .trace_err();
552 }
553 }
554 }
555 }
556 }
557
558 if let Some(live_kit) = livekit_client.as_ref() {
559 if delete_livekit_room {
560 live_kit.delete_room(livekit_room).await.trace_err();
561 }
562 }
563 }
564 }
565
566 app_state
567 .db
568 .delete_stale_servers(&app_state.config.zed_environment, server_id)
569 .await
570 .trace_err();
571 }
572 .instrument(span),
573 );
574 Ok(())
575 }
576
577 pub fn teardown(&self) {
578 self.peer.teardown();
579 self.connection_pool.lock().reset();
580 let _ = self.teardown.send(true);
581 }
582
583 #[cfg(test)]
584 pub fn reset(&self, id: ServerId) {
585 self.teardown();
586 *self.id.lock() = id;
587 self.peer.reset(id.0 as u32);
588 let _ = self.teardown.send(false);
589 }
590
591 #[cfg(test)]
592 pub fn id(&self) -> ServerId {
593 *self.id.lock()
594 }
595
596 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
597 where
598 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
599 Fut: 'static + Send + Future<Output = Result<()>>,
600 M: EnvelopedMessage,
601 {
602 let prev_handler = self.handlers.insert(
603 TypeId::of::<M>(),
604 Box::new(move |envelope, session| {
605 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
606 let received_at = envelope.received_at;
607 tracing::info!("message received");
608 let start_time = Instant::now();
609 let future = (handler)(*envelope, session);
610 async move {
611 let result = future.await;
612 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
613 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
614 let queue_duration_ms = total_duration_ms - processing_duration_ms;
615 let payload_type = M::NAME;
616
617 match result {
618 Err(error) => {
619 tracing::error!(
620 ?error,
621 total_duration_ms,
622 processing_duration_ms,
623 queue_duration_ms,
624 payload_type,
625 "error handling message"
626 )
627 }
628 Ok(()) => tracing::info!(
629 total_duration_ms,
630 processing_duration_ms,
631 queue_duration_ms,
632 "finished handling message"
633 ),
634 }
635 }
636 .boxed()
637 }),
638 );
639 if prev_handler.is_some() {
640 panic!("registered a handler for the same message twice");
641 }
642 self
643 }
644
645 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
646 where
647 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
648 Fut: 'static + Send + Future<Output = Result<()>>,
649 M: EnvelopedMessage,
650 {
651 self.add_handler(move |envelope, session| handler(envelope.payload, session));
652 self
653 }
654
655 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
656 where
657 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
658 Fut: Send + Future<Output = Result<()>>,
659 M: RequestMessage,
660 {
661 let handler = Arc::new(handler);
662 self.add_handler(move |envelope, session| {
663 let receipt = envelope.receipt();
664 let handler = handler.clone();
665 async move {
666 let peer = session.peer.clone();
667 let responded = Arc::new(AtomicBool::default());
668 let response = Response {
669 peer: peer.clone(),
670 responded: responded.clone(),
671 receipt,
672 };
673 match (handler)(envelope.payload, response, session).await {
674 Ok(()) => {
675 if responded.load(std::sync::atomic::Ordering::SeqCst) {
676 Ok(())
677 } else {
678 Err(anyhow!("handler did not send a response"))?
679 }
680 }
681 Err(error) => {
682 let proto_err = match &error {
683 Error::Internal(err) => err.to_proto(),
684 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
685 };
686 peer.respond_with_error(receipt, proto_err)?;
687 Err(error)
688 }
689 }
690 }
691 })
692 }
693
694 #[allow(clippy::too_many_arguments)]
695 pub fn handle_connection(
696 self: &Arc<Self>,
697 connection: Connection,
698 address: String,
699 principal: Principal,
700 zed_version: ZedVersion,
701 geoip_country_code: Option<String>,
702 system_id: Option<String>,
703 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
704 executor: Executor,
705 ) -> impl Future<Output = ()> {
706 let this = self.clone();
707 let span = info_span!("handle connection", %address,
708 connection_id=field::Empty,
709 user_id=field::Empty,
710 login=field::Empty,
711 impersonator=field::Empty,
712 geoip_country_code=field::Empty
713 );
714 principal.update_span(&span);
715 if let Some(country_code) = geoip_country_code.as_ref() {
716 span.record("geoip_country_code", country_code);
717 }
718
719 let mut teardown = self.teardown.subscribe();
720 async move {
721 if *teardown.borrow() {
722 tracing::error!("server is tearing down");
723 return
724 }
725 let (connection_id, handle_io, mut incoming_rx) = this
726 .peer
727 .add_connection(connection, {
728 let executor = executor.clone();
729 move |duration| executor.sleep(duration)
730 });
731 tracing::Span::current().record("connection_id", format!("{}", connection_id));
732
733 tracing::info!("connection opened");
734
735 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
736 let http_client = match ReqwestClient::user_agent(&user_agent) {
737 Ok(http_client) => Arc::new(http_client),
738 Err(error) => {
739 tracing::error!(?error, "failed to create HTTP client");
740 return;
741 }
742 };
743
744 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
745 supermaven_admin_api_key.to_string(),
746 http_client.clone(),
747 )));
748
749 let session = Session {
750 principal: principal.clone(),
751 connection_id,
752 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
753 peer: this.peer.clone(),
754 connection_pool: this.connection_pool.clone(),
755 app_state: this.app_state.clone(),
756 http_client,
757 geoip_country_code,
758 system_id,
759 _executor: executor.clone(),
760 supermaven_client,
761 };
762
763 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
764 tracing::error!(?error, "failed to send initial client update");
765 return;
766 }
767
768 let handle_io = handle_io.fuse();
769 futures::pin_mut!(handle_io);
770
771 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
772 // This prevents deadlocks when e.g., client A performs a request to client B and
773 // client B performs a request to client A. If both clients stop processing further
774 // messages until their respective request completes, they won't have a chance to
775 // respond to the other client's request and cause a deadlock.
776 //
777 // This arrangement ensures we will attempt to process earlier messages first, but fall
778 // back to processing messages arrived later in the spirit of making progress.
779 let mut foreground_message_handlers = FuturesUnordered::new();
780 let concurrent_handlers = Arc::new(Semaphore::new(256));
781 loop {
782 let next_message = async {
783 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
784 let message = incoming_rx.next().await;
785 (permit, message)
786 }.fuse();
787 futures::pin_mut!(next_message);
788 futures::select_biased! {
789 _ = teardown.changed().fuse() => return,
790 result = handle_io => {
791 if let Err(error) = result {
792 tracing::error!(?error, "error handling I/O");
793 }
794 break;
795 }
796 _ = foreground_message_handlers.next() => {}
797 next_message = next_message => {
798 let (permit, message) = next_message;
799 if let Some(message) = message {
800 let type_name = message.payload_type_name();
801 // note: we copy all the fields from the parent span so we can query them in the logs.
802 // (https://github.com/tokio-rs/tracing/issues/2670).
803 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
804 user_id=field::Empty,
805 login=field::Empty,
806 impersonator=field::Empty,
807 );
808 principal.update_span(&span);
809 let span_enter = span.enter();
810 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
811 let is_background = message.is_background();
812 let handle_message = (handler)(message, session.clone());
813 drop(span_enter);
814
815 let handle_message = async move {
816 handle_message.await;
817 drop(permit);
818 }.instrument(span);
819 if is_background {
820 executor.spawn_detached(handle_message);
821 } else {
822 foreground_message_handlers.push(handle_message);
823 }
824 } else {
825 tracing::error!("no message handler");
826 }
827 } else {
828 tracing::info!("connection closed");
829 break;
830 }
831 }
832 }
833 }
834
835 drop(foreground_message_handlers);
836 tracing::info!("signing out");
837 if let Err(error) = connection_lost(session, teardown, executor).await {
838 tracing::error!(?error, "error signing out");
839 }
840
841 }.instrument(span)
842 }
843
844 async fn send_initial_client_update(
845 &self,
846 connection_id: ConnectionId,
847 principal: &Principal,
848 zed_version: ZedVersion,
849 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
850 session: &Session,
851 ) -> Result<()> {
852 self.peer.send(
853 connection_id,
854 proto::Hello {
855 peer_id: Some(connection_id.into()),
856 },
857 )?;
858 tracing::info!("sent hello message");
859 if let Some(send_connection_id) = send_connection_id.take() {
860 let _ = send_connection_id.send(connection_id);
861 }
862
863 match principal {
864 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
865 if !user.connected_once {
866 self.peer.send(connection_id, proto::ShowContacts {})?;
867 self.app_state
868 .db
869 .set_user_connected_once(user.id, true)
870 .await?;
871 }
872
873 update_user_plan(user.id, session).await?;
874
875 let contacts = self.app_state.db.get_contacts(user.id).await?;
876
877 {
878 let mut pool = self.connection_pool.lock();
879 pool.add_connection(connection_id, user.id, user.admin, zed_version);
880 self.peer.send(
881 connection_id,
882 build_initial_contacts_update(contacts, &pool),
883 )?;
884 }
885
886 if should_auto_subscribe_to_channels(zed_version) {
887 subscribe_user_to_channels(user.id, session).await?;
888 }
889
890 if let Some(incoming_call) =
891 self.app_state.db.incoming_call_for_user(user.id).await?
892 {
893 self.peer.send(connection_id, incoming_call)?;
894 }
895
896 update_user_contacts(user.id, session).await?;
897 }
898 }
899
900 Ok(())
901 }
902
903 pub async fn invite_code_redeemed(
904 self: &Arc<Self>,
905 inviter_id: UserId,
906 invitee_id: UserId,
907 ) -> Result<()> {
908 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
909 if let Some(code) = &user.invite_code {
910 let pool = self.connection_pool.lock();
911 let invitee_contact = contact_for_user(invitee_id, false, &pool);
912 for connection_id in pool.user_connection_ids(inviter_id) {
913 self.peer.send(
914 connection_id,
915 proto::UpdateContacts {
916 contacts: vec![invitee_contact.clone()],
917 ..Default::default()
918 },
919 )?;
920 self.peer.send(
921 connection_id,
922 proto::UpdateInviteInfo {
923 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
924 count: user.invite_count as u32,
925 },
926 )?;
927 }
928 }
929 }
930 Ok(())
931 }
932
933 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
934 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
935 if let Some(invite_code) = &user.invite_code {
936 let pool = self.connection_pool.lock();
937 for connection_id in pool.user_connection_ids(user_id) {
938 self.peer.send(
939 connection_id,
940 proto::UpdateInviteInfo {
941 url: format!(
942 "{}{}",
943 self.app_state.config.invite_link_prefix, invite_code
944 ),
945 count: user.invite_count as u32,
946 },
947 )?;
948 }
949 }
950 }
951 Ok(())
952 }
953
954 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
955 let pool = self.connection_pool.lock();
956 for connection_id in pool.user_connection_ids(user_id) {
957 self.peer
958 .send(connection_id, proto::RefreshLlmToken {})
959 .trace_err();
960 }
961 }
962
963 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
964 ServerSnapshot {
965 connection_pool: ConnectionPoolGuard {
966 guard: self.connection_pool.lock(),
967 _not_send: PhantomData,
968 },
969 peer: &self.peer,
970 }
971 }
972}
973
974impl<'a> Deref for ConnectionPoolGuard<'a> {
975 type Target = ConnectionPool;
976
977 fn deref(&self) -> &Self::Target {
978 &self.guard
979 }
980}
981
982impl<'a> DerefMut for ConnectionPoolGuard<'a> {
983 fn deref_mut(&mut self) -> &mut Self::Target {
984 &mut self.guard
985 }
986}
987
988impl<'a> Drop for ConnectionPoolGuard<'a> {
989 fn drop(&mut self) {
990 #[cfg(test)]
991 self.check_invariants();
992 }
993}
994
995fn broadcast<F>(
996 sender_id: Option<ConnectionId>,
997 receiver_ids: impl IntoIterator<Item = ConnectionId>,
998 mut f: F,
999) where
1000 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1001{
1002 for receiver_id in receiver_ids {
1003 if Some(receiver_id) != sender_id {
1004 if let Err(error) = f(receiver_id) {
1005 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1006 }
1007 }
1008 }
1009}
1010
1011pub struct ProtocolVersion(u32);
1012
1013impl Header for ProtocolVersion {
1014 fn name() -> &'static HeaderName {
1015 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1016 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1017 }
1018
1019 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1020 where
1021 Self: Sized,
1022 I: Iterator<Item = &'i axum::http::HeaderValue>,
1023 {
1024 let version = values
1025 .next()
1026 .ok_or_else(axum::headers::Error::invalid)?
1027 .to_str()
1028 .map_err(|_| axum::headers::Error::invalid())?
1029 .parse()
1030 .map_err(|_| axum::headers::Error::invalid())?;
1031 Ok(Self(version))
1032 }
1033
1034 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1035 values.extend([self.0.to_string().parse().unwrap()]);
1036 }
1037}
1038
1039pub struct AppVersionHeader(SemanticVersion);
1040impl Header for AppVersionHeader {
1041 fn name() -> &'static HeaderName {
1042 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1043 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1044 }
1045
1046 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1047 where
1048 Self: Sized,
1049 I: Iterator<Item = &'i axum::http::HeaderValue>,
1050 {
1051 let version = values
1052 .next()
1053 .ok_or_else(axum::headers::Error::invalid)?
1054 .to_str()
1055 .map_err(|_| axum::headers::Error::invalid())?
1056 .parse()
1057 .map_err(|_| axum::headers::Error::invalid())?;
1058 Ok(Self(version))
1059 }
1060
1061 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1062 values.extend([self.0.to_string().parse().unwrap()]);
1063 }
1064}
1065
1066pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1067 Router::new()
1068 .route("/rpc", get(handle_websocket_request))
1069 .layer(
1070 ServiceBuilder::new()
1071 .layer(Extension(server.app_state.clone()))
1072 .layer(middleware::from_fn(auth::validate_header)),
1073 )
1074 .route("/metrics", get(handle_metrics))
1075 .layer(Extension(server))
1076}
1077
1078#[allow(clippy::too_many_arguments)]
1079pub async fn handle_websocket_request(
1080 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1081 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1082 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1083 Extension(server): Extension<Arc<Server>>,
1084 Extension(principal): Extension<Principal>,
1085 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1086 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1087 ws: WebSocketUpgrade,
1088) -> axum::response::Response {
1089 if protocol_version != rpc::PROTOCOL_VERSION {
1090 return (
1091 StatusCode::UPGRADE_REQUIRED,
1092 "client must be upgraded".to_string(),
1093 )
1094 .into_response();
1095 }
1096
1097 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1098 return (
1099 StatusCode::UPGRADE_REQUIRED,
1100 "no version header found".to_string(),
1101 )
1102 .into_response();
1103 };
1104
1105 if !version.can_collaborate() {
1106 return (
1107 StatusCode::UPGRADE_REQUIRED,
1108 "client must be upgraded".to_string(),
1109 )
1110 .into_response();
1111 }
1112
1113 let socket_address = socket_address.to_string();
1114 ws.on_upgrade(move |socket| {
1115 let socket = socket
1116 .map_ok(to_tungstenite_message)
1117 .err_into()
1118 .with(|message| async move { to_axum_message(message) });
1119 let connection = Connection::new(Box::pin(socket));
1120 async move {
1121 server
1122 .handle_connection(
1123 connection,
1124 socket_address,
1125 principal,
1126 version,
1127 country_code_header.map(|header| header.to_string()),
1128 system_id_header.map(|header| header.to_string()),
1129 None,
1130 Executor::Production,
1131 )
1132 .await;
1133 }
1134 })
1135}
1136
1137pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1138 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1139 let connections_metric = CONNECTIONS_METRIC
1140 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1141
1142 let connections = server
1143 .connection_pool
1144 .lock()
1145 .connections()
1146 .filter(|connection| !connection.admin)
1147 .count();
1148 connections_metric.set(connections as _);
1149
1150 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1151 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1152 register_int_gauge!(
1153 "shared_projects",
1154 "number of open projects with one or more guests"
1155 )
1156 .unwrap()
1157 });
1158
1159 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1160 shared_projects_metric.set(shared_projects as _);
1161
1162 let encoder = prometheus::TextEncoder::new();
1163 let metric_families = prometheus::gather();
1164 let encoded_metrics = encoder
1165 .encode_to_string(&metric_families)
1166 .map_err(|err| anyhow!("{}", err))?;
1167 Ok(encoded_metrics)
1168}
1169
1170#[instrument(err, skip(executor))]
1171async fn connection_lost(
1172 session: Session,
1173 mut teardown: watch::Receiver<bool>,
1174 executor: Executor,
1175) -> Result<()> {
1176 session.peer.disconnect(session.connection_id);
1177 session
1178 .connection_pool()
1179 .await
1180 .remove_connection(session.connection_id)?;
1181
1182 session
1183 .db()
1184 .await
1185 .connection_lost(session.connection_id)
1186 .await
1187 .trace_err();
1188
1189 futures::select_biased! {
1190 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1191
1192 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1193 leave_room_for_session(&session, session.connection_id).await.trace_err();
1194 leave_channel_buffers_for_session(&session)
1195 .await
1196 .trace_err();
1197
1198 if !session
1199 .connection_pool()
1200 .await
1201 .is_user_online(session.user_id())
1202 {
1203 let db = session.db().await;
1204 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1205 room_updated(&room, &session.peer);
1206 }
1207 }
1208
1209 update_user_contacts(session.user_id(), &session).await?;
1210 },
1211 _ = teardown.changed().fuse() => {}
1212 }
1213
1214 Ok(())
1215}
1216
1217/// Acknowledges a ping from a client, used to keep the connection alive.
1218async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1219 response.send(proto::Ack {})?;
1220 Ok(())
1221}
1222
1223/// Creates a new room for calling (outside of channels)
1224async fn create_room(
1225 _request: proto::CreateRoom,
1226 response: Response<proto::CreateRoom>,
1227 session: Session,
1228) -> Result<()> {
1229 let livekit_room = nanoid::nanoid!(30);
1230
1231 let live_kit_connection_info = util::maybe!(async {
1232 let live_kit = session.app_state.livekit_client.as_ref();
1233 let live_kit = live_kit?;
1234 let user_id = session.user_id().to_string();
1235
1236 let token = live_kit
1237 .room_token(&livekit_room, &user_id.to_string())
1238 .trace_err()?;
1239
1240 Some(proto::LiveKitConnectionInfo {
1241 server_url: live_kit.url().into(),
1242 token,
1243 can_publish: true,
1244 })
1245 })
1246 .await;
1247
1248 let room = session
1249 .db()
1250 .await
1251 .create_room(session.user_id(), session.connection_id, &livekit_room)
1252 .await?;
1253
1254 response.send(proto::CreateRoomResponse {
1255 room: Some(room.clone()),
1256 live_kit_connection_info,
1257 })?;
1258
1259 update_user_contacts(session.user_id(), &session).await?;
1260 Ok(())
1261}
1262
1263/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1264async fn join_room(
1265 request: proto::JoinRoom,
1266 response: Response<proto::JoinRoom>,
1267 session: Session,
1268) -> Result<()> {
1269 let room_id = RoomId::from_proto(request.id);
1270
1271 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1272
1273 if let Some(channel_id) = channel_id {
1274 return join_channel_internal(channel_id, Box::new(response), session).await;
1275 }
1276
1277 let joined_room = {
1278 let room = session
1279 .db()
1280 .await
1281 .join_room(room_id, session.user_id(), session.connection_id)
1282 .await?;
1283 room_updated(&room.room, &session.peer);
1284 room.into_inner()
1285 };
1286
1287 for connection_id in session
1288 .connection_pool()
1289 .await
1290 .user_connection_ids(session.user_id())
1291 {
1292 session
1293 .peer
1294 .send(
1295 connection_id,
1296 proto::CallCanceled {
1297 room_id: room_id.to_proto(),
1298 },
1299 )
1300 .trace_err();
1301 }
1302
1303 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1304 {
1305 live_kit
1306 .room_token(
1307 &joined_room.room.livekit_room,
1308 &session.user_id().to_string(),
1309 )
1310 .trace_err()
1311 .map(|token| proto::LiveKitConnectionInfo {
1312 server_url: live_kit.url().into(),
1313 token,
1314 can_publish: true,
1315 })
1316 } else {
1317 None
1318 };
1319
1320 response.send(proto::JoinRoomResponse {
1321 room: Some(joined_room.room),
1322 channel_id: None,
1323 live_kit_connection_info,
1324 })?;
1325
1326 update_user_contacts(session.user_id(), &session).await?;
1327 Ok(())
1328}
1329
1330/// Rejoin room is used to reconnect to a room after connection errors.
1331async fn rejoin_room(
1332 request: proto::RejoinRoom,
1333 response: Response<proto::RejoinRoom>,
1334 session: Session,
1335) -> Result<()> {
1336 let room;
1337 let channel;
1338 {
1339 let mut rejoined_room = session
1340 .db()
1341 .await
1342 .rejoin_room(request, session.user_id(), session.connection_id)
1343 .await?;
1344
1345 response.send(proto::RejoinRoomResponse {
1346 room: Some(rejoined_room.room.clone()),
1347 reshared_projects: rejoined_room
1348 .reshared_projects
1349 .iter()
1350 .map(|project| proto::ResharedProject {
1351 id: project.id.to_proto(),
1352 collaborators: project
1353 .collaborators
1354 .iter()
1355 .map(|collaborator| collaborator.to_proto())
1356 .collect(),
1357 })
1358 .collect(),
1359 rejoined_projects: rejoined_room
1360 .rejoined_projects
1361 .iter()
1362 .map(|rejoined_project| rejoined_project.to_proto())
1363 .collect(),
1364 })?;
1365 room_updated(&rejoined_room.room, &session.peer);
1366
1367 for project in &rejoined_room.reshared_projects {
1368 for collaborator in &project.collaborators {
1369 session
1370 .peer
1371 .send(
1372 collaborator.connection_id,
1373 proto::UpdateProjectCollaborator {
1374 project_id: project.id.to_proto(),
1375 old_peer_id: Some(project.old_connection_id.into()),
1376 new_peer_id: Some(session.connection_id.into()),
1377 },
1378 )
1379 .trace_err();
1380 }
1381
1382 broadcast(
1383 Some(session.connection_id),
1384 project
1385 .collaborators
1386 .iter()
1387 .map(|collaborator| collaborator.connection_id),
1388 |connection_id| {
1389 session.peer.forward_send(
1390 session.connection_id,
1391 connection_id,
1392 proto::UpdateProject {
1393 project_id: project.id.to_proto(),
1394 worktrees: project.worktrees.clone(),
1395 },
1396 )
1397 },
1398 );
1399 }
1400
1401 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1402
1403 let rejoined_room = rejoined_room.into_inner();
1404
1405 room = rejoined_room.room;
1406 channel = rejoined_room.channel;
1407 }
1408
1409 if let Some(channel) = channel {
1410 channel_updated(
1411 &channel,
1412 &room,
1413 &session.peer,
1414 &*session.connection_pool().await,
1415 );
1416 }
1417
1418 update_user_contacts(session.user_id(), &session).await?;
1419 Ok(())
1420}
1421
1422fn notify_rejoined_projects(
1423 rejoined_projects: &mut Vec<RejoinedProject>,
1424 session: &Session,
1425) -> Result<()> {
1426 for project in rejoined_projects.iter() {
1427 for collaborator in &project.collaborators {
1428 session
1429 .peer
1430 .send(
1431 collaborator.connection_id,
1432 proto::UpdateProjectCollaborator {
1433 project_id: project.id.to_proto(),
1434 old_peer_id: Some(project.old_connection_id.into()),
1435 new_peer_id: Some(session.connection_id.into()),
1436 },
1437 )
1438 .trace_err();
1439 }
1440 }
1441
1442 for project in rejoined_projects {
1443 for worktree in mem::take(&mut project.worktrees) {
1444 // Stream this worktree's entries.
1445 let message = proto::UpdateWorktree {
1446 project_id: project.id.to_proto(),
1447 worktree_id: worktree.id,
1448 abs_path: worktree.abs_path.clone(),
1449 root_name: worktree.root_name,
1450 updated_entries: worktree.updated_entries,
1451 removed_entries: worktree.removed_entries,
1452 scan_id: worktree.scan_id,
1453 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1454 updated_repositories: worktree.updated_repositories,
1455 removed_repositories: worktree.removed_repositories,
1456 };
1457 for update in proto::split_worktree_update(message) {
1458 session.peer.send(session.connection_id, update.clone())?;
1459 }
1460
1461 // Stream this worktree's diagnostics.
1462 for summary in worktree.diagnostic_summaries {
1463 session.peer.send(
1464 session.connection_id,
1465 proto::UpdateDiagnosticSummary {
1466 project_id: project.id.to_proto(),
1467 worktree_id: worktree.id,
1468 summary: Some(summary),
1469 },
1470 )?;
1471 }
1472
1473 for settings_file in worktree.settings_files {
1474 session.peer.send(
1475 session.connection_id,
1476 proto::UpdateWorktreeSettings {
1477 project_id: project.id.to_proto(),
1478 worktree_id: worktree.id,
1479 path: settings_file.path,
1480 content: Some(settings_file.content),
1481 kind: Some(settings_file.kind.to_proto().into()),
1482 },
1483 )?;
1484 }
1485 }
1486
1487 for language_server in &project.language_servers {
1488 session.peer.send(
1489 session.connection_id,
1490 proto::UpdateLanguageServer {
1491 project_id: project.id.to_proto(),
1492 language_server_id: language_server.id,
1493 variant: Some(
1494 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1495 proto::LspDiskBasedDiagnosticsUpdated {},
1496 ),
1497 ),
1498 },
1499 )?;
1500 }
1501 }
1502 Ok(())
1503}
1504
1505/// leave room disconnects from the room.
1506async fn leave_room(
1507 _: proto::LeaveRoom,
1508 response: Response<proto::LeaveRoom>,
1509 session: Session,
1510) -> Result<()> {
1511 leave_room_for_session(&session, session.connection_id).await?;
1512 response.send(proto::Ack {})?;
1513 Ok(())
1514}
1515
1516/// Updates the permissions of someone else in the room.
1517async fn set_room_participant_role(
1518 request: proto::SetRoomParticipantRole,
1519 response: Response<proto::SetRoomParticipantRole>,
1520 session: Session,
1521) -> Result<()> {
1522 let user_id = UserId::from_proto(request.user_id);
1523 let role = ChannelRole::from(request.role());
1524
1525 let (livekit_room, can_publish) = {
1526 let room = session
1527 .db()
1528 .await
1529 .set_room_participant_role(
1530 session.user_id(),
1531 RoomId::from_proto(request.room_id),
1532 user_id,
1533 role,
1534 )
1535 .await?;
1536
1537 let livekit_room = room.livekit_room.clone();
1538 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1539 room_updated(&room, &session.peer);
1540 (livekit_room, can_publish)
1541 };
1542
1543 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1544 live_kit
1545 .update_participant(
1546 livekit_room.clone(),
1547 request.user_id.to_string(),
1548 livekit_api::proto::ParticipantPermission {
1549 can_subscribe: true,
1550 can_publish,
1551 can_publish_data: can_publish,
1552 hidden: false,
1553 recorder: false,
1554 },
1555 )
1556 .await
1557 .trace_err();
1558 }
1559
1560 response.send(proto::Ack {})?;
1561 Ok(())
1562}
1563
1564/// Call someone else into the current room
1565async fn call(
1566 request: proto::Call,
1567 response: Response<proto::Call>,
1568 session: Session,
1569) -> Result<()> {
1570 let room_id = RoomId::from_proto(request.room_id);
1571 let calling_user_id = session.user_id();
1572 let calling_connection_id = session.connection_id;
1573 let called_user_id = UserId::from_proto(request.called_user_id);
1574 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1575 if !session
1576 .db()
1577 .await
1578 .has_contact(calling_user_id, called_user_id)
1579 .await?
1580 {
1581 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1582 }
1583
1584 let incoming_call = {
1585 let (room, incoming_call) = &mut *session
1586 .db()
1587 .await
1588 .call(
1589 room_id,
1590 calling_user_id,
1591 calling_connection_id,
1592 called_user_id,
1593 initial_project_id,
1594 )
1595 .await?;
1596 room_updated(room, &session.peer);
1597 mem::take(incoming_call)
1598 };
1599 update_user_contacts(called_user_id, &session).await?;
1600
1601 let mut calls = session
1602 .connection_pool()
1603 .await
1604 .user_connection_ids(called_user_id)
1605 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1606 .collect::<FuturesUnordered<_>>();
1607
1608 while let Some(call_response) = calls.next().await {
1609 match call_response.as_ref() {
1610 Ok(_) => {
1611 response.send(proto::Ack {})?;
1612 return Ok(());
1613 }
1614 Err(_) => {
1615 call_response.trace_err();
1616 }
1617 }
1618 }
1619
1620 {
1621 let room = session
1622 .db()
1623 .await
1624 .call_failed(room_id, called_user_id)
1625 .await?;
1626 room_updated(&room, &session.peer);
1627 }
1628 update_user_contacts(called_user_id, &session).await?;
1629
1630 Err(anyhow!("failed to ring user"))?
1631}
1632
1633/// Cancel an outgoing call.
1634async fn cancel_call(
1635 request: proto::CancelCall,
1636 response: Response<proto::CancelCall>,
1637 session: Session,
1638) -> Result<()> {
1639 let called_user_id = UserId::from_proto(request.called_user_id);
1640 let room_id = RoomId::from_proto(request.room_id);
1641 {
1642 let room = session
1643 .db()
1644 .await
1645 .cancel_call(room_id, session.connection_id, called_user_id)
1646 .await?;
1647 room_updated(&room, &session.peer);
1648 }
1649
1650 for connection_id in session
1651 .connection_pool()
1652 .await
1653 .user_connection_ids(called_user_id)
1654 {
1655 session
1656 .peer
1657 .send(
1658 connection_id,
1659 proto::CallCanceled {
1660 room_id: room_id.to_proto(),
1661 },
1662 )
1663 .trace_err();
1664 }
1665 response.send(proto::Ack {})?;
1666
1667 update_user_contacts(called_user_id, &session).await?;
1668 Ok(())
1669}
1670
1671/// Decline an incoming call.
1672async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1673 let room_id = RoomId::from_proto(message.room_id);
1674 {
1675 let room = session
1676 .db()
1677 .await
1678 .decline_call(Some(room_id), session.user_id())
1679 .await?
1680 .ok_or_else(|| anyhow!("failed to decline call"))?;
1681 room_updated(&room, &session.peer);
1682 }
1683
1684 for connection_id in session
1685 .connection_pool()
1686 .await
1687 .user_connection_ids(session.user_id())
1688 {
1689 session
1690 .peer
1691 .send(
1692 connection_id,
1693 proto::CallCanceled {
1694 room_id: room_id.to_proto(),
1695 },
1696 )
1697 .trace_err();
1698 }
1699 update_user_contacts(session.user_id(), &session).await?;
1700 Ok(())
1701}
1702
1703/// Updates other participants in the room with your current location.
1704async fn update_participant_location(
1705 request: proto::UpdateParticipantLocation,
1706 response: Response<proto::UpdateParticipantLocation>,
1707 session: Session,
1708) -> Result<()> {
1709 let room_id = RoomId::from_proto(request.room_id);
1710 let location = request
1711 .location
1712 .ok_or_else(|| anyhow!("invalid location"))?;
1713
1714 let db = session.db().await;
1715 let room = db
1716 .update_room_participant_location(room_id, session.connection_id, location)
1717 .await?;
1718
1719 room_updated(&room, &session.peer);
1720 response.send(proto::Ack {})?;
1721 Ok(())
1722}
1723
1724/// Share a project into the room.
1725async fn share_project(
1726 request: proto::ShareProject,
1727 response: Response<proto::ShareProject>,
1728 session: Session,
1729) -> Result<()> {
1730 let (project_id, room) = &*session
1731 .db()
1732 .await
1733 .share_project(
1734 RoomId::from_proto(request.room_id),
1735 session.connection_id,
1736 &request.worktrees,
1737 request.is_ssh_project,
1738 )
1739 .await?;
1740 response.send(proto::ShareProjectResponse {
1741 project_id: project_id.to_proto(),
1742 })?;
1743 room_updated(room, &session.peer);
1744
1745 Ok(())
1746}
1747
1748/// Unshare a project from the room.
1749async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1750 let project_id = ProjectId::from_proto(message.project_id);
1751 unshare_project_internal(project_id, session.connection_id, &session).await
1752}
1753
1754async fn unshare_project_internal(
1755 project_id: ProjectId,
1756 connection_id: ConnectionId,
1757 session: &Session,
1758) -> Result<()> {
1759 let delete = {
1760 let room_guard = session
1761 .db()
1762 .await
1763 .unshare_project(project_id, connection_id)
1764 .await?;
1765
1766 let (delete, room, guest_connection_ids) = &*room_guard;
1767
1768 let message = proto::UnshareProject {
1769 project_id: project_id.to_proto(),
1770 };
1771
1772 broadcast(
1773 Some(connection_id),
1774 guest_connection_ids.iter().copied(),
1775 |conn_id| session.peer.send(conn_id, message.clone()),
1776 );
1777 if let Some(room) = room {
1778 room_updated(room, &session.peer);
1779 }
1780
1781 *delete
1782 };
1783
1784 if delete {
1785 let db = session.db().await;
1786 db.delete_project(project_id).await?;
1787 }
1788
1789 Ok(())
1790}
1791
1792/// Join someone elses shared project.
1793async fn join_project(
1794 request: proto::JoinProject,
1795 response: Response<proto::JoinProject>,
1796 session: Session,
1797) -> Result<()> {
1798 let project_id = ProjectId::from_proto(request.project_id);
1799
1800 tracing::info!(%project_id, "join project");
1801
1802 let db = session.db().await;
1803 let (project, replica_id) = &mut *db
1804 .join_project(project_id, session.connection_id, session.user_id())
1805 .await?;
1806 drop(db);
1807 tracing::info!(%project_id, "join remote project");
1808 join_project_internal(response, session, project, replica_id)
1809}
1810
1811trait JoinProjectInternalResponse {
1812 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1813}
1814impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1815 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1816 Response::<proto::JoinProject>::send(self, result)
1817 }
1818}
1819
1820fn join_project_internal(
1821 response: impl JoinProjectInternalResponse,
1822 session: Session,
1823 project: &mut Project,
1824 replica_id: &ReplicaId,
1825) -> Result<()> {
1826 let collaborators = project
1827 .collaborators
1828 .iter()
1829 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1830 .map(|collaborator| collaborator.to_proto())
1831 .collect::<Vec<_>>();
1832 let project_id = project.id;
1833 let guest_user_id = session.user_id();
1834
1835 let worktrees = project
1836 .worktrees
1837 .iter()
1838 .map(|(id, worktree)| proto::WorktreeMetadata {
1839 id: *id,
1840 root_name: worktree.root_name.clone(),
1841 visible: worktree.visible,
1842 abs_path: worktree.abs_path.clone(),
1843 })
1844 .collect::<Vec<_>>();
1845
1846 let add_project_collaborator = proto::AddProjectCollaborator {
1847 project_id: project_id.to_proto(),
1848 collaborator: Some(proto::Collaborator {
1849 peer_id: Some(session.connection_id.into()),
1850 replica_id: replica_id.0 as u32,
1851 user_id: guest_user_id.to_proto(),
1852 is_host: false,
1853 }),
1854 };
1855
1856 for collaborator in &collaborators {
1857 session
1858 .peer
1859 .send(
1860 collaborator.peer_id.unwrap().into(),
1861 add_project_collaborator.clone(),
1862 )
1863 .trace_err();
1864 }
1865
1866 // First, we send the metadata associated with each worktree.
1867 response.send(proto::JoinProjectResponse {
1868 project_id: project.id.0 as u64,
1869 worktrees: worktrees.clone(),
1870 replica_id: replica_id.0 as u32,
1871 collaborators: collaborators.clone(),
1872 language_servers: project.language_servers.clone(),
1873 role: project.role.into(),
1874 })?;
1875
1876 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1877 // Stream this worktree's entries.
1878 let message = proto::UpdateWorktree {
1879 project_id: project_id.to_proto(),
1880 worktree_id,
1881 abs_path: worktree.abs_path.clone(),
1882 root_name: worktree.root_name,
1883 updated_entries: worktree.entries,
1884 removed_entries: Default::default(),
1885 scan_id: worktree.scan_id,
1886 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1887 updated_repositories: worktree.repository_entries.into_values().collect(),
1888 removed_repositories: Default::default(),
1889 };
1890 for update in proto::split_worktree_update(message) {
1891 session.peer.send(session.connection_id, update.clone())?;
1892 }
1893
1894 // Stream this worktree's diagnostics.
1895 for summary in worktree.diagnostic_summaries {
1896 session.peer.send(
1897 session.connection_id,
1898 proto::UpdateDiagnosticSummary {
1899 project_id: project_id.to_proto(),
1900 worktree_id: worktree.id,
1901 summary: Some(summary),
1902 },
1903 )?;
1904 }
1905
1906 for settings_file in worktree.settings_files {
1907 session.peer.send(
1908 session.connection_id,
1909 proto::UpdateWorktreeSettings {
1910 project_id: project_id.to_proto(),
1911 worktree_id: worktree.id,
1912 path: settings_file.path,
1913 content: Some(settings_file.content),
1914 kind: Some(settings_file.kind.to_proto() as i32),
1915 },
1916 )?;
1917 }
1918 }
1919
1920 for language_server in &project.language_servers {
1921 session.peer.send(
1922 session.connection_id,
1923 proto::UpdateLanguageServer {
1924 project_id: project_id.to_proto(),
1925 language_server_id: language_server.id,
1926 variant: Some(
1927 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1928 proto::LspDiskBasedDiagnosticsUpdated {},
1929 ),
1930 ),
1931 },
1932 )?;
1933 }
1934
1935 Ok(())
1936}
1937
1938/// Leave someone elses shared project.
1939async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1940 let sender_id = session.connection_id;
1941 let project_id = ProjectId::from_proto(request.project_id);
1942 let db = session.db().await;
1943
1944 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1945 tracing::info!(
1946 %project_id,
1947 "leave project"
1948 );
1949
1950 project_left(project, &session);
1951 if let Some(room) = room {
1952 room_updated(room, &session.peer);
1953 }
1954
1955 Ok(())
1956}
1957
1958/// Updates other participants with changes to the project
1959async fn update_project(
1960 request: proto::UpdateProject,
1961 response: Response<proto::UpdateProject>,
1962 session: Session,
1963) -> Result<()> {
1964 let project_id = ProjectId::from_proto(request.project_id);
1965 let (room, guest_connection_ids) = &*session
1966 .db()
1967 .await
1968 .update_project(project_id, session.connection_id, &request.worktrees)
1969 .await?;
1970 broadcast(
1971 Some(session.connection_id),
1972 guest_connection_ids.iter().copied(),
1973 |connection_id| {
1974 session
1975 .peer
1976 .forward_send(session.connection_id, connection_id, request.clone())
1977 },
1978 );
1979 if let Some(room) = room {
1980 room_updated(room, &session.peer);
1981 }
1982 response.send(proto::Ack {})?;
1983
1984 Ok(())
1985}
1986
1987/// Updates other participants with changes to the worktree
1988async fn update_worktree(
1989 request: proto::UpdateWorktree,
1990 response: Response<proto::UpdateWorktree>,
1991 session: Session,
1992) -> Result<()> {
1993 let guest_connection_ids = session
1994 .db()
1995 .await
1996 .update_worktree(&request, session.connection_id)
1997 .await?;
1998
1999 broadcast(
2000 Some(session.connection_id),
2001 guest_connection_ids.iter().copied(),
2002 |connection_id| {
2003 session
2004 .peer
2005 .forward_send(session.connection_id, connection_id, request.clone())
2006 },
2007 );
2008 response.send(proto::Ack {})?;
2009 Ok(())
2010}
2011
2012/// Updates other participants with changes to the diagnostics
2013async fn update_diagnostic_summary(
2014 message: proto::UpdateDiagnosticSummary,
2015 session: Session,
2016) -> Result<()> {
2017 let guest_connection_ids = session
2018 .db()
2019 .await
2020 .update_diagnostic_summary(&message, session.connection_id)
2021 .await?;
2022
2023 broadcast(
2024 Some(session.connection_id),
2025 guest_connection_ids.iter().copied(),
2026 |connection_id| {
2027 session
2028 .peer
2029 .forward_send(session.connection_id, connection_id, message.clone())
2030 },
2031 );
2032
2033 Ok(())
2034}
2035
2036/// Updates other participants with changes to the worktree settings
2037async fn update_worktree_settings(
2038 message: proto::UpdateWorktreeSettings,
2039 session: Session,
2040) -> Result<()> {
2041 let guest_connection_ids = session
2042 .db()
2043 .await
2044 .update_worktree_settings(&message, session.connection_id)
2045 .await?;
2046
2047 broadcast(
2048 Some(session.connection_id),
2049 guest_connection_ids.iter().copied(),
2050 |connection_id| {
2051 session
2052 .peer
2053 .forward_send(session.connection_id, connection_id, message.clone())
2054 },
2055 );
2056
2057 Ok(())
2058}
2059
2060/// Notify other participants that a language server has started.
2061async fn start_language_server(
2062 request: proto::StartLanguageServer,
2063 session: Session,
2064) -> Result<()> {
2065 let guest_connection_ids = session
2066 .db()
2067 .await
2068 .start_language_server(&request, session.connection_id)
2069 .await?;
2070
2071 broadcast(
2072 Some(session.connection_id),
2073 guest_connection_ids.iter().copied(),
2074 |connection_id| {
2075 session
2076 .peer
2077 .forward_send(session.connection_id, connection_id, request.clone())
2078 },
2079 );
2080 Ok(())
2081}
2082
2083/// Notify other participants that a language server has changed.
2084async fn update_language_server(
2085 request: proto::UpdateLanguageServer,
2086 session: Session,
2087) -> Result<()> {
2088 let project_id = ProjectId::from_proto(request.project_id);
2089 let project_connection_ids = session
2090 .db()
2091 .await
2092 .project_connection_ids(project_id, session.connection_id, true)
2093 .await?;
2094 broadcast(
2095 Some(session.connection_id),
2096 project_connection_ids.iter().copied(),
2097 |connection_id| {
2098 session
2099 .peer
2100 .forward_send(session.connection_id, connection_id, request.clone())
2101 },
2102 );
2103 Ok(())
2104}
2105
2106/// forward a project request to the host. These requests should be read only
2107/// as guests are allowed to send them.
2108async fn forward_read_only_project_request<T>(
2109 request: T,
2110 response: Response<T>,
2111 session: Session,
2112) -> Result<()>
2113where
2114 T: EntityMessage + RequestMessage,
2115{
2116 let project_id = ProjectId::from_proto(request.remote_entity_id());
2117 let host_connection_id = session
2118 .db()
2119 .await
2120 .host_for_read_only_project_request(project_id, session.connection_id)
2121 .await?;
2122 let payload = session
2123 .peer
2124 .forward_request(session.connection_id, host_connection_id, request)
2125 .await?;
2126 response.send(payload)?;
2127 Ok(())
2128}
2129
2130async fn forward_find_search_candidates_request(
2131 request: proto::FindSearchCandidates,
2132 response: Response<proto::FindSearchCandidates>,
2133 session: Session,
2134) -> Result<()> {
2135 let project_id = ProjectId::from_proto(request.remote_entity_id());
2136 let host_connection_id = session
2137 .db()
2138 .await
2139 .host_for_read_only_project_request(project_id, session.connection_id)
2140 .await?;
2141 let payload = session
2142 .peer
2143 .forward_request(session.connection_id, host_connection_id, request)
2144 .await?;
2145 response.send(payload)?;
2146 Ok(())
2147}
2148
2149/// forward a project request to the host. These requests are disallowed
2150/// for guests.
2151async fn forward_mutating_project_request<T>(
2152 request: T,
2153 response: Response<T>,
2154 session: Session,
2155) -> Result<()>
2156where
2157 T: EntityMessage + RequestMessage,
2158{
2159 let project_id = ProjectId::from_proto(request.remote_entity_id());
2160
2161 let host_connection_id = session
2162 .db()
2163 .await
2164 .host_for_mutating_project_request(project_id, session.connection_id)
2165 .await?;
2166 let payload = session
2167 .peer
2168 .forward_request(session.connection_id, host_connection_id, request)
2169 .await?;
2170 response.send(payload)?;
2171 Ok(())
2172}
2173
2174/// Notify other participants that a new buffer has been created
2175async fn create_buffer_for_peer(
2176 request: proto::CreateBufferForPeer,
2177 session: Session,
2178) -> Result<()> {
2179 session
2180 .db()
2181 .await
2182 .check_user_is_project_host(
2183 ProjectId::from_proto(request.project_id),
2184 session.connection_id,
2185 )
2186 .await?;
2187 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2188 session
2189 .peer
2190 .forward_send(session.connection_id, peer_id.into(), request)?;
2191 Ok(())
2192}
2193
2194/// Notify other participants that a buffer has been updated. This is
2195/// allowed for guests as long as the update is limited to selections.
2196async fn update_buffer(
2197 request: proto::UpdateBuffer,
2198 response: Response<proto::UpdateBuffer>,
2199 session: Session,
2200) -> Result<()> {
2201 let project_id = ProjectId::from_proto(request.project_id);
2202 let mut capability = Capability::ReadOnly;
2203
2204 for op in request.operations.iter() {
2205 match op.variant {
2206 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2207 Some(_) => capability = Capability::ReadWrite,
2208 }
2209 }
2210
2211 let host = {
2212 let guard = session
2213 .db()
2214 .await
2215 .connections_for_buffer_update(project_id, session.connection_id, capability)
2216 .await?;
2217
2218 let (host, guests) = &*guard;
2219
2220 broadcast(
2221 Some(session.connection_id),
2222 guests.clone(),
2223 |connection_id| {
2224 session
2225 .peer
2226 .forward_send(session.connection_id, connection_id, request.clone())
2227 },
2228 );
2229
2230 *host
2231 };
2232
2233 if host != session.connection_id {
2234 session
2235 .peer
2236 .forward_request(session.connection_id, host, request.clone())
2237 .await?;
2238 }
2239
2240 response.send(proto::Ack {})?;
2241 Ok(())
2242}
2243
2244async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2245 let project_id = ProjectId::from_proto(message.project_id);
2246
2247 let operation = message.operation.as_ref().context("invalid operation")?;
2248 let capability = match operation.variant.as_ref() {
2249 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2250 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2251 match buffer_op.variant {
2252 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2253 Capability::ReadOnly
2254 }
2255 _ => Capability::ReadWrite,
2256 }
2257 } else {
2258 Capability::ReadWrite
2259 }
2260 }
2261 Some(_) => Capability::ReadWrite,
2262 None => Capability::ReadOnly,
2263 };
2264
2265 let guard = session
2266 .db()
2267 .await
2268 .connections_for_buffer_update(project_id, session.connection_id, capability)
2269 .await?;
2270
2271 let (host, guests) = &*guard;
2272
2273 broadcast(
2274 Some(session.connection_id),
2275 guests.iter().chain([host]).copied(),
2276 |connection_id| {
2277 session
2278 .peer
2279 .forward_send(session.connection_id, connection_id, message.clone())
2280 },
2281 );
2282
2283 Ok(())
2284}
2285
2286/// Notify other participants that a project has been updated.
2287async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2288 request: T,
2289 session: Session,
2290) -> Result<()> {
2291 let project_id = ProjectId::from_proto(request.remote_entity_id());
2292 let project_connection_ids = session
2293 .db()
2294 .await
2295 .project_connection_ids(project_id, session.connection_id, false)
2296 .await?;
2297
2298 broadcast(
2299 Some(session.connection_id),
2300 project_connection_ids.iter().copied(),
2301 |connection_id| {
2302 session
2303 .peer
2304 .forward_send(session.connection_id, connection_id, request.clone())
2305 },
2306 );
2307 Ok(())
2308}
2309
2310/// Start following another user in a call.
2311async fn follow(
2312 request: proto::Follow,
2313 response: Response<proto::Follow>,
2314 session: Session,
2315) -> Result<()> {
2316 let room_id = RoomId::from_proto(request.room_id);
2317 let project_id = request.project_id.map(ProjectId::from_proto);
2318 let leader_id = request
2319 .leader_id
2320 .ok_or_else(|| anyhow!("invalid leader id"))?
2321 .into();
2322 let follower_id = session.connection_id;
2323
2324 session
2325 .db()
2326 .await
2327 .check_room_participants(room_id, leader_id, session.connection_id)
2328 .await?;
2329
2330 let response_payload = session
2331 .peer
2332 .forward_request(session.connection_id, leader_id, request)
2333 .await?;
2334 response.send(response_payload)?;
2335
2336 if let Some(project_id) = project_id {
2337 let room = session
2338 .db()
2339 .await
2340 .follow(room_id, project_id, leader_id, follower_id)
2341 .await?;
2342 room_updated(&room, &session.peer);
2343 }
2344
2345 Ok(())
2346}
2347
2348/// Stop following another user in a call.
2349async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2350 let room_id = RoomId::from_proto(request.room_id);
2351 let project_id = request.project_id.map(ProjectId::from_proto);
2352 let leader_id = request
2353 .leader_id
2354 .ok_or_else(|| anyhow!("invalid leader id"))?
2355 .into();
2356 let follower_id = session.connection_id;
2357
2358 session
2359 .db()
2360 .await
2361 .check_room_participants(room_id, leader_id, session.connection_id)
2362 .await?;
2363
2364 session
2365 .peer
2366 .forward_send(session.connection_id, leader_id, request)?;
2367
2368 if let Some(project_id) = project_id {
2369 let room = session
2370 .db()
2371 .await
2372 .unfollow(room_id, project_id, leader_id, follower_id)
2373 .await?;
2374 room_updated(&room, &session.peer);
2375 }
2376
2377 Ok(())
2378}
2379
2380/// Notify everyone following you of your current location.
2381async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2382 let room_id = RoomId::from_proto(request.room_id);
2383 let database = session.db.lock().await;
2384
2385 let connection_ids = if let Some(project_id) = request.project_id {
2386 let project_id = ProjectId::from_proto(project_id);
2387 database
2388 .project_connection_ids(project_id, session.connection_id, true)
2389 .await?
2390 } else {
2391 database
2392 .room_connection_ids(room_id, session.connection_id)
2393 .await?
2394 };
2395
2396 // For now, don't send view update messages back to that view's current leader.
2397 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2398 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2399 _ => None,
2400 });
2401
2402 for connection_id in connection_ids.iter().cloned() {
2403 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2404 session
2405 .peer
2406 .forward_send(session.connection_id, connection_id, request.clone())?;
2407 }
2408 }
2409 Ok(())
2410}
2411
2412/// Get public data about users.
2413async fn get_users(
2414 request: proto::GetUsers,
2415 response: Response<proto::GetUsers>,
2416 session: Session,
2417) -> Result<()> {
2418 let user_ids = request
2419 .user_ids
2420 .into_iter()
2421 .map(UserId::from_proto)
2422 .collect();
2423 let users = session
2424 .db()
2425 .await
2426 .get_users_by_ids(user_ids)
2427 .await?
2428 .into_iter()
2429 .map(|user| proto::User {
2430 id: user.id.to_proto(),
2431 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2432 github_login: user.github_login,
2433 email: user.email_address,
2434 name: user.name,
2435 })
2436 .collect();
2437 response.send(proto::UsersResponse { users })?;
2438 Ok(())
2439}
2440
2441/// Search for users (to invite) buy Github login
2442async fn fuzzy_search_users(
2443 request: proto::FuzzySearchUsers,
2444 response: Response<proto::FuzzySearchUsers>,
2445 session: Session,
2446) -> Result<()> {
2447 let query = request.query;
2448 let users = match query.len() {
2449 0 => vec![],
2450 1 | 2 => session
2451 .db()
2452 .await
2453 .get_user_by_github_login(&query)
2454 .await?
2455 .into_iter()
2456 .collect(),
2457 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2458 };
2459 let users = users
2460 .into_iter()
2461 .filter(|user| user.id != session.user_id())
2462 .map(|user| proto::User {
2463 id: user.id.to_proto(),
2464 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2465 github_login: user.github_login,
2466 name: user.name,
2467 email: user.email_address,
2468 })
2469 .collect();
2470 response.send(proto::UsersResponse { users })?;
2471 Ok(())
2472}
2473
2474/// Send a contact request to another user.
2475async fn request_contact(
2476 request: proto::RequestContact,
2477 response: Response<proto::RequestContact>,
2478 session: Session,
2479) -> Result<()> {
2480 let requester_id = session.user_id();
2481 let responder_id = UserId::from_proto(request.responder_id);
2482 if requester_id == responder_id {
2483 return Err(anyhow!("cannot add yourself as a contact"))?;
2484 }
2485
2486 let notifications = session
2487 .db()
2488 .await
2489 .send_contact_request(requester_id, responder_id)
2490 .await?;
2491
2492 // Update outgoing contact requests of requester
2493 let mut update = proto::UpdateContacts::default();
2494 update.outgoing_requests.push(responder_id.to_proto());
2495 for connection_id in session
2496 .connection_pool()
2497 .await
2498 .user_connection_ids(requester_id)
2499 {
2500 session.peer.send(connection_id, update.clone())?;
2501 }
2502
2503 // Update incoming contact requests of responder
2504 let mut update = proto::UpdateContacts::default();
2505 update
2506 .incoming_requests
2507 .push(proto::IncomingContactRequest {
2508 requester_id: requester_id.to_proto(),
2509 });
2510 let connection_pool = session.connection_pool().await;
2511 for connection_id in connection_pool.user_connection_ids(responder_id) {
2512 session.peer.send(connection_id, update.clone())?;
2513 }
2514
2515 send_notifications(&connection_pool, &session.peer, notifications);
2516
2517 response.send(proto::Ack {})?;
2518 Ok(())
2519}
2520
2521/// Accept or decline a contact request
2522async fn respond_to_contact_request(
2523 request: proto::RespondToContactRequest,
2524 response: Response<proto::RespondToContactRequest>,
2525 session: Session,
2526) -> Result<()> {
2527 let responder_id = session.user_id();
2528 let requester_id = UserId::from_proto(request.requester_id);
2529 let db = session.db().await;
2530 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2531 db.dismiss_contact_notification(responder_id, requester_id)
2532 .await?;
2533 } else {
2534 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2535
2536 let notifications = db
2537 .respond_to_contact_request(responder_id, requester_id, accept)
2538 .await?;
2539 let requester_busy = db.is_user_busy(requester_id).await?;
2540 let responder_busy = db.is_user_busy(responder_id).await?;
2541
2542 let pool = session.connection_pool().await;
2543 // Update responder with new contact
2544 let mut update = proto::UpdateContacts::default();
2545 if accept {
2546 update
2547 .contacts
2548 .push(contact_for_user(requester_id, requester_busy, &pool));
2549 }
2550 update
2551 .remove_incoming_requests
2552 .push(requester_id.to_proto());
2553 for connection_id in pool.user_connection_ids(responder_id) {
2554 session.peer.send(connection_id, update.clone())?;
2555 }
2556
2557 // Update requester with new contact
2558 let mut update = proto::UpdateContacts::default();
2559 if accept {
2560 update
2561 .contacts
2562 .push(contact_for_user(responder_id, responder_busy, &pool));
2563 }
2564 update
2565 .remove_outgoing_requests
2566 .push(responder_id.to_proto());
2567
2568 for connection_id in pool.user_connection_ids(requester_id) {
2569 session.peer.send(connection_id, update.clone())?;
2570 }
2571
2572 send_notifications(&pool, &session.peer, notifications);
2573 }
2574
2575 response.send(proto::Ack {})?;
2576 Ok(())
2577}
2578
2579/// Remove a contact.
2580async fn remove_contact(
2581 request: proto::RemoveContact,
2582 response: Response<proto::RemoveContact>,
2583 session: Session,
2584) -> Result<()> {
2585 let requester_id = session.user_id();
2586 let responder_id = UserId::from_proto(request.user_id);
2587 let db = session.db().await;
2588 let (contact_accepted, deleted_notification_id) =
2589 db.remove_contact(requester_id, responder_id).await?;
2590
2591 let pool = session.connection_pool().await;
2592 // Update outgoing contact requests of requester
2593 let mut update = proto::UpdateContacts::default();
2594 if contact_accepted {
2595 update.remove_contacts.push(responder_id.to_proto());
2596 } else {
2597 update
2598 .remove_outgoing_requests
2599 .push(responder_id.to_proto());
2600 }
2601 for connection_id in pool.user_connection_ids(requester_id) {
2602 session.peer.send(connection_id, update.clone())?;
2603 }
2604
2605 // Update incoming contact requests of responder
2606 let mut update = proto::UpdateContacts::default();
2607 if contact_accepted {
2608 update.remove_contacts.push(requester_id.to_proto());
2609 } else {
2610 update
2611 .remove_incoming_requests
2612 .push(requester_id.to_proto());
2613 }
2614 for connection_id in pool.user_connection_ids(responder_id) {
2615 session.peer.send(connection_id, update.clone())?;
2616 if let Some(notification_id) = deleted_notification_id {
2617 session.peer.send(
2618 connection_id,
2619 proto::DeleteNotification {
2620 notification_id: notification_id.to_proto(),
2621 },
2622 )?;
2623 }
2624 }
2625
2626 response.send(proto::Ack {})?;
2627 Ok(())
2628}
2629
2630fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2631 version.0.minor() < 139
2632}
2633
2634async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2635 let plan = session.current_plan(&session.db().await).await?;
2636
2637 session
2638 .peer
2639 .send(
2640 session.connection_id,
2641 proto::UpdateUserPlan { plan: plan.into() },
2642 )
2643 .trace_err();
2644
2645 Ok(())
2646}
2647
2648async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2649 subscribe_user_to_channels(session.user_id(), &session).await?;
2650 Ok(())
2651}
2652
2653async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2654 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2655 let mut pool = session.connection_pool().await;
2656 for membership in &channels_for_user.channel_memberships {
2657 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2658 }
2659 session.peer.send(
2660 session.connection_id,
2661 build_update_user_channels(&channels_for_user),
2662 )?;
2663 session.peer.send(
2664 session.connection_id,
2665 build_channels_update(channels_for_user),
2666 )?;
2667 Ok(())
2668}
2669
2670/// Creates a new channel.
2671async fn create_channel(
2672 request: proto::CreateChannel,
2673 response: Response<proto::CreateChannel>,
2674 session: Session,
2675) -> Result<()> {
2676 let db = session.db().await;
2677
2678 let parent_id = request.parent_id.map(ChannelId::from_proto);
2679 let (channel, membership) = db
2680 .create_channel(&request.name, parent_id, session.user_id())
2681 .await?;
2682
2683 let root_id = channel.root_id();
2684 let channel = Channel::from_model(channel);
2685
2686 response.send(proto::CreateChannelResponse {
2687 channel: Some(channel.to_proto()),
2688 parent_id: request.parent_id,
2689 })?;
2690
2691 let mut connection_pool = session.connection_pool().await;
2692 if let Some(membership) = membership {
2693 connection_pool.subscribe_to_channel(
2694 membership.user_id,
2695 membership.channel_id,
2696 membership.role,
2697 );
2698 let update = proto::UpdateUserChannels {
2699 channel_memberships: vec![proto::ChannelMembership {
2700 channel_id: membership.channel_id.to_proto(),
2701 role: membership.role.into(),
2702 }],
2703 ..Default::default()
2704 };
2705 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2706 session.peer.send(connection_id, update.clone())?;
2707 }
2708 }
2709
2710 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2711 if !role.can_see_channel(channel.visibility) {
2712 continue;
2713 }
2714
2715 let update = proto::UpdateChannels {
2716 channels: vec![channel.to_proto()],
2717 ..Default::default()
2718 };
2719 session.peer.send(connection_id, update.clone())?;
2720 }
2721
2722 Ok(())
2723}
2724
2725/// Delete a channel
2726async fn delete_channel(
2727 request: proto::DeleteChannel,
2728 response: Response<proto::DeleteChannel>,
2729 session: Session,
2730) -> Result<()> {
2731 let db = session.db().await;
2732
2733 let channel_id = request.channel_id;
2734 let (root_channel, removed_channels) = db
2735 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2736 .await?;
2737 response.send(proto::Ack {})?;
2738
2739 // Notify members of removed channels
2740 let mut update = proto::UpdateChannels::default();
2741 update
2742 .delete_channels
2743 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2744
2745 let connection_pool = session.connection_pool().await;
2746 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2747 session.peer.send(connection_id, update.clone())?;
2748 }
2749
2750 Ok(())
2751}
2752
2753/// Invite someone to join a channel.
2754async fn invite_channel_member(
2755 request: proto::InviteChannelMember,
2756 response: Response<proto::InviteChannelMember>,
2757 session: Session,
2758) -> Result<()> {
2759 let db = session.db().await;
2760 let channel_id = ChannelId::from_proto(request.channel_id);
2761 let invitee_id = UserId::from_proto(request.user_id);
2762 let InviteMemberResult {
2763 channel,
2764 notifications,
2765 } = db
2766 .invite_channel_member(
2767 channel_id,
2768 invitee_id,
2769 session.user_id(),
2770 request.role().into(),
2771 )
2772 .await?;
2773
2774 let update = proto::UpdateChannels {
2775 channel_invitations: vec![channel.to_proto()],
2776 ..Default::default()
2777 };
2778
2779 let connection_pool = session.connection_pool().await;
2780 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2781 session.peer.send(connection_id, update.clone())?;
2782 }
2783
2784 send_notifications(&connection_pool, &session.peer, notifications);
2785
2786 response.send(proto::Ack {})?;
2787 Ok(())
2788}
2789
2790/// remove someone from a channel
2791async fn remove_channel_member(
2792 request: proto::RemoveChannelMember,
2793 response: Response<proto::RemoveChannelMember>,
2794 session: Session,
2795) -> Result<()> {
2796 let db = session.db().await;
2797 let channel_id = ChannelId::from_proto(request.channel_id);
2798 let member_id = UserId::from_proto(request.user_id);
2799
2800 let RemoveChannelMemberResult {
2801 membership_update,
2802 notification_id,
2803 } = db
2804 .remove_channel_member(channel_id, member_id, session.user_id())
2805 .await?;
2806
2807 let mut connection_pool = session.connection_pool().await;
2808 notify_membership_updated(
2809 &mut connection_pool,
2810 membership_update,
2811 member_id,
2812 &session.peer,
2813 );
2814 for connection_id in connection_pool.user_connection_ids(member_id) {
2815 if let Some(notification_id) = notification_id {
2816 session
2817 .peer
2818 .send(
2819 connection_id,
2820 proto::DeleteNotification {
2821 notification_id: notification_id.to_proto(),
2822 },
2823 )
2824 .trace_err();
2825 }
2826 }
2827
2828 response.send(proto::Ack {})?;
2829 Ok(())
2830}
2831
2832/// Toggle the channel between public and private.
2833/// Care is taken to maintain the invariant that public channels only descend from public channels,
2834/// (though members-only channels can appear at any point in the hierarchy).
2835async fn set_channel_visibility(
2836 request: proto::SetChannelVisibility,
2837 response: Response<proto::SetChannelVisibility>,
2838 session: Session,
2839) -> Result<()> {
2840 let db = session.db().await;
2841 let channel_id = ChannelId::from_proto(request.channel_id);
2842 let visibility = request.visibility().into();
2843
2844 let channel_model = db
2845 .set_channel_visibility(channel_id, visibility, session.user_id())
2846 .await?;
2847 let root_id = channel_model.root_id();
2848 let channel = Channel::from_model(channel_model);
2849
2850 let mut connection_pool = session.connection_pool().await;
2851 for (user_id, role) in connection_pool
2852 .channel_user_ids(root_id)
2853 .collect::<Vec<_>>()
2854 .into_iter()
2855 {
2856 let update = if role.can_see_channel(channel.visibility) {
2857 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2858 proto::UpdateChannels {
2859 channels: vec![channel.to_proto()],
2860 ..Default::default()
2861 }
2862 } else {
2863 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2864 proto::UpdateChannels {
2865 delete_channels: vec![channel.id.to_proto()],
2866 ..Default::default()
2867 }
2868 };
2869
2870 for connection_id in connection_pool.user_connection_ids(user_id) {
2871 session.peer.send(connection_id, update.clone())?;
2872 }
2873 }
2874
2875 response.send(proto::Ack {})?;
2876 Ok(())
2877}
2878
2879/// Alter the role for a user in the channel.
2880async fn set_channel_member_role(
2881 request: proto::SetChannelMemberRole,
2882 response: Response<proto::SetChannelMemberRole>,
2883 session: Session,
2884) -> Result<()> {
2885 let db = session.db().await;
2886 let channel_id = ChannelId::from_proto(request.channel_id);
2887 let member_id = UserId::from_proto(request.user_id);
2888 let result = db
2889 .set_channel_member_role(
2890 channel_id,
2891 session.user_id(),
2892 member_id,
2893 request.role().into(),
2894 )
2895 .await?;
2896
2897 match result {
2898 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2899 let mut connection_pool = session.connection_pool().await;
2900 notify_membership_updated(
2901 &mut connection_pool,
2902 membership_update,
2903 member_id,
2904 &session.peer,
2905 )
2906 }
2907 db::SetMemberRoleResult::InviteUpdated(channel) => {
2908 let update = proto::UpdateChannels {
2909 channel_invitations: vec![channel.to_proto()],
2910 ..Default::default()
2911 };
2912
2913 for connection_id in session
2914 .connection_pool()
2915 .await
2916 .user_connection_ids(member_id)
2917 {
2918 session.peer.send(connection_id, update.clone())?;
2919 }
2920 }
2921 }
2922
2923 response.send(proto::Ack {})?;
2924 Ok(())
2925}
2926
2927/// Change the name of a channel
2928async fn rename_channel(
2929 request: proto::RenameChannel,
2930 response: Response<proto::RenameChannel>,
2931 session: Session,
2932) -> Result<()> {
2933 let db = session.db().await;
2934 let channel_id = ChannelId::from_proto(request.channel_id);
2935 let channel_model = db
2936 .rename_channel(channel_id, session.user_id(), &request.name)
2937 .await?;
2938 let root_id = channel_model.root_id();
2939 let channel = Channel::from_model(channel_model);
2940
2941 response.send(proto::RenameChannelResponse {
2942 channel: Some(channel.to_proto()),
2943 })?;
2944
2945 let connection_pool = session.connection_pool().await;
2946 let update = proto::UpdateChannels {
2947 channels: vec![channel.to_proto()],
2948 ..Default::default()
2949 };
2950 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2951 if role.can_see_channel(channel.visibility) {
2952 session.peer.send(connection_id, update.clone())?;
2953 }
2954 }
2955
2956 Ok(())
2957}
2958
2959/// Move a channel to a new parent.
2960async fn move_channel(
2961 request: proto::MoveChannel,
2962 response: Response<proto::MoveChannel>,
2963 session: Session,
2964) -> Result<()> {
2965 let channel_id = ChannelId::from_proto(request.channel_id);
2966 let to = ChannelId::from_proto(request.to);
2967
2968 let (root_id, channels) = session
2969 .db()
2970 .await
2971 .move_channel(channel_id, to, session.user_id())
2972 .await?;
2973
2974 let connection_pool = session.connection_pool().await;
2975 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2976 let channels = channels
2977 .iter()
2978 .filter_map(|channel| {
2979 if role.can_see_channel(channel.visibility) {
2980 Some(channel.to_proto())
2981 } else {
2982 None
2983 }
2984 })
2985 .collect::<Vec<_>>();
2986 if channels.is_empty() {
2987 continue;
2988 }
2989
2990 let update = proto::UpdateChannels {
2991 channels,
2992 ..Default::default()
2993 };
2994
2995 session.peer.send(connection_id, update.clone())?;
2996 }
2997
2998 response.send(Ack {})?;
2999 Ok(())
3000}
3001
3002/// Get the list of channel members
3003async fn get_channel_members(
3004 request: proto::GetChannelMembers,
3005 response: Response<proto::GetChannelMembers>,
3006 session: Session,
3007) -> Result<()> {
3008 let db = session.db().await;
3009 let channel_id = ChannelId::from_proto(request.channel_id);
3010 let limit = if request.limit == 0 {
3011 u16::MAX as u64
3012 } else {
3013 request.limit
3014 };
3015 let (members, users) = db
3016 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3017 .await?;
3018 response.send(proto::GetChannelMembersResponse { members, users })?;
3019 Ok(())
3020}
3021
3022/// Accept or decline a channel invitation.
3023async fn respond_to_channel_invite(
3024 request: proto::RespondToChannelInvite,
3025 response: Response<proto::RespondToChannelInvite>,
3026 session: Session,
3027) -> Result<()> {
3028 let db = session.db().await;
3029 let channel_id = ChannelId::from_proto(request.channel_id);
3030 let RespondToChannelInvite {
3031 membership_update,
3032 notifications,
3033 } = db
3034 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3035 .await?;
3036
3037 let mut connection_pool = session.connection_pool().await;
3038 if let Some(membership_update) = membership_update {
3039 notify_membership_updated(
3040 &mut connection_pool,
3041 membership_update,
3042 session.user_id(),
3043 &session.peer,
3044 );
3045 } else {
3046 let update = proto::UpdateChannels {
3047 remove_channel_invitations: vec![channel_id.to_proto()],
3048 ..Default::default()
3049 };
3050
3051 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3052 session.peer.send(connection_id, update.clone())?;
3053 }
3054 };
3055
3056 send_notifications(&connection_pool, &session.peer, notifications);
3057
3058 response.send(proto::Ack {})?;
3059
3060 Ok(())
3061}
3062
3063/// Join the channels' room
3064async fn join_channel(
3065 request: proto::JoinChannel,
3066 response: Response<proto::JoinChannel>,
3067 session: Session,
3068) -> Result<()> {
3069 let channel_id = ChannelId::from_proto(request.channel_id);
3070 join_channel_internal(channel_id, Box::new(response), session).await
3071}
3072
3073trait JoinChannelInternalResponse {
3074 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3075}
3076impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3077 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3078 Response::<proto::JoinChannel>::send(self, result)
3079 }
3080}
3081impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3082 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3083 Response::<proto::JoinRoom>::send(self, result)
3084 }
3085}
3086
3087async fn join_channel_internal(
3088 channel_id: ChannelId,
3089 response: Box<impl JoinChannelInternalResponse>,
3090 session: Session,
3091) -> Result<()> {
3092 let joined_room = {
3093 let mut db = session.db().await;
3094 // If zed quits without leaving the room, and the user re-opens zed before the
3095 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3096 // room they were in.
3097 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3098 tracing::info!(
3099 stale_connection_id = %connection,
3100 "cleaning up stale connection",
3101 );
3102 drop(db);
3103 leave_room_for_session(&session, connection).await?;
3104 db = session.db().await;
3105 }
3106
3107 let (joined_room, membership_updated, role) = db
3108 .join_channel(channel_id, session.user_id(), session.connection_id)
3109 .await?;
3110
3111 let live_kit_connection_info =
3112 session
3113 .app_state
3114 .livekit_client
3115 .as_ref()
3116 .and_then(|live_kit| {
3117 let (can_publish, token) = if role == ChannelRole::Guest {
3118 (
3119 false,
3120 live_kit
3121 .guest_token(
3122 &joined_room.room.livekit_room,
3123 &session.user_id().to_string(),
3124 )
3125 .trace_err()?,
3126 )
3127 } else {
3128 (
3129 true,
3130 live_kit
3131 .room_token(
3132 &joined_room.room.livekit_room,
3133 &session.user_id().to_string(),
3134 )
3135 .trace_err()?,
3136 )
3137 };
3138
3139 Some(LiveKitConnectionInfo {
3140 server_url: live_kit.url().into(),
3141 token,
3142 can_publish,
3143 })
3144 });
3145
3146 response.send(proto::JoinRoomResponse {
3147 room: Some(joined_room.room.clone()),
3148 channel_id: joined_room
3149 .channel
3150 .as_ref()
3151 .map(|channel| channel.id.to_proto()),
3152 live_kit_connection_info,
3153 })?;
3154
3155 let mut connection_pool = session.connection_pool().await;
3156 if let Some(membership_updated) = membership_updated {
3157 notify_membership_updated(
3158 &mut connection_pool,
3159 membership_updated,
3160 session.user_id(),
3161 &session.peer,
3162 );
3163 }
3164
3165 room_updated(&joined_room.room, &session.peer);
3166
3167 joined_room
3168 };
3169
3170 channel_updated(
3171 &joined_room
3172 .channel
3173 .ok_or_else(|| anyhow!("channel not returned"))?,
3174 &joined_room.room,
3175 &session.peer,
3176 &*session.connection_pool().await,
3177 );
3178
3179 update_user_contacts(session.user_id(), &session).await?;
3180 Ok(())
3181}
3182
3183/// Start editing the channel notes
3184async fn join_channel_buffer(
3185 request: proto::JoinChannelBuffer,
3186 response: Response<proto::JoinChannelBuffer>,
3187 session: Session,
3188) -> Result<()> {
3189 let db = session.db().await;
3190 let channel_id = ChannelId::from_proto(request.channel_id);
3191
3192 let open_response = db
3193 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3194 .await?;
3195
3196 let collaborators = open_response.collaborators.clone();
3197 response.send(open_response)?;
3198
3199 let update = UpdateChannelBufferCollaborators {
3200 channel_id: channel_id.to_proto(),
3201 collaborators: collaborators.clone(),
3202 };
3203 channel_buffer_updated(
3204 session.connection_id,
3205 collaborators
3206 .iter()
3207 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3208 &update,
3209 &session.peer,
3210 );
3211
3212 Ok(())
3213}
3214
3215/// Edit the channel notes
3216async fn update_channel_buffer(
3217 request: proto::UpdateChannelBuffer,
3218 session: Session,
3219) -> Result<()> {
3220 let db = session.db().await;
3221 let channel_id = ChannelId::from_proto(request.channel_id);
3222
3223 let (collaborators, epoch, version) = db
3224 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3225 .await?;
3226
3227 channel_buffer_updated(
3228 session.connection_id,
3229 collaborators.clone(),
3230 &proto::UpdateChannelBuffer {
3231 channel_id: channel_id.to_proto(),
3232 operations: request.operations,
3233 },
3234 &session.peer,
3235 );
3236
3237 let pool = &*session.connection_pool().await;
3238
3239 let non_collaborators =
3240 pool.channel_connection_ids(channel_id)
3241 .filter_map(|(connection_id, _)| {
3242 if collaborators.contains(&connection_id) {
3243 None
3244 } else {
3245 Some(connection_id)
3246 }
3247 });
3248
3249 broadcast(None, non_collaborators, |peer_id| {
3250 session.peer.send(
3251 peer_id,
3252 proto::UpdateChannels {
3253 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3254 channel_id: channel_id.to_proto(),
3255 epoch: epoch as u64,
3256 version: version.clone(),
3257 }],
3258 ..Default::default()
3259 },
3260 )
3261 });
3262
3263 Ok(())
3264}
3265
3266/// Rejoin the channel notes after a connection blip
3267async fn rejoin_channel_buffers(
3268 request: proto::RejoinChannelBuffers,
3269 response: Response<proto::RejoinChannelBuffers>,
3270 session: Session,
3271) -> Result<()> {
3272 let db = session.db().await;
3273 let buffers = db
3274 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3275 .await?;
3276
3277 for rejoined_buffer in &buffers {
3278 let collaborators_to_notify = rejoined_buffer
3279 .buffer
3280 .collaborators
3281 .iter()
3282 .filter_map(|c| Some(c.peer_id?.into()));
3283 channel_buffer_updated(
3284 session.connection_id,
3285 collaborators_to_notify,
3286 &proto::UpdateChannelBufferCollaborators {
3287 channel_id: rejoined_buffer.buffer.channel_id,
3288 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3289 },
3290 &session.peer,
3291 );
3292 }
3293
3294 response.send(proto::RejoinChannelBuffersResponse {
3295 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3296 })?;
3297
3298 Ok(())
3299}
3300
3301/// Stop editing the channel notes
3302async fn leave_channel_buffer(
3303 request: proto::LeaveChannelBuffer,
3304 response: Response<proto::LeaveChannelBuffer>,
3305 session: Session,
3306) -> Result<()> {
3307 let db = session.db().await;
3308 let channel_id = ChannelId::from_proto(request.channel_id);
3309
3310 let left_buffer = db
3311 .leave_channel_buffer(channel_id, session.connection_id)
3312 .await?;
3313
3314 response.send(Ack {})?;
3315
3316 channel_buffer_updated(
3317 session.connection_id,
3318 left_buffer.connections,
3319 &proto::UpdateChannelBufferCollaborators {
3320 channel_id: channel_id.to_proto(),
3321 collaborators: left_buffer.collaborators,
3322 },
3323 &session.peer,
3324 );
3325
3326 Ok(())
3327}
3328
3329fn channel_buffer_updated<T: EnvelopedMessage>(
3330 sender_id: ConnectionId,
3331 collaborators: impl IntoIterator<Item = ConnectionId>,
3332 message: &T,
3333 peer: &Peer,
3334) {
3335 broadcast(Some(sender_id), collaborators, |peer_id| {
3336 peer.send(peer_id, message.clone())
3337 });
3338}
3339
3340fn send_notifications(
3341 connection_pool: &ConnectionPool,
3342 peer: &Peer,
3343 notifications: db::NotificationBatch,
3344) {
3345 for (user_id, notification) in notifications {
3346 for connection_id in connection_pool.user_connection_ids(user_id) {
3347 if let Err(error) = peer.send(
3348 connection_id,
3349 proto::AddNotification {
3350 notification: Some(notification.clone()),
3351 },
3352 ) {
3353 tracing::error!(
3354 "failed to send notification to {:?} {}",
3355 connection_id,
3356 error
3357 );
3358 }
3359 }
3360 }
3361}
3362
3363/// Send a message to the channel
3364async fn send_channel_message(
3365 request: proto::SendChannelMessage,
3366 response: Response<proto::SendChannelMessage>,
3367 session: Session,
3368) -> Result<()> {
3369 // Validate the message body.
3370 let body = request.body.trim().to_string();
3371 if body.len() > MAX_MESSAGE_LEN {
3372 return Err(anyhow!("message is too long"))?;
3373 }
3374 if body.is_empty() {
3375 return Err(anyhow!("message can't be blank"))?;
3376 }
3377
3378 // TODO: adjust mentions if body is trimmed
3379
3380 let timestamp = OffsetDateTime::now_utc();
3381 let nonce = request
3382 .nonce
3383 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3384
3385 let channel_id = ChannelId::from_proto(request.channel_id);
3386 let CreatedChannelMessage {
3387 message_id,
3388 participant_connection_ids,
3389 notifications,
3390 } = session
3391 .db()
3392 .await
3393 .create_channel_message(
3394 channel_id,
3395 session.user_id(),
3396 &body,
3397 &request.mentions,
3398 timestamp,
3399 nonce.clone().into(),
3400 request.reply_to_message_id.map(MessageId::from_proto),
3401 )
3402 .await?;
3403
3404 let message = proto::ChannelMessage {
3405 sender_id: session.user_id().to_proto(),
3406 id: message_id.to_proto(),
3407 body,
3408 mentions: request.mentions,
3409 timestamp: timestamp.unix_timestamp() as u64,
3410 nonce: Some(nonce),
3411 reply_to_message_id: request.reply_to_message_id,
3412 edited_at: None,
3413 };
3414 broadcast(
3415 Some(session.connection_id),
3416 participant_connection_ids.clone(),
3417 |connection| {
3418 session.peer.send(
3419 connection,
3420 proto::ChannelMessageSent {
3421 channel_id: channel_id.to_proto(),
3422 message: Some(message.clone()),
3423 },
3424 )
3425 },
3426 );
3427 response.send(proto::SendChannelMessageResponse {
3428 message: Some(message),
3429 })?;
3430
3431 let pool = &*session.connection_pool().await;
3432 let non_participants =
3433 pool.channel_connection_ids(channel_id)
3434 .filter_map(|(connection_id, _)| {
3435 if participant_connection_ids.contains(&connection_id) {
3436 None
3437 } else {
3438 Some(connection_id)
3439 }
3440 });
3441 broadcast(None, non_participants, |peer_id| {
3442 session.peer.send(
3443 peer_id,
3444 proto::UpdateChannels {
3445 latest_channel_message_ids: vec![proto::ChannelMessageId {
3446 channel_id: channel_id.to_proto(),
3447 message_id: message_id.to_proto(),
3448 }],
3449 ..Default::default()
3450 },
3451 )
3452 });
3453 send_notifications(pool, &session.peer, notifications);
3454
3455 Ok(())
3456}
3457
3458/// Delete a channel message
3459async fn remove_channel_message(
3460 request: proto::RemoveChannelMessage,
3461 response: Response<proto::RemoveChannelMessage>,
3462 session: Session,
3463) -> Result<()> {
3464 let channel_id = ChannelId::from_proto(request.channel_id);
3465 let message_id = MessageId::from_proto(request.message_id);
3466 let (connection_ids, existing_notification_ids) = session
3467 .db()
3468 .await
3469 .remove_channel_message(channel_id, message_id, session.user_id())
3470 .await?;
3471
3472 broadcast(
3473 Some(session.connection_id),
3474 connection_ids,
3475 move |connection| {
3476 session.peer.send(connection, request.clone())?;
3477
3478 for notification_id in &existing_notification_ids {
3479 session.peer.send(
3480 connection,
3481 proto::DeleteNotification {
3482 notification_id: (*notification_id).to_proto(),
3483 },
3484 )?;
3485 }
3486
3487 Ok(())
3488 },
3489 );
3490 response.send(proto::Ack {})?;
3491 Ok(())
3492}
3493
3494async fn update_channel_message(
3495 request: proto::UpdateChannelMessage,
3496 response: Response<proto::UpdateChannelMessage>,
3497 session: Session,
3498) -> Result<()> {
3499 let channel_id = ChannelId::from_proto(request.channel_id);
3500 let message_id = MessageId::from_proto(request.message_id);
3501 let updated_at = OffsetDateTime::now_utc();
3502 let UpdatedChannelMessage {
3503 message_id,
3504 participant_connection_ids,
3505 notifications,
3506 reply_to_message_id,
3507 timestamp,
3508 deleted_mention_notification_ids,
3509 updated_mention_notifications,
3510 } = session
3511 .db()
3512 .await
3513 .update_channel_message(
3514 channel_id,
3515 message_id,
3516 session.user_id(),
3517 request.body.as_str(),
3518 &request.mentions,
3519 updated_at,
3520 )
3521 .await?;
3522
3523 let nonce = request
3524 .nonce
3525 .clone()
3526 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3527
3528 let message = proto::ChannelMessage {
3529 sender_id: session.user_id().to_proto(),
3530 id: message_id.to_proto(),
3531 body: request.body.clone(),
3532 mentions: request.mentions.clone(),
3533 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3534 nonce: Some(nonce),
3535 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3536 edited_at: Some(updated_at.unix_timestamp() as u64),
3537 };
3538
3539 response.send(proto::Ack {})?;
3540
3541 let pool = &*session.connection_pool().await;
3542 broadcast(
3543 Some(session.connection_id),
3544 participant_connection_ids,
3545 |connection| {
3546 session.peer.send(
3547 connection,
3548 proto::ChannelMessageUpdate {
3549 channel_id: channel_id.to_proto(),
3550 message: Some(message.clone()),
3551 },
3552 )?;
3553
3554 for notification_id in &deleted_mention_notification_ids {
3555 session.peer.send(
3556 connection,
3557 proto::DeleteNotification {
3558 notification_id: (*notification_id).to_proto(),
3559 },
3560 )?;
3561 }
3562
3563 for notification in &updated_mention_notifications {
3564 session.peer.send(
3565 connection,
3566 proto::UpdateNotification {
3567 notification: Some(notification.clone()),
3568 },
3569 )?;
3570 }
3571
3572 Ok(())
3573 },
3574 );
3575
3576 send_notifications(pool, &session.peer, notifications);
3577
3578 Ok(())
3579}
3580
3581/// Mark a channel message as read
3582async fn acknowledge_channel_message(
3583 request: proto::AckChannelMessage,
3584 session: Session,
3585) -> Result<()> {
3586 let channel_id = ChannelId::from_proto(request.channel_id);
3587 let message_id = MessageId::from_proto(request.message_id);
3588 let notifications = session
3589 .db()
3590 .await
3591 .observe_channel_message(channel_id, session.user_id(), message_id)
3592 .await?;
3593 send_notifications(
3594 &*session.connection_pool().await,
3595 &session.peer,
3596 notifications,
3597 );
3598 Ok(())
3599}
3600
3601/// Mark a buffer version as synced
3602async fn acknowledge_buffer_version(
3603 request: proto::AckBufferOperation,
3604 session: Session,
3605) -> Result<()> {
3606 let buffer_id = BufferId::from_proto(request.buffer_id);
3607 session
3608 .db()
3609 .await
3610 .observe_buffer_version(
3611 buffer_id,
3612 session.user_id(),
3613 request.epoch as i32,
3614 &request.version,
3615 )
3616 .await?;
3617 Ok(())
3618}
3619
3620async fn count_language_model_tokens(
3621 request: proto::CountLanguageModelTokens,
3622 response: Response<proto::CountLanguageModelTokens>,
3623 session: Session,
3624 config: &Config,
3625) -> Result<()> {
3626 authorize_access_to_legacy_llm_endpoints(&session).await?;
3627
3628 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3629 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3630 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3631 };
3632
3633 session
3634 .app_state
3635 .rate_limiter
3636 .check(&*rate_limit, session.user_id())
3637 .await?;
3638
3639 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3640 Some(proto::LanguageModelProvider::Google) => {
3641 let api_key = config
3642 .google_ai_api_key
3643 .as_ref()
3644 .context("no Google AI API key configured on the server")?;
3645 google_ai::count_tokens(
3646 session.http_client.as_ref(),
3647 google_ai::API_URL,
3648 api_key,
3649 serde_json::from_str(&request.request)?,
3650 )
3651 .await?
3652 }
3653 _ => return Err(anyhow!("unsupported provider"))?,
3654 };
3655
3656 response.send(proto::CountLanguageModelTokensResponse {
3657 token_count: result.total_tokens as u32,
3658 })?;
3659
3660 Ok(())
3661}
3662
3663struct ZedProCountLanguageModelTokensRateLimit;
3664
3665impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3666 fn capacity(&self) -> usize {
3667 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3668 .ok()
3669 .and_then(|v| v.parse().ok())
3670 .unwrap_or(600) // Picked arbitrarily
3671 }
3672
3673 fn refill_duration(&self) -> chrono::Duration {
3674 chrono::Duration::hours(1)
3675 }
3676
3677 fn db_name(&self) -> &'static str {
3678 "zed-pro:count-language-model-tokens"
3679 }
3680}
3681
3682struct FreeCountLanguageModelTokensRateLimit;
3683
3684impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3685 fn capacity(&self) -> usize {
3686 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3687 .ok()
3688 .and_then(|v| v.parse().ok())
3689 .unwrap_or(600 / 10) // Picked arbitrarily
3690 }
3691
3692 fn refill_duration(&self) -> chrono::Duration {
3693 chrono::Duration::hours(1)
3694 }
3695
3696 fn db_name(&self) -> &'static str {
3697 "free:count-language-model-tokens"
3698 }
3699}
3700
3701struct ZedProComputeEmbeddingsRateLimit;
3702
3703impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3704 fn capacity(&self) -> usize {
3705 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3706 .ok()
3707 .and_then(|v| v.parse().ok())
3708 .unwrap_or(5000) // Picked arbitrarily
3709 }
3710
3711 fn refill_duration(&self) -> chrono::Duration {
3712 chrono::Duration::hours(1)
3713 }
3714
3715 fn db_name(&self) -> &'static str {
3716 "zed-pro:compute-embeddings"
3717 }
3718}
3719
3720struct FreeComputeEmbeddingsRateLimit;
3721
3722impl RateLimit for FreeComputeEmbeddingsRateLimit {
3723 fn capacity(&self) -> usize {
3724 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3725 .ok()
3726 .and_then(|v| v.parse().ok())
3727 .unwrap_or(5000 / 10) // Picked arbitrarily
3728 }
3729
3730 fn refill_duration(&self) -> chrono::Duration {
3731 chrono::Duration::hours(1)
3732 }
3733
3734 fn db_name(&self) -> &'static str {
3735 "free:compute-embeddings"
3736 }
3737}
3738
3739async fn compute_embeddings(
3740 request: proto::ComputeEmbeddings,
3741 response: Response<proto::ComputeEmbeddings>,
3742 session: Session,
3743 api_key: Option<Arc<str>>,
3744) -> Result<()> {
3745 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3746 authorize_access_to_legacy_llm_endpoints(&session).await?;
3747
3748 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3749 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3750 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3751 };
3752
3753 session
3754 .app_state
3755 .rate_limiter
3756 .check(&*rate_limit, session.user_id())
3757 .await?;
3758
3759 let embeddings = match request.model.as_str() {
3760 "openai/text-embedding-3-small" => {
3761 open_ai::embed(
3762 session.http_client.as_ref(),
3763 OPEN_AI_API_URL,
3764 &api_key,
3765 OpenAiEmbeddingModel::TextEmbedding3Small,
3766 request.texts.iter().map(|text| text.as_str()),
3767 )
3768 .await?
3769 }
3770 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3771 };
3772
3773 let embeddings = request
3774 .texts
3775 .iter()
3776 .map(|text| {
3777 let mut hasher = sha2::Sha256::new();
3778 hasher.update(text.as_bytes());
3779 let result = hasher.finalize();
3780 result.to_vec()
3781 })
3782 .zip(
3783 embeddings
3784 .data
3785 .into_iter()
3786 .map(|embedding| embedding.embedding),
3787 )
3788 .collect::<HashMap<_, _>>();
3789
3790 let db = session.db().await;
3791 db.save_embeddings(&request.model, &embeddings)
3792 .await
3793 .context("failed to save embeddings")
3794 .trace_err();
3795
3796 response.send(proto::ComputeEmbeddingsResponse {
3797 embeddings: embeddings
3798 .into_iter()
3799 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3800 .collect(),
3801 })?;
3802 Ok(())
3803}
3804
3805async fn get_cached_embeddings(
3806 request: proto::GetCachedEmbeddings,
3807 response: Response<proto::GetCachedEmbeddings>,
3808 session: Session,
3809) -> Result<()> {
3810 authorize_access_to_legacy_llm_endpoints(&session).await?;
3811
3812 let db = session.db().await;
3813 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3814
3815 response.send(proto::GetCachedEmbeddingsResponse {
3816 embeddings: embeddings
3817 .into_iter()
3818 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3819 .collect(),
3820 })?;
3821 Ok(())
3822}
3823
3824/// This is leftover from before the LLM service.
3825///
3826/// The endpoints protected by this check will be moved there eventually.
3827async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3828 if session.is_staff() {
3829 Ok(())
3830 } else {
3831 Err(anyhow!("permission denied"))?
3832 }
3833}
3834
3835/// Get a Supermaven API key for the user
3836async fn get_supermaven_api_key(
3837 _request: proto::GetSupermavenApiKey,
3838 response: Response<proto::GetSupermavenApiKey>,
3839 session: Session,
3840) -> Result<()> {
3841 let user_id: String = session.user_id().to_string();
3842 if !session.is_staff() {
3843 return Err(anyhow!("supermaven not enabled for this account"))?;
3844 }
3845
3846 let email = session
3847 .email()
3848 .ok_or_else(|| anyhow!("user must have an email"))?;
3849
3850 let supermaven_admin_api = session
3851 .supermaven_client
3852 .as_ref()
3853 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3854
3855 let result = supermaven_admin_api
3856 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3857 .await?;
3858
3859 response.send(proto::GetSupermavenApiKeyResponse {
3860 api_key: result.api_key,
3861 })?;
3862
3863 Ok(())
3864}
3865
3866/// Start receiving chat updates for a channel
3867async fn join_channel_chat(
3868 request: proto::JoinChannelChat,
3869 response: Response<proto::JoinChannelChat>,
3870 session: Session,
3871) -> Result<()> {
3872 let channel_id = ChannelId::from_proto(request.channel_id);
3873
3874 let db = session.db().await;
3875 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3876 .await?;
3877 let messages = db
3878 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3879 .await?;
3880 response.send(proto::JoinChannelChatResponse {
3881 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3882 messages,
3883 })?;
3884 Ok(())
3885}
3886
3887/// Stop receiving chat updates for a channel
3888async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3889 let channel_id = ChannelId::from_proto(request.channel_id);
3890 session
3891 .db()
3892 .await
3893 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3894 .await?;
3895 Ok(())
3896}
3897
3898/// Retrieve the chat history for a channel
3899async fn get_channel_messages(
3900 request: proto::GetChannelMessages,
3901 response: Response<proto::GetChannelMessages>,
3902 session: Session,
3903) -> Result<()> {
3904 let channel_id = ChannelId::from_proto(request.channel_id);
3905 let messages = session
3906 .db()
3907 .await
3908 .get_channel_messages(
3909 channel_id,
3910 session.user_id(),
3911 MESSAGE_COUNT_PER_PAGE,
3912 Some(MessageId::from_proto(request.before_message_id)),
3913 )
3914 .await?;
3915 response.send(proto::GetChannelMessagesResponse {
3916 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3917 messages,
3918 })?;
3919 Ok(())
3920}
3921
3922/// Retrieve specific chat messages
3923async fn get_channel_messages_by_id(
3924 request: proto::GetChannelMessagesById,
3925 response: Response<proto::GetChannelMessagesById>,
3926 session: Session,
3927) -> Result<()> {
3928 let message_ids = request
3929 .message_ids
3930 .iter()
3931 .map(|id| MessageId::from_proto(*id))
3932 .collect::<Vec<_>>();
3933 let messages = session
3934 .db()
3935 .await
3936 .get_channel_messages_by_id(session.user_id(), &message_ids)
3937 .await?;
3938 response.send(proto::GetChannelMessagesResponse {
3939 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3940 messages,
3941 })?;
3942 Ok(())
3943}
3944
3945/// Retrieve the current users notifications
3946async fn get_notifications(
3947 request: proto::GetNotifications,
3948 response: Response<proto::GetNotifications>,
3949 session: Session,
3950) -> Result<()> {
3951 let notifications = session
3952 .db()
3953 .await
3954 .get_notifications(
3955 session.user_id(),
3956 NOTIFICATION_COUNT_PER_PAGE,
3957 request.before_id.map(db::NotificationId::from_proto),
3958 )
3959 .await?;
3960 response.send(proto::GetNotificationsResponse {
3961 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3962 notifications,
3963 })?;
3964 Ok(())
3965}
3966
3967/// Mark notifications as read
3968async fn mark_notification_as_read(
3969 request: proto::MarkNotificationRead,
3970 response: Response<proto::MarkNotificationRead>,
3971 session: Session,
3972) -> Result<()> {
3973 let database = &session.db().await;
3974 let notifications = database
3975 .mark_notification_as_read_by_id(
3976 session.user_id(),
3977 NotificationId::from_proto(request.notification_id),
3978 )
3979 .await?;
3980 send_notifications(
3981 &*session.connection_pool().await,
3982 &session.peer,
3983 notifications,
3984 );
3985 response.send(proto::Ack {})?;
3986 Ok(())
3987}
3988
3989/// Get the current users information
3990async fn get_private_user_info(
3991 _request: proto::GetPrivateUserInfo,
3992 response: Response<proto::GetPrivateUserInfo>,
3993 session: Session,
3994) -> Result<()> {
3995 let db = session.db().await;
3996
3997 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3998 let user = db
3999 .get_user_by_id(session.user_id())
4000 .await?
4001 .ok_or_else(|| anyhow!("user not found"))?;
4002 let flags = db.get_user_flags(session.user_id()).await?;
4003
4004 response.send(proto::GetPrivateUserInfoResponse {
4005 metrics_id,
4006 staff: user.admin,
4007 flags,
4008 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4009 })?;
4010 Ok(())
4011}
4012
4013/// Accept the terms of service (tos) on behalf of the current user
4014async fn accept_terms_of_service(
4015 _request: proto::AcceptTermsOfService,
4016 response: Response<proto::AcceptTermsOfService>,
4017 session: Session,
4018) -> Result<()> {
4019 let db = session.db().await;
4020
4021 let accepted_tos_at = Utc::now();
4022 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4023 .await?;
4024
4025 response.send(proto::AcceptTermsOfServiceResponse {
4026 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4027 })?;
4028 Ok(())
4029}
4030
4031/// The minimum account age an account must have in order to use the LLM service.
4032const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4033
4034async fn get_llm_api_token(
4035 _request: proto::GetLlmToken,
4036 response: Response<proto::GetLlmToken>,
4037 session: Session,
4038) -> Result<()> {
4039 let db = session.db().await;
4040
4041 let flags = db.get_user_flags(session.user_id()).await?;
4042 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4043 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4044 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4045
4046 if !session.is_staff() && !has_language_models_feature_flag {
4047 Err(anyhow!("permission denied"))?
4048 }
4049
4050 let user_id = session.user_id();
4051 let user = db
4052 .get_user_by_id(user_id)
4053 .await?
4054 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4055
4056 if user.accepted_tos_at.is_none() {
4057 Err(anyhow!("terms of service not accepted"))?
4058 }
4059
4060 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4061
4062 let bypass_account_age_check =
4063 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4064 if !bypass_account_age_check {
4065 let mut account_created_at = user.created_at;
4066 if let Some(github_created_at) = user.github_user_created_at {
4067 account_created_at = account_created_at.min(github_created_at);
4068 }
4069 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4070 Err(anyhow!("account too young"))?
4071 }
4072 }
4073
4074 let billing_preferences = db.get_billing_preferences(user.id).await?;
4075
4076 let token = LlmTokenClaims::create(
4077 &user,
4078 session.is_staff(),
4079 billing_preferences,
4080 has_llm_closed_beta_feature_flag,
4081 has_predict_edits_feature_flag,
4082 has_llm_subscription,
4083 session.current_plan(&db).await?,
4084 session.system_id.clone(),
4085 &session.app_state.config,
4086 )?;
4087 response.send(proto::GetLlmTokenResponse { token })?;
4088 Ok(())
4089}
4090
4091fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4092 let message = match message {
4093 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4094 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4095 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4096 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4097 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4098 code: frame.code.into(),
4099 reason: frame.reason,
4100 })),
4101 // We should never receive a frame while reading the message, according
4102 // to the `tungstenite` maintainers:
4103 //
4104 // > It cannot occur when you read messages from the WebSocket, but it
4105 // > can be used when you want to send the raw frames (e.g. you want to
4106 // > send the frames to the WebSocket without composing the full message first).
4107 // >
4108 // > — https://github.com/snapview/tungstenite-rs/issues/268
4109 TungsteniteMessage::Frame(_) => {
4110 bail!("received an unexpected frame while reading the message")
4111 }
4112 };
4113
4114 Ok(message)
4115}
4116
4117fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4118 match message {
4119 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4120 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4121 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4122 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4123 AxumMessage::Close(frame) => {
4124 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4125 code: frame.code.into(),
4126 reason: frame.reason,
4127 }))
4128 }
4129 }
4130}
4131
4132fn notify_membership_updated(
4133 connection_pool: &mut ConnectionPool,
4134 result: MembershipUpdated,
4135 user_id: UserId,
4136 peer: &Peer,
4137) {
4138 for membership in &result.new_channels.channel_memberships {
4139 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4140 }
4141 for channel_id in &result.removed_channels {
4142 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4143 }
4144
4145 let user_channels_update = proto::UpdateUserChannels {
4146 channel_memberships: result
4147 .new_channels
4148 .channel_memberships
4149 .iter()
4150 .map(|cm| proto::ChannelMembership {
4151 channel_id: cm.channel_id.to_proto(),
4152 role: cm.role.into(),
4153 })
4154 .collect(),
4155 ..Default::default()
4156 };
4157
4158 let mut update = build_channels_update(result.new_channels);
4159 update.delete_channels = result
4160 .removed_channels
4161 .into_iter()
4162 .map(|id| id.to_proto())
4163 .collect();
4164 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4165
4166 for connection_id in connection_pool.user_connection_ids(user_id) {
4167 peer.send(connection_id, user_channels_update.clone())
4168 .trace_err();
4169 peer.send(connection_id, update.clone()).trace_err();
4170 }
4171}
4172
4173fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4174 proto::UpdateUserChannels {
4175 channel_memberships: channels
4176 .channel_memberships
4177 .iter()
4178 .map(|m| proto::ChannelMembership {
4179 channel_id: m.channel_id.to_proto(),
4180 role: m.role.into(),
4181 })
4182 .collect(),
4183 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4184 observed_channel_message_id: channels.observed_channel_messages.clone(),
4185 }
4186}
4187
4188fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4189 let mut update = proto::UpdateChannels::default();
4190
4191 for channel in channels.channels {
4192 update.channels.push(channel.to_proto());
4193 }
4194
4195 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4196 update.latest_channel_message_ids = channels.latest_channel_messages;
4197
4198 for (channel_id, participants) in channels.channel_participants {
4199 update
4200 .channel_participants
4201 .push(proto::ChannelParticipants {
4202 channel_id: channel_id.to_proto(),
4203 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4204 });
4205 }
4206
4207 for channel in channels.invited_channels {
4208 update.channel_invitations.push(channel.to_proto());
4209 }
4210
4211 update
4212}
4213
4214fn build_initial_contacts_update(
4215 contacts: Vec<db::Contact>,
4216 pool: &ConnectionPool,
4217) -> proto::UpdateContacts {
4218 let mut update = proto::UpdateContacts::default();
4219
4220 for contact in contacts {
4221 match contact {
4222 db::Contact::Accepted { user_id, busy } => {
4223 update.contacts.push(contact_for_user(user_id, busy, pool));
4224 }
4225 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4226 db::Contact::Incoming { user_id } => {
4227 update
4228 .incoming_requests
4229 .push(proto::IncomingContactRequest {
4230 requester_id: user_id.to_proto(),
4231 })
4232 }
4233 }
4234 }
4235
4236 update
4237}
4238
4239fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4240 proto::Contact {
4241 user_id: user_id.to_proto(),
4242 online: pool.is_user_online(user_id),
4243 busy,
4244 }
4245}
4246
4247fn room_updated(room: &proto::Room, peer: &Peer) {
4248 broadcast(
4249 None,
4250 room.participants
4251 .iter()
4252 .filter_map(|participant| Some(participant.peer_id?.into())),
4253 |peer_id| {
4254 peer.send(
4255 peer_id,
4256 proto::RoomUpdated {
4257 room: Some(room.clone()),
4258 },
4259 )
4260 },
4261 );
4262}
4263
4264fn channel_updated(
4265 channel: &db::channel::Model,
4266 room: &proto::Room,
4267 peer: &Peer,
4268 pool: &ConnectionPool,
4269) {
4270 let participants = room
4271 .participants
4272 .iter()
4273 .map(|p| p.user_id)
4274 .collect::<Vec<_>>();
4275
4276 broadcast(
4277 None,
4278 pool.channel_connection_ids(channel.root_id())
4279 .filter_map(|(channel_id, role)| {
4280 role.can_see_channel(channel.visibility)
4281 .then_some(channel_id)
4282 }),
4283 |peer_id| {
4284 peer.send(
4285 peer_id,
4286 proto::UpdateChannels {
4287 channel_participants: vec![proto::ChannelParticipants {
4288 channel_id: channel.id.to_proto(),
4289 participant_user_ids: participants.clone(),
4290 }],
4291 ..Default::default()
4292 },
4293 )
4294 },
4295 );
4296}
4297
4298async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4299 let db = session.db().await;
4300
4301 let contacts = db.get_contacts(user_id).await?;
4302 let busy = db.is_user_busy(user_id).await?;
4303
4304 let pool = session.connection_pool().await;
4305 let updated_contact = contact_for_user(user_id, busy, &pool);
4306 for contact in contacts {
4307 if let db::Contact::Accepted {
4308 user_id: contact_user_id,
4309 ..
4310 } = contact
4311 {
4312 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4313 session
4314 .peer
4315 .send(
4316 contact_conn_id,
4317 proto::UpdateContacts {
4318 contacts: vec![updated_contact.clone()],
4319 remove_contacts: Default::default(),
4320 incoming_requests: Default::default(),
4321 remove_incoming_requests: Default::default(),
4322 outgoing_requests: Default::default(),
4323 remove_outgoing_requests: Default::default(),
4324 },
4325 )
4326 .trace_err();
4327 }
4328 }
4329 }
4330 Ok(())
4331}
4332
4333async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4334 let mut contacts_to_update = HashSet::default();
4335
4336 let room_id;
4337 let canceled_calls_to_user_ids;
4338 let livekit_room;
4339 let delete_livekit_room;
4340 let room;
4341 let channel;
4342
4343 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4344 contacts_to_update.insert(session.user_id());
4345
4346 for project in left_room.left_projects.values() {
4347 project_left(project, session);
4348 }
4349
4350 room_id = RoomId::from_proto(left_room.room.id);
4351 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4352 livekit_room = mem::take(&mut left_room.room.livekit_room);
4353 delete_livekit_room = left_room.deleted;
4354 room = mem::take(&mut left_room.room);
4355 channel = mem::take(&mut left_room.channel);
4356
4357 room_updated(&room, &session.peer);
4358 } else {
4359 return Ok(());
4360 }
4361
4362 if let Some(channel) = channel {
4363 channel_updated(
4364 &channel,
4365 &room,
4366 &session.peer,
4367 &*session.connection_pool().await,
4368 );
4369 }
4370
4371 {
4372 let pool = session.connection_pool().await;
4373 for canceled_user_id in canceled_calls_to_user_ids {
4374 for connection_id in pool.user_connection_ids(canceled_user_id) {
4375 session
4376 .peer
4377 .send(
4378 connection_id,
4379 proto::CallCanceled {
4380 room_id: room_id.to_proto(),
4381 },
4382 )
4383 .trace_err();
4384 }
4385 contacts_to_update.insert(canceled_user_id);
4386 }
4387 }
4388
4389 for contact_user_id in contacts_to_update {
4390 update_user_contacts(contact_user_id, session).await?;
4391 }
4392
4393 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4394 live_kit
4395 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4396 .await
4397 .trace_err();
4398
4399 if delete_livekit_room {
4400 live_kit.delete_room(livekit_room).await.trace_err();
4401 }
4402 }
4403
4404 Ok(())
4405}
4406
4407async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4408 let left_channel_buffers = session
4409 .db()
4410 .await
4411 .leave_channel_buffers(session.connection_id)
4412 .await?;
4413
4414 for left_buffer in left_channel_buffers {
4415 channel_buffer_updated(
4416 session.connection_id,
4417 left_buffer.connections,
4418 &proto::UpdateChannelBufferCollaborators {
4419 channel_id: left_buffer.channel_id.to_proto(),
4420 collaborators: left_buffer.collaborators,
4421 },
4422 &session.peer,
4423 );
4424 }
4425
4426 Ok(())
4427}
4428
4429fn project_left(project: &db::LeftProject, session: &Session) {
4430 for connection_id in &project.connection_ids {
4431 if project.should_unshare {
4432 session
4433 .peer
4434 .send(
4435 *connection_id,
4436 proto::UnshareProject {
4437 project_id: project.id.to_proto(),
4438 },
4439 )
4440 .trace_err();
4441 } else {
4442 session
4443 .peer
4444 .send(
4445 *connection_id,
4446 proto::RemoveProjectCollaborator {
4447 project_id: project.id.to_proto(),
4448 peer_id: Some(session.connection_id.into()),
4449 },
4450 )
4451 .trace_err();
4452 }
4453 }
4454}
4455
4456pub trait ResultExt {
4457 type Ok;
4458
4459 fn trace_err(self) -> Option<Self::Ok>;
4460}
4461
4462impl<T, E> ResultExt for Result<T, E>
4463where
4464 E: std::fmt::Debug,
4465{
4466 type Ok = T;
4467
4468 #[track_caller]
4469 fn trace_err(self) -> Option<T> {
4470 match self {
4471 Ok(value) => Some(value),
4472 Err(error) => {
4473 tracing::error!("{:?}", error);
4474 None
4475 }
4476 }
4477 }
4478}