1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use sha2::Digest;
41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
42
43use futures::{
44 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
45 TryStreamExt,
46};
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, MutexGuard, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87
88type MessageHandler =
89 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
90
91struct Response<R> {
92 peer: Arc<Peer>,
93 receipt: Receipt<R>,
94 responded: Arc<AtomicBool>,
95}
96
97impl<R: RequestMessage> Response<R> {
98 fn send(self, payload: R::Response) -> Result<()> {
99 self.responded.store(true, SeqCst);
100 self.peer.respond(self.receipt, payload)?;
101 Ok(())
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum Principal {
107 User(User),
108 Impersonated { user: User, admin: User },
109}
110
111impl Principal {
112 fn update_span(&self, span: &tracing::Span) {
113 match &self {
114 Principal::User(user) => {
115 span.record("user_id", user.id.0);
116 span.record("login", &user.github_login);
117 }
118 Principal::Impersonated { user, admin } => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 span.record("impersonator", &admin.github_login);
122 }
123 }
124 }
125}
126
127#[derive(Clone)]
128struct Session {
129 principal: Principal,
130 connection_id: ConnectionId,
131 db: Arc<tokio::sync::Mutex<DbHandle>>,
132 peer: Arc<Peer>,
133 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
134 app_state: Arc<AppState>,
135 supermaven_client: Option<Arc<SupermavenAdminApi>>,
136 http_client: Arc<dyn HttpClient>,
137 /// The GeoIP country code for the user.
138 #[allow(unused)]
139 geoip_country_code: Option<String>,
140 system_id: Option<String>,
141 _executor: Executor,
142}
143
144impl Session {
145 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
146 #[cfg(test)]
147 tokio::task::yield_now().await;
148 let guard = self.db.lock().await;
149 #[cfg(test)]
150 tokio::task::yield_now().await;
151 guard
152 }
153
154 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 let guard = self.connection_pool.lock();
158 ConnectionPoolGuard {
159 guard,
160 _not_send: PhantomData,
161 }
162 }
163
164 fn is_staff(&self) -> bool {
165 match &self.principal {
166 Principal::User(user) => user.admin,
167 Principal::Impersonated { .. } => true,
168 }
169 }
170
171 pub async fn has_llm_subscription(
172 &self,
173 db: &MutexGuard<'_, DbHandle>,
174 ) -> anyhow::Result<bool> {
175 if self.is_staff() {
176 return Ok(true);
177 }
178
179 let user_id = self.user_id();
180
181 Ok(db.has_active_billing_subscription(user_id).await?)
182 }
183
184 pub async fn current_plan(
185 &self,
186 _db: &MutexGuard<'_, DbHandle>,
187 ) -> anyhow::Result<proto::Plan> {
188 if self.is_staff() {
189 Ok(proto::Plan::ZedPro)
190 } else {
191 Ok(proto::Plan::Free)
192 }
193 }
194
195 fn user_id(&self) -> UserId {
196 match &self.principal {
197 Principal::User(user) => user.id,
198 Principal::Impersonated { user, .. } => user.id,
199 }
200 }
201
202 pub fn email(&self) -> Option<String> {
203 match &self.principal {
204 Principal::User(user) => user.email_address.clone(),
205 Principal::Impersonated { user, .. } => user.email_address.clone(),
206 }
207 }
208}
209
210impl Debug for Session {
211 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
212 let mut result = f.debug_struct("Session");
213 match &self.principal {
214 Principal::User(user) => {
215 result.field("user", &user.github_login);
216 }
217 Principal::Impersonated { user, admin } => {
218 result.field("user", &user.github_login);
219 result.field("impersonator", &admin.github_login);
220 }
221 }
222 result.field("connection_id", &self.connection_id).finish()
223 }
224}
225
226struct DbHandle(Arc<Database>);
227
228impl Deref for DbHandle {
229 type Target = Database;
230
231 fn deref(&self) -> &Self::Target {
232 self.0.as_ref()
233 }
234}
235
236pub struct Server {
237 id: parking_lot::Mutex<ServerId>,
238 peer: Arc<Peer>,
239 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
240 app_state: Arc<AppState>,
241 handlers: HashMap<TypeId, MessageHandler>,
242 teardown: watch::Sender<bool>,
243}
244
245pub(crate) struct ConnectionPoolGuard<'a> {
246 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
247 _not_send: PhantomData<Rc<()>>,
248}
249
250#[derive(Serialize)]
251pub struct ServerSnapshot<'a> {
252 peer: &'a Peer,
253 #[serde(serialize_with = "serialize_deref")]
254 connection_pool: ConnectionPoolGuard<'a>,
255}
256
257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
258where
259 S: Serializer,
260 T: Deref<Target = U>,
261 U: Serialize,
262{
263 Serialize::serialize(value.deref(), serializer)
264}
265
266impl Server {
267 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
268 let mut server = Self {
269 id: parking_lot::Mutex::new(id),
270 peer: Peer::new(id.0 as u32),
271 app_state: app_state.clone(),
272 connection_pool: Default::default(),
273 handlers: Default::default(),
274 teardown: watch::channel(false).0,
275 };
276
277 server
278 .add_request_handler(ping)
279 .add_request_handler(create_room)
280 .add_request_handler(join_room)
281 .add_request_handler(rejoin_room)
282 .add_request_handler(leave_room)
283 .add_request_handler(set_room_participant_role)
284 .add_request_handler(call)
285 .add_request_handler(cancel_call)
286 .add_message_handler(decline_call)
287 .add_request_handler(update_participant_location)
288 .add_request_handler(share_project)
289 .add_message_handler(unshare_project)
290 .add_request_handler(join_project)
291 .add_message_handler(leave_project)
292 .add_request_handler(update_project)
293 .add_request_handler(update_worktree)
294 .add_message_handler(start_language_server)
295 .add_message_handler(update_language_server)
296 .add_message_handler(update_diagnostic_summary)
297 .add_message_handler(update_worktree_settings)
298 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
299 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
300 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
301 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
302 .add_request_handler(forward_find_search_candidates_request)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
307 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
308 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
311 .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
313 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
314 .add_request_handler(
315 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
316 )
317 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
318 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
323 .add_request_handler(
324 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
325 )
326 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
327 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
328 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
329 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
331 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
332 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
333 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
334 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
341 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
342 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
343 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
344 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
345 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
346 .add_message_handler(create_buffer_for_peer)
347 .add_request_handler(update_buffer)
348 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
349 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
350 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
351 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
353 .add_request_handler(get_users)
354 .add_request_handler(fuzzy_search_users)
355 .add_request_handler(request_contact)
356 .add_request_handler(remove_contact)
357 .add_request_handler(respond_to_contact_request)
358 .add_message_handler(subscribe_to_channels)
359 .add_request_handler(create_channel)
360 .add_request_handler(delete_channel)
361 .add_request_handler(invite_channel_member)
362 .add_request_handler(remove_channel_member)
363 .add_request_handler(set_channel_member_role)
364 .add_request_handler(set_channel_visibility)
365 .add_request_handler(rename_channel)
366 .add_request_handler(join_channel_buffer)
367 .add_request_handler(leave_channel_buffer)
368 .add_message_handler(update_channel_buffer)
369 .add_request_handler(rejoin_channel_buffers)
370 .add_request_handler(get_channel_members)
371 .add_request_handler(respond_to_channel_invite)
372 .add_request_handler(join_channel)
373 .add_request_handler(join_channel_chat)
374 .add_message_handler(leave_channel_chat)
375 .add_request_handler(send_channel_message)
376 .add_request_handler(remove_channel_message)
377 .add_request_handler(update_channel_message)
378 .add_request_handler(get_channel_messages)
379 .add_request_handler(get_channel_messages_by_id)
380 .add_request_handler(get_notifications)
381 .add_request_handler(mark_notification_as_read)
382 .add_request_handler(move_channel)
383 .add_request_handler(follow)
384 .add_message_handler(unfollow)
385 .add_message_handler(update_followers)
386 .add_request_handler(get_private_user_info)
387 .add_request_handler(get_llm_api_token)
388 .add_request_handler(accept_terms_of_service)
389 .add_message_handler(acknowledge_channel_message)
390 .add_message_handler(acknowledge_buffer_version)
391 .add_request_handler(get_supermaven_api_key)
392 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
393 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
394 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
395 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
396 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
397 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
398 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
399 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
400 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
401 .add_message_handler(update_context)
402 .add_request_handler({
403 let app_state = app_state.clone();
404 move |request, response, session| {
405 let app_state = app_state.clone();
406 async move {
407 count_language_model_tokens(request, response, session, &app_state.config)
408 .await
409 }
410 }
411 })
412 .add_request_handler(get_cached_embeddings)
413 .add_request_handler({
414 let app_state = app_state.clone();
415 move |request, response, session| {
416 compute_embeddings(
417 request,
418 response,
419 session,
420 app_state.config.openai_api_key.clone(),
421 )
422 }
423 });
424
425 Arc::new(server)
426 }
427
428 pub async fn start(&self) -> Result<()> {
429 let server_id = *self.id.lock();
430 let app_state = self.app_state.clone();
431 let peer = self.peer.clone();
432 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
433 let pool = self.connection_pool.clone();
434 let livekit_client = self.app_state.livekit_client.clone();
435
436 let span = info_span!("start server");
437 self.app_state.executor.spawn_detached(
438 async move {
439 tracing::info!("waiting for cleanup timeout");
440 timeout.await;
441 tracing::info!("cleanup timeout expired, retrieving stale rooms");
442 if let Some((room_ids, channel_ids)) = app_state
443 .db
444 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
445 .await
446 .trace_err()
447 {
448 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
449 tracing::info!(
450 stale_channel_buffer_count = channel_ids.len(),
451 "retrieved stale channel buffers"
452 );
453
454 for channel_id in channel_ids {
455 if let Some(refreshed_channel_buffer) = app_state
456 .db
457 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
458 .await
459 .trace_err()
460 {
461 for connection_id in refreshed_channel_buffer.connection_ids {
462 peer.send(
463 connection_id,
464 proto::UpdateChannelBufferCollaborators {
465 channel_id: channel_id.to_proto(),
466 collaborators: refreshed_channel_buffer
467 .collaborators
468 .clone(),
469 },
470 )
471 .trace_err();
472 }
473 }
474 }
475
476 for room_id in room_ids {
477 let mut contacts_to_update = HashSet::default();
478 let mut canceled_calls_to_user_ids = Vec::new();
479 let mut livekit_room = String::new();
480 let mut delete_livekit_room = false;
481
482 if let Some(mut refreshed_room) = app_state
483 .db
484 .clear_stale_room_participants(room_id, server_id)
485 .await
486 .trace_err()
487 {
488 tracing::info!(
489 room_id = room_id.0,
490 new_participant_count = refreshed_room.room.participants.len(),
491 "refreshed room"
492 );
493 room_updated(&refreshed_room.room, &peer);
494 if let Some(channel) = refreshed_room.channel.as_ref() {
495 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
496 }
497 contacts_to_update
498 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
499 contacts_to_update
500 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
501 canceled_calls_to_user_ids =
502 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
503 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
504 delete_livekit_room = refreshed_room.room.participants.is_empty();
505 }
506
507 {
508 let pool = pool.lock();
509 for canceled_user_id in canceled_calls_to_user_ids {
510 for connection_id in pool.user_connection_ids(canceled_user_id) {
511 peer.send(
512 connection_id,
513 proto::CallCanceled {
514 room_id: room_id.to_proto(),
515 },
516 )
517 .trace_err();
518 }
519 }
520 }
521
522 for user_id in contacts_to_update {
523 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
524 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
525 if let Some((busy, contacts)) = busy.zip(contacts) {
526 let pool = pool.lock();
527 let updated_contact = contact_for_user(user_id, busy, &pool);
528 for contact in contacts {
529 if let db::Contact::Accepted {
530 user_id: contact_user_id,
531 ..
532 } = contact
533 {
534 for contact_conn_id in
535 pool.user_connection_ids(contact_user_id)
536 {
537 peer.send(
538 contact_conn_id,
539 proto::UpdateContacts {
540 contacts: vec![updated_contact.clone()],
541 remove_contacts: Default::default(),
542 incoming_requests: Default::default(),
543 remove_incoming_requests: Default::default(),
544 outgoing_requests: Default::default(),
545 remove_outgoing_requests: Default::default(),
546 },
547 )
548 .trace_err();
549 }
550 }
551 }
552 }
553 }
554
555 if let Some(live_kit) = livekit_client.as_ref() {
556 if delete_livekit_room {
557 live_kit.delete_room(livekit_room).await.trace_err();
558 }
559 }
560 }
561 }
562
563 app_state
564 .db
565 .delete_stale_servers(&app_state.config.zed_environment, server_id)
566 .await
567 .trace_err();
568 }
569 .instrument(span),
570 );
571 Ok(())
572 }
573
574 pub fn teardown(&self) {
575 self.peer.teardown();
576 self.connection_pool.lock().reset();
577 let _ = self.teardown.send(true);
578 }
579
580 #[cfg(test)]
581 pub fn reset(&self, id: ServerId) {
582 self.teardown();
583 *self.id.lock() = id;
584 self.peer.reset(id.0 as u32);
585 let _ = self.teardown.send(false);
586 }
587
588 #[cfg(test)]
589 pub fn id(&self) -> ServerId {
590 *self.id.lock()
591 }
592
593 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
594 where
595 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
596 Fut: 'static + Send + Future<Output = Result<()>>,
597 M: EnvelopedMessage,
598 {
599 let prev_handler = self.handlers.insert(
600 TypeId::of::<M>(),
601 Box::new(move |envelope, session| {
602 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
603 let received_at = envelope.received_at;
604 tracing::info!("message received");
605 let start_time = Instant::now();
606 let future = (handler)(*envelope, session);
607 async move {
608 let result = future.await;
609 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
610 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
611 let queue_duration_ms = total_duration_ms - processing_duration_ms;
612 let payload_type = M::NAME;
613
614 match result {
615 Err(error) => {
616 tracing::error!(
617 ?error,
618 total_duration_ms,
619 processing_duration_ms,
620 queue_duration_ms,
621 payload_type,
622 "error handling message"
623 )
624 }
625 Ok(()) => tracing::info!(
626 total_duration_ms,
627 processing_duration_ms,
628 queue_duration_ms,
629 "finished handling message"
630 ),
631 }
632 }
633 .boxed()
634 }),
635 );
636 if prev_handler.is_some() {
637 panic!("registered a handler for the same message twice");
638 }
639 self
640 }
641
642 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
643 where
644 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
645 Fut: 'static + Send + Future<Output = Result<()>>,
646 M: EnvelopedMessage,
647 {
648 self.add_handler(move |envelope, session| handler(envelope.payload, session));
649 self
650 }
651
652 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
653 where
654 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
655 Fut: Send + Future<Output = Result<()>>,
656 M: RequestMessage,
657 {
658 let handler = Arc::new(handler);
659 self.add_handler(move |envelope, session| {
660 let receipt = envelope.receipt();
661 let handler = handler.clone();
662 async move {
663 let peer = session.peer.clone();
664 let responded = Arc::new(AtomicBool::default());
665 let response = Response {
666 peer: peer.clone(),
667 responded: responded.clone(),
668 receipt,
669 };
670 match (handler)(envelope.payload, response, session).await {
671 Ok(()) => {
672 if responded.load(std::sync::atomic::Ordering::SeqCst) {
673 Ok(())
674 } else {
675 Err(anyhow!("handler did not send a response"))?
676 }
677 }
678 Err(error) => {
679 let proto_err = match &error {
680 Error::Internal(err) => err.to_proto(),
681 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
682 };
683 peer.respond_with_error(receipt, proto_err)?;
684 Err(error)
685 }
686 }
687 }
688 })
689 }
690
691 #[allow(clippy::too_many_arguments)]
692 pub fn handle_connection(
693 self: &Arc<Self>,
694 connection: Connection,
695 address: String,
696 principal: Principal,
697 zed_version: ZedVersion,
698 geoip_country_code: Option<String>,
699 system_id: Option<String>,
700 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
701 executor: Executor,
702 ) -> impl Future<Output = ()> {
703 let this = self.clone();
704 let span = info_span!("handle connection", %address,
705 connection_id=field::Empty,
706 user_id=field::Empty,
707 login=field::Empty,
708 impersonator=field::Empty,
709 geoip_country_code=field::Empty
710 );
711 principal.update_span(&span);
712 if let Some(country_code) = geoip_country_code.as_ref() {
713 span.record("geoip_country_code", country_code);
714 }
715
716 let mut teardown = self.teardown.subscribe();
717 async move {
718 if *teardown.borrow() {
719 tracing::error!("server is tearing down");
720 return
721 }
722 let (connection_id, handle_io, mut incoming_rx) = this
723 .peer
724 .add_connection(connection, {
725 let executor = executor.clone();
726 move |duration| executor.sleep(duration)
727 });
728 tracing::Span::current().record("connection_id", format!("{}", connection_id));
729
730 tracing::info!("connection opened");
731
732 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
733 let http_client = match ReqwestClient::user_agent(&user_agent) {
734 Ok(http_client) => Arc::new(http_client),
735 Err(error) => {
736 tracing::error!(?error, "failed to create HTTP client");
737 return;
738 }
739 };
740
741 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
742 supermaven_admin_api_key.to_string(),
743 http_client.clone(),
744 )));
745
746 let session = Session {
747 principal: principal.clone(),
748 connection_id,
749 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
750 peer: this.peer.clone(),
751 connection_pool: this.connection_pool.clone(),
752 app_state: this.app_state.clone(),
753 http_client,
754 geoip_country_code,
755 system_id,
756 _executor: executor.clone(),
757 supermaven_client,
758 };
759
760 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
761 tracing::error!(?error, "failed to send initial client update");
762 return;
763 }
764
765 let handle_io = handle_io.fuse();
766 futures::pin_mut!(handle_io);
767
768 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
769 // This prevents deadlocks when e.g., client A performs a request to client B and
770 // client B performs a request to client A. If both clients stop processing further
771 // messages until their respective request completes, they won't have a chance to
772 // respond to the other client's request and cause a deadlock.
773 //
774 // This arrangement ensures we will attempt to process earlier messages first, but fall
775 // back to processing messages arrived later in the spirit of making progress.
776 let mut foreground_message_handlers = FuturesUnordered::new();
777 let concurrent_handlers = Arc::new(Semaphore::new(256));
778 loop {
779 let next_message = async {
780 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
781 let message = incoming_rx.next().await;
782 (permit, message)
783 }.fuse();
784 futures::pin_mut!(next_message);
785 futures::select_biased! {
786 _ = teardown.changed().fuse() => return,
787 result = handle_io => {
788 if let Err(error) = result {
789 tracing::error!(?error, "error handling I/O");
790 }
791 break;
792 }
793 _ = foreground_message_handlers.next() => {}
794 next_message = next_message => {
795 let (permit, message) = next_message;
796 if let Some(message) = message {
797 let type_name = message.payload_type_name();
798 // note: we copy all the fields from the parent span so we can query them in the logs.
799 // (https://github.com/tokio-rs/tracing/issues/2670).
800 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
801 user_id=field::Empty,
802 login=field::Empty,
803 impersonator=field::Empty,
804 );
805 principal.update_span(&span);
806 let span_enter = span.enter();
807 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
808 let is_background = message.is_background();
809 let handle_message = (handler)(message, session.clone());
810 drop(span_enter);
811
812 let handle_message = async move {
813 handle_message.await;
814 drop(permit);
815 }.instrument(span);
816 if is_background {
817 executor.spawn_detached(handle_message);
818 } else {
819 foreground_message_handlers.push(handle_message);
820 }
821 } else {
822 tracing::error!("no message handler");
823 }
824 } else {
825 tracing::info!("connection closed");
826 break;
827 }
828 }
829 }
830 }
831
832 drop(foreground_message_handlers);
833 tracing::info!("signing out");
834 if let Err(error) = connection_lost(session, teardown, executor).await {
835 tracing::error!(?error, "error signing out");
836 }
837
838 }.instrument(span)
839 }
840
841 async fn send_initial_client_update(
842 &self,
843 connection_id: ConnectionId,
844 principal: &Principal,
845 zed_version: ZedVersion,
846 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
847 session: &Session,
848 ) -> Result<()> {
849 self.peer.send(
850 connection_id,
851 proto::Hello {
852 peer_id: Some(connection_id.into()),
853 },
854 )?;
855 tracing::info!("sent hello message");
856 if let Some(send_connection_id) = send_connection_id.take() {
857 let _ = send_connection_id.send(connection_id);
858 }
859
860 match principal {
861 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
862 if !user.connected_once {
863 self.peer.send(connection_id, proto::ShowContacts {})?;
864 self.app_state
865 .db
866 .set_user_connected_once(user.id, true)
867 .await?;
868 }
869
870 update_user_plan(user.id, session).await?;
871
872 let contacts = self.app_state.db.get_contacts(user.id).await?;
873
874 {
875 let mut pool = self.connection_pool.lock();
876 pool.add_connection(connection_id, user.id, user.admin, zed_version);
877 self.peer.send(
878 connection_id,
879 build_initial_contacts_update(contacts, &pool),
880 )?;
881 }
882
883 if should_auto_subscribe_to_channels(zed_version) {
884 subscribe_user_to_channels(user.id, session).await?;
885 }
886
887 if let Some(incoming_call) =
888 self.app_state.db.incoming_call_for_user(user.id).await?
889 {
890 self.peer.send(connection_id, incoming_call)?;
891 }
892
893 update_user_contacts(user.id, session).await?;
894 }
895 }
896
897 Ok(())
898 }
899
900 pub async fn invite_code_redeemed(
901 self: &Arc<Self>,
902 inviter_id: UserId,
903 invitee_id: UserId,
904 ) -> Result<()> {
905 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
906 if let Some(code) = &user.invite_code {
907 let pool = self.connection_pool.lock();
908 let invitee_contact = contact_for_user(invitee_id, false, &pool);
909 for connection_id in pool.user_connection_ids(inviter_id) {
910 self.peer.send(
911 connection_id,
912 proto::UpdateContacts {
913 contacts: vec![invitee_contact.clone()],
914 ..Default::default()
915 },
916 )?;
917 self.peer.send(
918 connection_id,
919 proto::UpdateInviteInfo {
920 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
921 count: user.invite_count as u32,
922 },
923 )?;
924 }
925 }
926 }
927 Ok(())
928 }
929
930 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
931 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
932 if let Some(invite_code) = &user.invite_code {
933 let pool = self.connection_pool.lock();
934 for connection_id in pool.user_connection_ids(user_id) {
935 self.peer.send(
936 connection_id,
937 proto::UpdateInviteInfo {
938 url: format!(
939 "{}{}",
940 self.app_state.config.invite_link_prefix, invite_code
941 ),
942 count: user.invite_count as u32,
943 },
944 )?;
945 }
946 }
947 }
948 Ok(())
949 }
950
951 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
952 let pool = self.connection_pool.lock();
953 for connection_id in pool.user_connection_ids(user_id) {
954 self.peer
955 .send(connection_id, proto::RefreshLlmToken {})
956 .trace_err();
957 }
958 }
959
960 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
961 ServerSnapshot {
962 connection_pool: ConnectionPoolGuard {
963 guard: self.connection_pool.lock(),
964 _not_send: PhantomData,
965 },
966 peer: &self.peer,
967 }
968 }
969}
970
971impl<'a> Deref for ConnectionPoolGuard<'a> {
972 type Target = ConnectionPool;
973
974 fn deref(&self) -> &Self::Target {
975 &self.guard
976 }
977}
978
979impl<'a> DerefMut for ConnectionPoolGuard<'a> {
980 fn deref_mut(&mut self) -> &mut Self::Target {
981 &mut self.guard
982 }
983}
984
985impl<'a> Drop for ConnectionPoolGuard<'a> {
986 fn drop(&mut self) {
987 #[cfg(test)]
988 self.check_invariants();
989 }
990}
991
992fn broadcast<F>(
993 sender_id: Option<ConnectionId>,
994 receiver_ids: impl IntoIterator<Item = ConnectionId>,
995 mut f: F,
996) where
997 F: FnMut(ConnectionId) -> anyhow::Result<()>,
998{
999 for receiver_id in receiver_ids {
1000 if Some(receiver_id) != sender_id {
1001 if let Err(error) = f(receiver_id) {
1002 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1003 }
1004 }
1005 }
1006}
1007
1008pub struct ProtocolVersion(u32);
1009
1010impl Header for ProtocolVersion {
1011 fn name() -> &'static HeaderName {
1012 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1013 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1014 }
1015
1016 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1017 where
1018 Self: Sized,
1019 I: Iterator<Item = &'i axum::http::HeaderValue>,
1020 {
1021 let version = values
1022 .next()
1023 .ok_or_else(axum::headers::Error::invalid)?
1024 .to_str()
1025 .map_err(|_| axum::headers::Error::invalid())?
1026 .parse()
1027 .map_err(|_| axum::headers::Error::invalid())?;
1028 Ok(Self(version))
1029 }
1030
1031 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1032 values.extend([self.0.to_string().parse().unwrap()]);
1033 }
1034}
1035
1036pub struct AppVersionHeader(SemanticVersion);
1037impl Header for AppVersionHeader {
1038 fn name() -> &'static HeaderName {
1039 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1040 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1041 }
1042
1043 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1044 where
1045 Self: Sized,
1046 I: Iterator<Item = &'i axum::http::HeaderValue>,
1047 {
1048 let version = values
1049 .next()
1050 .ok_or_else(axum::headers::Error::invalid)?
1051 .to_str()
1052 .map_err(|_| axum::headers::Error::invalid())?
1053 .parse()
1054 .map_err(|_| axum::headers::Error::invalid())?;
1055 Ok(Self(version))
1056 }
1057
1058 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1059 values.extend([self.0.to_string().parse().unwrap()]);
1060 }
1061}
1062
1063pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1064 Router::new()
1065 .route("/rpc", get(handle_websocket_request))
1066 .layer(
1067 ServiceBuilder::new()
1068 .layer(Extension(server.app_state.clone()))
1069 .layer(middleware::from_fn(auth::validate_header)),
1070 )
1071 .route("/metrics", get(handle_metrics))
1072 .layer(Extension(server))
1073}
1074
1075#[allow(clippy::too_many_arguments)]
1076pub async fn handle_websocket_request(
1077 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1078 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1079 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1080 Extension(server): Extension<Arc<Server>>,
1081 Extension(principal): Extension<Principal>,
1082 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1083 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1084 ws: WebSocketUpgrade,
1085) -> axum::response::Response {
1086 if protocol_version != rpc::PROTOCOL_VERSION {
1087 return (
1088 StatusCode::UPGRADE_REQUIRED,
1089 "client must be upgraded".to_string(),
1090 )
1091 .into_response();
1092 }
1093
1094 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1095 return (
1096 StatusCode::UPGRADE_REQUIRED,
1097 "no version header found".to_string(),
1098 )
1099 .into_response();
1100 };
1101
1102 if !version.can_collaborate() {
1103 return (
1104 StatusCode::UPGRADE_REQUIRED,
1105 "client must be upgraded".to_string(),
1106 )
1107 .into_response();
1108 }
1109
1110 let socket_address = socket_address.to_string();
1111 ws.on_upgrade(move |socket| {
1112 let socket = socket
1113 .map_ok(to_tungstenite_message)
1114 .err_into()
1115 .with(|message| async move { to_axum_message(message) });
1116 let connection = Connection::new(Box::pin(socket));
1117 async move {
1118 server
1119 .handle_connection(
1120 connection,
1121 socket_address,
1122 principal,
1123 version,
1124 country_code_header.map(|header| header.to_string()),
1125 system_id_header.map(|header| header.to_string()),
1126 None,
1127 Executor::Production,
1128 )
1129 .await;
1130 }
1131 })
1132}
1133
1134pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1135 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1136 let connections_metric = CONNECTIONS_METRIC
1137 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1138
1139 let connections = server
1140 .connection_pool
1141 .lock()
1142 .connections()
1143 .filter(|connection| !connection.admin)
1144 .count();
1145 connections_metric.set(connections as _);
1146
1147 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1148 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1149 register_int_gauge!(
1150 "shared_projects",
1151 "number of open projects with one or more guests"
1152 )
1153 .unwrap()
1154 });
1155
1156 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1157 shared_projects_metric.set(shared_projects as _);
1158
1159 let encoder = prometheus::TextEncoder::new();
1160 let metric_families = prometheus::gather();
1161 let encoded_metrics = encoder
1162 .encode_to_string(&metric_families)
1163 .map_err(|err| anyhow!("{}", err))?;
1164 Ok(encoded_metrics)
1165}
1166
1167#[instrument(err, skip(executor))]
1168async fn connection_lost(
1169 session: Session,
1170 mut teardown: watch::Receiver<bool>,
1171 executor: Executor,
1172) -> Result<()> {
1173 session.peer.disconnect(session.connection_id);
1174 session
1175 .connection_pool()
1176 .await
1177 .remove_connection(session.connection_id)?;
1178
1179 session
1180 .db()
1181 .await
1182 .connection_lost(session.connection_id)
1183 .await
1184 .trace_err();
1185
1186 futures::select_biased! {
1187 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1188
1189 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1190 leave_room_for_session(&session, session.connection_id).await.trace_err();
1191 leave_channel_buffers_for_session(&session)
1192 .await
1193 .trace_err();
1194
1195 if !session
1196 .connection_pool()
1197 .await
1198 .is_user_online(session.user_id())
1199 {
1200 let db = session.db().await;
1201 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1202 room_updated(&room, &session.peer);
1203 }
1204 }
1205
1206 update_user_contacts(session.user_id(), &session).await?;
1207 },
1208 _ = teardown.changed().fuse() => {}
1209 }
1210
1211 Ok(())
1212}
1213
1214/// Acknowledges a ping from a client, used to keep the connection alive.
1215async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1216 response.send(proto::Ack {})?;
1217 Ok(())
1218}
1219
1220/// Creates a new room for calling (outside of channels)
1221async fn create_room(
1222 _request: proto::CreateRoom,
1223 response: Response<proto::CreateRoom>,
1224 session: Session,
1225) -> Result<()> {
1226 let livekit_room = nanoid::nanoid!(30);
1227
1228 let live_kit_connection_info = util::maybe!(async {
1229 let live_kit = session.app_state.livekit_client.as_ref();
1230 let live_kit = live_kit?;
1231 let user_id = session.user_id().to_string();
1232
1233 let token = live_kit
1234 .room_token(&livekit_room, &user_id.to_string())
1235 .trace_err()?;
1236
1237 Some(proto::LiveKitConnectionInfo {
1238 server_url: live_kit.url().into(),
1239 token,
1240 can_publish: true,
1241 })
1242 })
1243 .await;
1244
1245 let room = session
1246 .db()
1247 .await
1248 .create_room(session.user_id(), session.connection_id, &livekit_room)
1249 .await?;
1250
1251 response.send(proto::CreateRoomResponse {
1252 room: Some(room.clone()),
1253 live_kit_connection_info,
1254 })?;
1255
1256 update_user_contacts(session.user_id(), &session).await?;
1257 Ok(())
1258}
1259
1260/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1261async fn join_room(
1262 request: proto::JoinRoom,
1263 response: Response<proto::JoinRoom>,
1264 session: Session,
1265) -> Result<()> {
1266 let room_id = RoomId::from_proto(request.id);
1267
1268 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1269
1270 if let Some(channel_id) = channel_id {
1271 return join_channel_internal(channel_id, Box::new(response), session).await;
1272 }
1273
1274 let joined_room = {
1275 let room = session
1276 .db()
1277 .await
1278 .join_room(room_id, session.user_id(), session.connection_id)
1279 .await?;
1280 room_updated(&room.room, &session.peer);
1281 room.into_inner()
1282 };
1283
1284 for connection_id in session
1285 .connection_pool()
1286 .await
1287 .user_connection_ids(session.user_id())
1288 {
1289 session
1290 .peer
1291 .send(
1292 connection_id,
1293 proto::CallCanceled {
1294 room_id: room_id.to_proto(),
1295 },
1296 )
1297 .trace_err();
1298 }
1299
1300 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1301 {
1302 live_kit
1303 .room_token(
1304 &joined_room.room.livekit_room,
1305 &session.user_id().to_string(),
1306 )
1307 .trace_err()
1308 .map(|token| proto::LiveKitConnectionInfo {
1309 server_url: live_kit.url().into(),
1310 token,
1311 can_publish: true,
1312 })
1313 } else {
1314 None
1315 };
1316
1317 response.send(proto::JoinRoomResponse {
1318 room: Some(joined_room.room),
1319 channel_id: None,
1320 live_kit_connection_info,
1321 })?;
1322
1323 update_user_contacts(session.user_id(), &session).await?;
1324 Ok(())
1325}
1326
1327/// Rejoin room is used to reconnect to a room after connection errors.
1328async fn rejoin_room(
1329 request: proto::RejoinRoom,
1330 response: Response<proto::RejoinRoom>,
1331 session: Session,
1332) -> Result<()> {
1333 let room;
1334 let channel;
1335 {
1336 let mut rejoined_room = session
1337 .db()
1338 .await
1339 .rejoin_room(request, session.user_id(), session.connection_id)
1340 .await?;
1341
1342 response.send(proto::RejoinRoomResponse {
1343 room: Some(rejoined_room.room.clone()),
1344 reshared_projects: rejoined_room
1345 .reshared_projects
1346 .iter()
1347 .map(|project| proto::ResharedProject {
1348 id: project.id.to_proto(),
1349 collaborators: project
1350 .collaborators
1351 .iter()
1352 .map(|collaborator| collaborator.to_proto())
1353 .collect(),
1354 })
1355 .collect(),
1356 rejoined_projects: rejoined_room
1357 .rejoined_projects
1358 .iter()
1359 .map(|rejoined_project| rejoined_project.to_proto())
1360 .collect(),
1361 })?;
1362 room_updated(&rejoined_room.room, &session.peer);
1363
1364 for project in &rejoined_room.reshared_projects {
1365 for collaborator in &project.collaborators {
1366 session
1367 .peer
1368 .send(
1369 collaborator.connection_id,
1370 proto::UpdateProjectCollaborator {
1371 project_id: project.id.to_proto(),
1372 old_peer_id: Some(project.old_connection_id.into()),
1373 new_peer_id: Some(session.connection_id.into()),
1374 },
1375 )
1376 .trace_err();
1377 }
1378
1379 broadcast(
1380 Some(session.connection_id),
1381 project
1382 .collaborators
1383 .iter()
1384 .map(|collaborator| collaborator.connection_id),
1385 |connection_id| {
1386 session.peer.forward_send(
1387 session.connection_id,
1388 connection_id,
1389 proto::UpdateProject {
1390 project_id: project.id.to_proto(),
1391 worktrees: project.worktrees.clone(),
1392 },
1393 )
1394 },
1395 );
1396 }
1397
1398 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1399
1400 let rejoined_room = rejoined_room.into_inner();
1401
1402 room = rejoined_room.room;
1403 channel = rejoined_room.channel;
1404 }
1405
1406 if let Some(channel) = channel {
1407 channel_updated(
1408 &channel,
1409 &room,
1410 &session.peer,
1411 &*session.connection_pool().await,
1412 );
1413 }
1414
1415 update_user_contacts(session.user_id(), &session).await?;
1416 Ok(())
1417}
1418
1419fn notify_rejoined_projects(
1420 rejoined_projects: &mut Vec<RejoinedProject>,
1421 session: &Session,
1422) -> Result<()> {
1423 for project in rejoined_projects.iter() {
1424 for collaborator in &project.collaborators {
1425 session
1426 .peer
1427 .send(
1428 collaborator.connection_id,
1429 proto::UpdateProjectCollaborator {
1430 project_id: project.id.to_proto(),
1431 old_peer_id: Some(project.old_connection_id.into()),
1432 new_peer_id: Some(session.connection_id.into()),
1433 },
1434 )
1435 .trace_err();
1436 }
1437 }
1438
1439 for project in rejoined_projects {
1440 for worktree in mem::take(&mut project.worktrees) {
1441 // Stream this worktree's entries.
1442 let message = proto::UpdateWorktree {
1443 project_id: project.id.to_proto(),
1444 worktree_id: worktree.id,
1445 abs_path: worktree.abs_path.clone(),
1446 root_name: worktree.root_name,
1447 updated_entries: worktree.updated_entries,
1448 removed_entries: worktree.removed_entries,
1449 scan_id: worktree.scan_id,
1450 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1451 updated_repositories: worktree.updated_repositories,
1452 removed_repositories: worktree.removed_repositories,
1453 };
1454 for update in proto::split_worktree_update(message) {
1455 session.peer.send(session.connection_id, update.clone())?;
1456 }
1457
1458 // Stream this worktree's diagnostics.
1459 for summary in worktree.diagnostic_summaries {
1460 session.peer.send(
1461 session.connection_id,
1462 proto::UpdateDiagnosticSummary {
1463 project_id: project.id.to_proto(),
1464 worktree_id: worktree.id,
1465 summary: Some(summary),
1466 },
1467 )?;
1468 }
1469
1470 for settings_file in worktree.settings_files {
1471 session.peer.send(
1472 session.connection_id,
1473 proto::UpdateWorktreeSettings {
1474 project_id: project.id.to_proto(),
1475 worktree_id: worktree.id,
1476 path: settings_file.path,
1477 content: Some(settings_file.content),
1478 kind: Some(settings_file.kind.to_proto().into()),
1479 },
1480 )?;
1481 }
1482 }
1483
1484 for language_server in &project.language_servers {
1485 session.peer.send(
1486 session.connection_id,
1487 proto::UpdateLanguageServer {
1488 project_id: project.id.to_proto(),
1489 language_server_id: language_server.id,
1490 variant: Some(
1491 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1492 proto::LspDiskBasedDiagnosticsUpdated {},
1493 ),
1494 ),
1495 },
1496 )?;
1497 }
1498 }
1499 Ok(())
1500}
1501
1502/// leave room disconnects from the room.
1503async fn leave_room(
1504 _: proto::LeaveRoom,
1505 response: Response<proto::LeaveRoom>,
1506 session: Session,
1507) -> Result<()> {
1508 leave_room_for_session(&session, session.connection_id).await?;
1509 response.send(proto::Ack {})?;
1510 Ok(())
1511}
1512
1513/// Updates the permissions of someone else in the room.
1514async fn set_room_participant_role(
1515 request: proto::SetRoomParticipantRole,
1516 response: Response<proto::SetRoomParticipantRole>,
1517 session: Session,
1518) -> Result<()> {
1519 let user_id = UserId::from_proto(request.user_id);
1520 let role = ChannelRole::from(request.role());
1521
1522 let (livekit_room, can_publish) = {
1523 let room = session
1524 .db()
1525 .await
1526 .set_room_participant_role(
1527 session.user_id(),
1528 RoomId::from_proto(request.room_id),
1529 user_id,
1530 role,
1531 )
1532 .await?;
1533
1534 let livekit_room = room.livekit_room.clone();
1535 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1536 room_updated(&room, &session.peer);
1537 (livekit_room, can_publish)
1538 };
1539
1540 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1541 live_kit
1542 .update_participant(
1543 livekit_room.clone(),
1544 request.user_id.to_string(),
1545 livekit_server::proto::ParticipantPermission {
1546 can_subscribe: true,
1547 can_publish,
1548 can_publish_data: can_publish,
1549 hidden: false,
1550 recorder: false,
1551 },
1552 )
1553 .await
1554 .trace_err();
1555 }
1556
1557 response.send(proto::Ack {})?;
1558 Ok(())
1559}
1560
1561/// Call someone else into the current room
1562async fn call(
1563 request: proto::Call,
1564 response: Response<proto::Call>,
1565 session: Session,
1566) -> Result<()> {
1567 let room_id = RoomId::from_proto(request.room_id);
1568 let calling_user_id = session.user_id();
1569 let calling_connection_id = session.connection_id;
1570 let called_user_id = UserId::from_proto(request.called_user_id);
1571 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1572 if !session
1573 .db()
1574 .await
1575 .has_contact(calling_user_id, called_user_id)
1576 .await?
1577 {
1578 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1579 }
1580
1581 let incoming_call = {
1582 let (room, incoming_call) = &mut *session
1583 .db()
1584 .await
1585 .call(
1586 room_id,
1587 calling_user_id,
1588 calling_connection_id,
1589 called_user_id,
1590 initial_project_id,
1591 )
1592 .await?;
1593 room_updated(room, &session.peer);
1594 mem::take(incoming_call)
1595 };
1596 update_user_contacts(called_user_id, &session).await?;
1597
1598 let mut calls = session
1599 .connection_pool()
1600 .await
1601 .user_connection_ids(called_user_id)
1602 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1603 .collect::<FuturesUnordered<_>>();
1604
1605 while let Some(call_response) = calls.next().await {
1606 match call_response.as_ref() {
1607 Ok(_) => {
1608 response.send(proto::Ack {})?;
1609 return Ok(());
1610 }
1611 Err(_) => {
1612 call_response.trace_err();
1613 }
1614 }
1615 }
1616
1617 {
1618 let room = session
1619 .db()
1620 .await
1621 .call_failed(room_id, called_user_id)
1622 .await?;
1623 room_updated(&room, &session.peer);
1624 }
1625 update_user_contacts(called_user_id, &session).await?;
1626
1627 Err(anyhow!("failed to ring user"))?
1628}
1629
1630/// Cancel an outgoing call.
1631async fn cancel_call(
1632 request: proto::CancelCall,
1633 response: Response<proto::CancelCall>,
1634 session: Session,
1635) -> Result<()> {
1636 let called_user_id = UserId::from_proto(request.called_user_id);
1637 let room_id = RoomId::from_proto(request.room_id);
1638 {
1639 let room = session
1640 .db()
1641 .await
1642 .cancel_call(room_id, session.connection_id, called_user_id)
1643 .await?;
1644 room_updated(&room, &session.peer);
1645 }
1646
1647 for connection_id in session
1648 .connection_pool()
1649 .await
1650 .user_connection_ids(called_user_id)
1651 {
1652 session
1653 .peer
1654 .send(
1655 connection_id,
1656 proto::CallCanceled {
1657 room_id: room_id.to_proto(),
1658 },
1659 )
1660 .trace_err();
1661 }
1662 response.send(proto::Ack {})?;
1663
1664 update_user_contacts(called_user_id, &session).await?;
1665 Ok(())
1666}
1667
1668/// Decline an incoming call.
1669async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1670 let room_id = RoomId::from_proto(message.room_id);
1671 {
1672 let room = session
1673 .db()
1674 .await
1675 .decline_call(Some(room_id), session.user_id())
1676 .await?
1677 .ok_or_else(|| anyhow!("failed to decline call"))?;
1678 room_updated(&room, &session.peer);
1679 }
1680
1681 for connection_id in session
1682 .connection_pool()
1683 .await
1684 .user_connection_ids(session.user_id())
1685 {
1686 session
1687 .peer
1688 .send(
1689 connection_id,
1690 proto::CallCanceled {
1691 room_id: room_id.to_proto(),
1692 },
1693 )
1694 .trace_err();
1695 }
1696 update_user_contacts(session.user_id(), &session).await?;
1697 Ok(())
1698}
1699
1700/// Updates other participants in the room with your current location.
1701async fn update_participant_location(
1702 request: proto::UpdateParticipantLocation,
1703 response: Response<proto::UpdateParticipantLocation>,
1704 session: Session,
1705) -> Result<()> {
1706 let room_id = RoomId::from_proto(request.room_id);
1707 let location = request
1708 .location
1709 .ok_or_else(|| anyhow!("invalid location"))?;
1710
1711 let db = session.db().await;
1712 let room = db
1713 .update_room_participant_location(room_id, session.connection_id, location)
1714 .await?;
1715
1716 room_updated(&room, &session.peer);
1717 response.send(proto::Ack {})?;
1718 Ok(())
1719}
1720
1721/// Share a project into the room.
1722async fn share_project(
1723 request: proto::ShareProject,
1724 response: Response<proto::ShareProject>,
1725 session: Session,
1726) -> Result<()> {
1727 let (project_id, room) = &*session
1728 .db()
1729 .await
1730 .share_project(
1731 RoomId::from_proto(request.room_id),
1732 session.connection_id,
1733 &request.worktrees,
1734 request.is_ssh_project,
1735 )
1736 .await?;
1737 response.send(proto::ShareProjectResponse {
1738 project_id: project_id.to_proto(),
1739 })?;
1740 room_updated(room, &session.peer);
1741
1742 Ok(())
1743}
1744
1745/// Unshare a project from the room.
1746async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1747 let project_id = ProjectId::from_proto(message.project_id);
1748 unshare_project_internal(project_id, session.connection_id, &session).await
1749}
1750
1751async fn unshare_project_internal(
1752 project_id: ProjectId,
1753 connection_id: ConnectionId,
1754 session: &Session,
1755) -> Result<()> {
1756 let delete = {
1757 let room_guard = session
1758 .db()
1759 .await
1760 .unshare_project(project_id, connection_id)
1761 .await?;
1762
1763 let (delete, room, guest_connection_ids) = &*room_guard;
1764
1765 let message = proto::UnshareProject {
1766 project_id: project_id.to_proto(),
1767 };
1768
1769 broadcast(
1770 Some(connection_id),
1771 guest_connection_ids.iter().copied(),
1772 |conn_id| session.peer.send(conn_id, message.clone()),
1773 );
1774 if let Some(room) = room {
1775 room_updated(room, &session.peer);
1776 }
1777
1778 *delete
1779 };
1780
1781 if delete {
1782 let db = session.db().await;
1783 db.delete_project(project_id).await?;
1784 }
1785
1786 Ok(())
1787}
1788
1789/// Join someone elses shared project.
1790async fn join_project(
1791 request: proto::JoinProject,
1792 response: Response<proto::JoinProject>,
1793 session: Session,
1794) -> Result<()> {
1795 let project_id = ProjectId::from_proto(request.project_id);
1796
1797 tracing::info!(%project_id, "join project");
1798
1799 let db = session.db().await;
1800 let (project, replica_id) = &mut *db
1801 .join_project(project_id, session.connection_id, session.user_id())
1802 .await?;
1803 drop(db);
1804 tracing::info!(%project_id, "join remote project");
1805 join_project_internal(response, session, project, replica_id)
1806}
1807
1808trait JoinProjectInternalResponse {
1809 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1810}
1811impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1812 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1813 Response::<proto::JoinProject>::send(self, result)
1814 }
1815}
1816
1817fn join_project_internal(
1818 response: impl JoinProjectInternalResponse,
1819 session: Session,
1820 project: &mut Project,
1821 replica_id: &ReplicaId,
1822) -> Result<()> {
1823 let collaborators = project
1824 .collaborators
1825 .iter()
1826 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1827 .map(|collaborator| collaborator.to_proto())
1828 .collect::<Vec<_>>();
1829 let project_id = project.id;
1830 let guest_user_id = session.user_id();
1831
1832 let worktrees = project
1833 .worktrees
1834 .iter()
1835 .map(|(id, worktree)| proto::WorktreeMetadata {
1836 id: *id,
1837 root_name: worktree.root_name.clone(),
1838 visible: worktree.visible,
1839 abs_path: worktree.abs_path.clone(),
1840 })
1841 .collect::<Vec<_>>();
1842
1843 let add_project_collaborator = proto::AddProjectCollaborator {
1844 project_id: project_id.to_proto(),
1845 collaborator: Some(proto::Collaborator {
1846 peer_id: Some(session.connection_id.into()),
1847 replica_id: replica_id.0 as u32,
1848 user_id: guest_user_id.to_proto(),
1849 is_host: false,
1850 }),
1851 };
1852
1853 for collaborator in &collaborators {
1854 session
1855 .peer
1856 .send(
1857 collaborator.peer_id.unwrap().into(),
1858 add_project_collaborator.clone(),
1859 )
1860 .trace_err();
1861 }
1862
1863 // First, we send the metadata associated with each worktree.
1864 response.send(proto::JoinProjectResponse {
1865 project_id: project.id.0 as u64,
1866 worktrees: worktrees.clone(),
1867 replica_id: replica_id.0 as u32,
1868 collaborators: collaborators.clone(),
1869 language_servers: project.language_servers.clone(),
1870 role: project.role.into(),
1871 })?;
1872
1873 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1874 // Stream this worktree's entries.
1875 let message = proto::UpdateWorktree {
1876 project_id: project_id.to_proto(),
1877 worktree_id,
1878 abs_path: worktree.abs_path.clone(),
1879 root_name: worktree.root_name,
1880 updated_entries: worktree.entries,
1881 removed_entries: Default::default(),
1882 scan_id: worktree.scan_id,
1883 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1884 updated_repositories: worktree.repository_entries.into_values().collect(),
1885 removed_repositories: Default::default(),
1886 };
1887 for update in proto::split_worktree_update(message) {
1888 session.peer.send(session.connection_id, update.clone())?;
1889 }
1890
1891 // Stream this worktree's diagnostics.
1892 for summary in worktree.diagnostic_summaries {
1893 session.peer.send(
1894 session.connection_id,
1895 proto::UpdateDiagnosticSummary {
1896 project_id: project_id.to_proto(),
1897 worktree_id: worktree.id,
1898 summary: Some(summary),
1899 },
1900 )?;
1901 }
1902
1903 for settings_file in worktree.settings_files {
1904 session.peer.send(
1905 session.connection_id,
1906 proto::UpdateWorktreeSettings {
1907 project_id: project_id.to_proto(),
1908 worktree_id: worktree.id,
1909 path: settings_file.path,
1910 content: Some(settings_file.content),
1911 kind: Some(settings_file.kind.to_proto() as i32),
1912 },
1913 )?;
1914 }
1915 }
1916
1917 for language_server in &project.language_servers {
1918 session.peer.send(
1919 session.connection_id,
1920 proto::UpdateLanguageServer {
1921 project_id: project_id.to_proto(),
1922 language_server_id: language_server.id,
1923 variant: Some(
1924 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1925 proto::LspDiskBasedDiagnosticsUpdated {},
1926 ),
1927 ),
1928 },
1929 )?;
1930 }
1931
1932 Ok(())
1933}
1934
1935/// Leave someone elses shared project.
1936async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1937 let sender_id = session.connection_id;
1938 let project_id = ProjectId::from_proto(request.project_id);
1939 let db = session.db().await;
1940
1941 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1942 tracing::info!(
1943 %project_id,
1944 "leave project"
1945 );
1946
1947 project_left(project, &session);
1948 if let Some(room) = room {
1949 room_updated(room, &session.peer);
1950 }
1951
1952 Ok(())
1953}
1954
1955/// Updates other participants with changes to the project
1956async fn update_project(
1957 request: proto::UpdateProject,
1958 response: Response<proto::UpdateProject>,
1959 session: Session,
1960) -> Result<()> {
1961 let project_id = ProjectId::from_proto(request.project_id);
1962 let (room, guest_connection_ids) = &*session
1963 .db()
1964 .await
1965 .update_project(project_id, session.connection_id, &request.worktrees)
1966 .await?;
1967 broadcast(
1968 Some(session.connection_id),
1969 guest_connection_ids.iter().copied(),
1970 |connection_id| {
1971 session
1972 .peer
1973 .forward_send(session.connection_id, connection_id, request.clone())
1974 },
1975 );
1976 if let Some(room) = room {
1977 room_updated(room, &session.peer);
1978 }
1979 response.send(proto::Ack {})?;
1980
1981 Ok(())
1982}
1983
1984/// Updates other participants with changes to the worktree
1985async fn update_worktree(
1986 request: proto::UpdateWorktree,
1987 response: Response<proto::UpdateWorktree>,
1988 session: Session,
1989) -> Result<()> {
1990 let guest_connection_ids = session
1991 .db()
1992 .await
1993 .update_worktree(&request, session.connection_id)
1994 .await?;
1995
1996 broadcast(
1997 Some(session.connection_id),
1998 guest_connection_ids.iter().copied(),
1999 |connection_id| {
2000 session
2001 .peer
2002 .forward_send(session.connection_id, connection_id, request.clone())
2003 },
2004 );
2005 response.send(proto::Ack {})?;
2006 Ok(())
2007}
2008
2009/// Updates other participants with changes to the diagnostics
2010async fn update_diagnostic_summary(
2011 message: proto::UpdateDiagnosticSummary,
2012 session: Session,
2013) -> Result<()> {
2014 let guest_connection_ids = session
2015 .db()
2016 .await
2017 .update_diagnostic_summary(&message, session.connection_id)
2018 .await?;
2019
2020 broadcast(
2021 Some(session.connection_id),
2022 guest_connection_ids.iter().copied(),
2023 |connection_id| {
2024 session
2025 .peer
2026 .forward_send(session.connection_id, connection_id, message.clone())
2027 },
2028 );
2029
2030 Ok(())
2031}
2032
2033/// Updates other participants with changes to the worktree settings
2034async fn update_worktree_settings(
2035 message: proto::UpdateWorktreeSettings,
2036 session: Session,
2037) -> Result<()> {
2038 let guest_connection_ids = session
2039 .db()
2040 .await
2041 .update_worktree_settings(&message, session.connection_id)
2042 .await?;
2043
2044 broadcast(
2045 Some(session.connection_id),
2046 guest_connection_ids.iter().copied(),
2047 |connection_id| {
2048 session
2049 .peer
2050 .forward_send(session.connection_id, connection_id, message.clone())
2051 },
2052 );
2053
2054 Ok(())
2055}
2056
2057/// Notify other participants that a language server has started.
2058async fn start_language_server(
2059 request: proto::StartLanguageServer,
2060 session: Session,
2061) -> Result<()> {
2062 let guest_connection_ids = session
2063 .db()
2064 .await
2065 .start_language_server(&request, session.connection_id)
2066 .await?;
2067
2068 broadcast(
2069 Some(session.connection_id),
2070 guest_connection_ids.iter().copied(),
2071 |connection_id| {
2072 session
2073 .peer
2074 .forward_send(session.connection_id, connection_id, request.clone())
2075 },
2076 );
2077 Ok(())
2078}
2079
2080/// Notify other participants that a language server has changed.
2081async fn update_language_server(
2082 request: proto::UpdateLanguageServer,
2083 session: Session,
2084) -> Result<()> {
2085 let project_id = ProjectId::from_proto(request.project_id);
2086 let project_connection_ids = session
2087 .db()
2088 .await
2089 .project_connection_ids(project_id, session.connection_id, true)
2090 .await?;
2091 broadcast(
2092 Some(session.connection_id),
2093 project_connection_ids.iter().copied(),
2094 |connection_id| {
2095 session
2096 .peer
2097 .forward_send(session.connection_id, connection_id, request.clone())
2098 },
2099 );
2100 Ok(())
2101}
2102
2103/// forward a project request to the host. These requests should be read only
2104/// as guests are allowed to send them.
2105async fn forward_read_only_project_request<T>(
2106 request: T,
2107 response: Response<T>,
2108 session: Session,
2109) -> Result<()>
2110where
2111 T: EntityMessage + RequestMessage,
2112{
2113 let project_id = ProjectId::from_proto(request.remote_entity_id());
2114 let host_connection_id = session
2115 .db()
2116 .await
2117 .host_for_read_only_project_request(project_id, session.connection_id)
2118 .await?;
2119 let payload = session
2120 .peer
2121 .forward_request(session.connection_id, host_connection_id, request)
2122 .await?;
2123 response.send(payload)?;
2124 Ok(())
2125}
2126
2127async fn forward_find_search_candidates_request(
2128 request: proto::FindSearchCandidates,
2129 response: Response<proto::FindSearchCandidates>,
2130 session: Session,
2131) -> Result<()> {
2132 let project_id = ProjectId::from_proto(request.remote_entity_id());
2133 let host_connection_id = session
2134 .db()
2135 .await
2136 .host_for_read_only_project_request(project_id, session.connection_id)
2137 .await?;
2138 let payload = session
2139 .peer
2140 .forward_request(session.connection_id, host_connection_id, request)
2141 .await?;
2142 response.send(payload)?;
2143 Ok(())
2144}
2145
2146/// forward a project request to the host. These requests are disallowed
2147/// for guests.
2148async fn forward_mutating_project_request<T>(
2149 request: T,
2150 response: Response<T>,
2151 session: Session,
2152) -> Result<()>
2153where
2154 T: EntityMessage + RequestMessage,
2155{
2156 let project_id = ProjectId::from_proto(request.remote_entity_id());
2157
2158 let host_connection_id = session
2159 .db()
2160 .await
2161 .host_for_mutating_project_request(project_id, session.connection_id)
2162 .await?;
2163 let payload = session
2164 .peer
2165 .forward_request(session.connection_id, host_connection_id, request)
2166 .await?;
2167 response.send(payload)?;
2168 Ok(())
2169}
2170
2171/// Notify other participants that a new buffer has been created
2172async fn create_buffer_for_peer(
2173 request: proto::CreateBufferForPeer,
2174 session: Session,
2175) -> Result<()> {
2176 session
2177 .db()
2178 .await
2179 .check_user_is_project_host(
2180 ProjectId::from_proto(request.project_id),
2181 session.connection_id,
2182 )
2183 .await?;
2184 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2185 session
2186 .peer
2187 .forward_send(session.connection_id, peer_id.into(), request)?;
2188 Ok(())
2189}
2190
2191/// Notify other participants that a buffer has been updated. This is
2192/// allowed for guests as long as the update is limited to selections.
2193async fn update_buffer(
2194 request: proto::UpdateBuffer,
2195 response: Response<proto::UpdateBuffer>,
2196 session: Session,
2197) -> Result<()> {
2198 let project_id = ProjectId::from_proto(request.project_id);
2199 let mut capability = Capability::ReadOnly;
2200
2201 for op in request.operations.iter() {
2202 match op.variant {
2203 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2204 Some(_) => capability = Capability::ReadWrite,
2205 }
2206 }
2207
2208 let host = {
2209 let guard = session
2210 .db()
2211 .await
2212 .connections_for_buffer_update(project_id, session.connection_id, capability)
2213 .await?;
2214
2215 let (host, guests) = &*guard;
2216
2217 broadcast(
2218 Some(session.connection_id),
2219 guests.clone(),
2220 |connection_id| {
2221 session
2222 .peer
2223 .forward_send(session.connection_id, connection_id, request.clone())
2224 },
2225 );
2226
2227 *host
2228 };
2229
2230 if host != session.connection_id {
2231 session
2232 .peer
2233 .forward_request(session.connection_id, host, request.clone())
2234 .await?;
2235 }
2236
2237 response.send(proto::Ack {})?;
2238 Ok(())
2239}
2240
2241async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2242 let project_id = ProjectId::from_proto(message.project_id);
2243
2244 let operation = message.operation.as_ref().context("invalid operation")?;
2245 let capability = match operation.variant.as_ref() {
2246 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2247 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2248 match buffer_op.variant {
2249 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2250 Capability::ReadOnly
2251 }
2252 _ => Capability::ReadWrite,
2253 }
2254 } else {
2255 Capability::ReadWrite
2256 }
2257 }
2258 Some(_) => Capability::ReadWrite,
2259 None => Capability::ReadOnly,
2260 };
2261
2262 let guard = session
2263 .db()
2264 .await
2265 .connections_for_buffer_update(project_id, session.connection_id, capability)
2266 .await?;
2267
2268 let (host, guests) = &*guard;
2269
2270 broadcast(
2271 Some(session.connection_id),
2272 guests.iter().chain([host]).copied(),
2273 |connection_id| {
2274 session
2275 .peer
2276 .forward_send(session.connection_id, connection_id, message.clone())
2277 },
2278 );
2279
2280 Ok(())
2281}
2282
2283/// Notify other participants that a project has been updated.
2284async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2285 request: T,
2286 session: Session,
2287) -> Result<()> {
2288 let project_id = ProjectId::from_proto(request.remote_entity_id());
2289 let project_connection_ids = session
2290 .db()
2291 .await
2292 .project_connection_ids(project_id, session.connection_id, false)
2293 .await?;
2294
2295 broadcast(
2296 Some(session.connection_id),
2297 project_connection_ids.iter().copied(),
2298 |connection_id| {
2299 session
2300 .peer
2301 .forward_send(session.connection_id, connection_id, request.clone())
2302 },
2303 );
2304 Ok(())
2305}
2306
2307/// Start following another user in a call.
2308async fn follow(
2309 request: proto::Follow,
2310 response: Response<proto::Follow>,
2311 session: Session,
2312) -> Result<()> {
2313 let room_id = RoomId::from_proto(request.room_id);
2314 let project_id = request.project_id.map(ProjectId::from_proto);
2315 let leader_id = request
2316 .leader_id
2317 .ok_or_else(|| anyhow!("invalid leader id"))?
2318 .into();
2319 let follower_id = session.connection_id;
2320
2321 session
2322 .db()
2323 .await
2324 .check_room_participants(room_id, leader_id, session.connection_id)
2325 .await?;
2326
2327 let response_payload = session
2328 .peer
2329 .forward_request(session.connection_id, leader_id, request)
2330 .await?;
2331 response.send(response_payload)?;
2332
2333 if let Some(project_id) = project_id {
2334 let room = session
2335 .db()
2336 .await
2337 .follow(room_id, project_id, leader_id, follower_id)
2338 .await?;
2339 room_updated(&room, &session.peer);
2340 }
2341
2342 Ok(())
2343}
2344
2345/// Stop following another user in a call.
2346async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2347 let room_id = RoomId::from_proto(request.room_id);
2348 let project_id = request.project_id.map(ProjectId::from_proto);
2349 let leader_id = request
2350 .leader_id
2351 .ok_or_else(|| anyhow!("invalid leader id"))?
2352 .into();
2353 let follower_id = session.connection_id;
2354
2355 session
2356 .db()
2357 .await
2358 .check_room_participants(room_id, leader_id, session.connection_id)
2359 .await?;
2360
2361 session
2362 .peer
2363 .forward_send(session.connection_id, leader_id, request)?;
2364
2365 if let Some(project_id) = project_id {
2366 let room = session
2367 .db()
2368 .await
2369 .unfollow(room_id, project_id, leader_id, follower_id)
2370 .await?;
2371 room_updated(&room, &session.peer);
2372 }
2373
2374 Ok(())
2375}
2376
2377/// Notify everyone following you of your current location.
2378async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2379 let room_id = RoomId::from_proto(request.room_id);
2380 let database = session.db.lock().await;
2381
2382 let connection_ids = if let Some(project_id) = request.project_id {
2383 let project_id = ProjectId::from_proto(project_id);
2384 database
2385 .project_connection_ids(project_id, session.connection_id, true)
2386 .await?
2387 } else {
2388 database
2389 .room_connection_ids(room_id, session.connection_id)
2390 .await?
2391 };
2392
2393 // For now, don't send view update messages back to that view's current leader.
2394 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2395 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2396 _ => None,
2397 });
2398
2399 for connection_id in connection_ids.iter().cloned() {
2400 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2401 session
2402 .peer
2403 .forward_send(session.connection_id, connection_id, request.clone())?;
2404 }
2405 }
2406 Ok(())
2407}
2408
2409/// Get public data about users.
2410async fn get_users(
2411 request: proto::GetUsers,
2412 response: Response<proto::GetUsers>,
2413 session: Session,
2414) -> Result<()> {
2415 let user_ids = request
2416 .user_ids
2417 .into_iter()
2418 .map(UserId::from_proto)
2419 .collect();
2420 let users = session
2421 .db()
2422 .await
2423 .get_users_by_ids(user_ids)
2424 .await?
2425 .into_iter()
2426 .map(|user| proto::User {
2427 id: user.id.to_proto(),
2428 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2429 github_login: user.github_login,
2430 email: user.email_address,
2431 name: user.name,
2432 })
2433 .collect();
2434 response.send(proto::UsersResponse { users })?;
2435 Ok(())
2436}
2437
2438/// Search for users (to invite) buy Github login
2439async fn fuzzy_search_users(
2440 request: proto::FuzzySearchUsers,
2441 response: Response<proto::FuzzySearchUsers>,
2442 session: Session,
2443) -> Result<()> {
2444 let query = request.query;
2445 let users = match query.len() {
2446 0 => vec![],
2447 1 | 2 => session
2448 .db()
2449 .await
2450 .get_user_by_github_login(&query)
2451 .await?
2452 .into_iter()
2453 .collect(),
2454 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2455 };
2456 let users = users
2457 .into_iter()
2458 .filter(|user| user.id != session.user_id())
2459 .map(|user| proto::User {
2460 id: user.id.to_proto(),
2461 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2462 github_login: user.github_login,
2463 name: user.name,
2464 email: user.email_address,
2465 })
2466 .collect();
2467 response.send(proto::UsersResponse { users })?;
2468 Ok(())
2469}
2470
2471/// Send a contact request to another user.
2472async fn request_contact(
2473 request: proto::RequestContact,
2474 response: Response<proto::RequestContact>,
2475 session: Session,
2476) -> Result<()> {
2477 let requester_id = session.user_id();
2478 let responder_id = UserId::from_proto(request.responder_id);
2479 if requester_id == responder_id {
2480 return Err(anyhow!("cannot add yourself as a contact"))?;
2481 }
2482
2483 let notifications = session
2484 .db()
2485 .await
2486 .send_contact_request(requester_id, responder_id)
2487 .await?;
2488
2489 // Update outgoing contact requests of requester
2490 let mut update = proto::UpdateContacts::default();
2491 update.outgoing_requests.push(responder_id.to_proto());
2492 for connection_id in session
2493 .connection_pool()
2494 .await
2495 .user_connection_ids(requester_id)
2496 {
2497 session.peer.send(connection_id, update.clone())?;
2498 }
2499
2500 // Update incoming contact requests of responder
2501 let mut update = proto::UpdateContacts::default();
2502 update
2503 .incoming_requests
2504 .push(proto::IncomingContactRequest {
2505 requester_id: requester_id.to_proto(),
2506 });
2507 let connection_pool = session.connection_pool().await;
2508 for connection_id in connection_pool.user_connection_ids(responder_id) {
2509 session.peer.send(connection_id, update.clone())?;
2510 }
2511
2512 send_notifications(&connection_pool, &session.peer, notifications);
2513
2514 response.send(proto::Ack {})?;
2515 Ok(())
2516}
2517
2518/// Accept or decline a contact request
2519async fn respond_to_contact_request(
2520 request: proto::RespondToContactRequest,
2521 response: Response<proto::RespondToContactRequest>,
2522 session: Session,
2523) -> Result<()> {
2524 let responder_id = session.user_id();
2525 let requester_id = UserId::from_proto(request.requester_id);
2526 let db = session.db().await;
2527 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2528 db.dismiss_contact_notification(responder_id, requester_id)
2529 .await?;
2530 } else {
2531 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2532
2533 let notifications = db
2534 .respond_to_contact_request(responder_id, requester_id, accept)
2535 .await?;
2536 let requester_busy = db.is_user_busy(requester_id).await?;
2537 let responder_busy = db.is_user_busy(responder_id).await?;
2538
2539 let pool = session.connection_pool().await;
2540 // Update responder with new contact
2541 let mut update = proto::UpdateContacts::default();
2542 if accept {
2543 update
2544 .contacts
2545 .push(contact_for_user(requester_id, requester_busy, &pool));
2546 }
2547 update
2548 .remove_incoming_requests
2549 .push(requester_id.to_proto());
2550 for connection_id in pool.user_connection_ids(responder_id) {
2551 session.peer.send(connection_id, update.clone())?;
2552 }
2553
2554 // Update requester with new contact
2555 let mut update = proto::UpdateContacts::default();
2556 if accept {
2557 update
2558 .contacts
2559 .push(contact_for_user(responder_id, responder_busy, &pool));
2560 }
2561 update
2562 .remove_outgoing_requests
2563 .push(responder_id.to_proto());
2564
2565 for connection_id in pool.user_connection_ids(requester_id) {
2566 session.peer.send(connection_id, update.clone())?;
2567 }
2568
2569 send_notifications(&pool, &session.peer, notifications);
2570 }
2571
2572 response.send(proto::Ack {})?;
2573 Ok(())
2574}
2575
2576/// Remove a contact.
2577async fn remove_contact(
2578 request: proto::RemoveContact,
2579 response: Response<proto::RemoveContact>,
2580 session: Session,
2581) -> Result<()> {
2582 let requester_id = session.user_id();
2583 let responder_id = UserId::from_proto(request.user_id);
2584 let db = session.db().await;
2585 let (contact_accepted, deleted_notification_id) =
2586 db.remove_contact(requester_id, responder_id).await?;
2587
2588 let pool = session.connection_pool().await;
2589 // Update outgoing contact requests of requester
2590 let mut update = proto::UpdateContacts::default();
2591 if contact_accepted {
2592 update.remove_contacts.push(responder_id.to_proto());
2593 } else {
2594 update
2595 .remove_outgoing_requests
2596 .push(responder_id.to_proto());
2597 }
2598 for connection_id in pool.user_connection_ids(requester_id) {
2599 session.peer.send(connection_id, update.clone())?;
2600 }
2601
2602 // Update incoming contact requests of responder
2603 let mut update = proto::UpdateContacts::default();
2604 if contact_accepted {
2605 update.remove_contacts.push(requester_id.to_proto());
2606 } else {
2607 update
2608 .remove_incoming_requests
2609 .push(requester_id.to_proto());
2610 }
2611 for connection_id in pool.user_connection_ids(responder_id) {
2612 session.peer.send(connection_id, update.clone())?;
2613 if let Some(notification_id) = deleted_notification_id {
2614 session.peer.send(
2615 connection_id,
2616 proto::DeleteNotification {
2617 notification_id: notification_id.to_proto(),
2618 },
2619 )?;
2620 }
2621 }
2622
2623 response.send(proto::Ack {})?;
2624 Ok(())
2625}
2626
2627fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2628 version.0.minor() < 139
2629}
2630
2631async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2632 let plan = session.current_plan(&session.db().await).await?;
2633
2634 session
2635 .peer
2636 .send(
2637 session.connection_id,
2638 proto::UpdateUserPlan { plan: plan.into() },
2639 )
2640 .trace_err();
2641
2642 Ok(())
2643}
2644
2645async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2646 subscribe_user_to_channels(session.user_id(), &session).await?;
2647 Ok(())
2648}
2649
2650async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2651 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2652 let mut pool = session.connection_pool().await;
2653 for membership in &channels_for_user.channel_memberships {
2654 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2655 }
2656 session.peer.send(
2657 session.connection_id,
2658 build_update_user_channels(&channels_for_user),
2659 )?;
2660 session.peer.send(
2661 session.connection_id,
2662 build_channels_update(channels_for_user),
2663 )?;
2664 Ok(())
2665}
2666
2667/// Creates a new channel.
2668async fn create_channel(
2669 request: proto::CreateChannel,
2670 response: Response<proto::CreateChannel>,
2671 session: Session,
2672) -> Result<()> {
2673 let db = session.db().await;
2674
2675 let parent_id = request.parent_id.map(ChannelId::from_proto);
2676 let (channel, membership) = db
2677 .create_channel(&request.name, parent_id, session.user_id())
2678 .await?;
2679
2680 let root_id = channel.root_id();
2681 let channel = Channel::from_model(channel);
2682
2683 response.send(proto::CreateChannelResponse {
2684 channel: Some(channel.to_proto()),
2685 parent_id: request.parent_id,
2686 })?;
2687
2688 let mut connection_pool = session.connection_pool().await;
2689 if let Some(membership) = membership {
2690 connection_pool.subscribe_to_channel(
2691 membership.user_id,
2692 membership.channel_id,
2693 membership.role,
2694 );
2695 let update = proto::UpdateUserChannels {
2696 channel_memberships: vec![proto::ChannelMembership {
2697 channel_id: membership.channel_id.to_proto(),
2698 role: membership.role.into(),
2699 }],
2700 ..Default::default()
2701 };
2702 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2703 session.peer.send(connection_id, update.clone())?;
2704 }
2705 }
2706
2707 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2708 if !role.can_see_channel(channel.visibility) {
2709 continue;
2710 }
2711
2712 let update = proto::UpdateChannels {
2713 channels: vec![channel.to_proto()],
2714 ..Default::default()
2715 };
2716 session.peer.send(connection_id, update.clone())?;
2717 }
2718
2719 Ok(())
2720}
2721
2722/// Delete a channel
2723async fn delete_channel(
2724 request: proto::DeleteChannel,
2725 response: Response<proto::DeleteChannel>,
2726 session: Session,
2727) -> Result<()> {
2728 let db = session.db().await;
2729
2730 let channel_id = request.channel_id;
2731 let (root_channel, removed_channels) = db
2732 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2733 .await?;
2734 response.send(proto::Ack {})?;
2735
2736 // Notify members of removed channels
2737 let mut update = proto::UpdateChannels::default();
2738 update
2739 .delete_channels
2740 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2741
2742 let connection_pool = session.connection_pool().await;
2743 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2744 session.peer.send(connection_id, update.clone())?;
2745 }
2746
2747 Ok(())
2748}
2749
2750/// Invite someone to join a channel.
2751async fn invite_channel_member(
2752 request: proto::InviteChannelMember,
2753 response: Response<proto::InviteChannelMember>,
2754 session: Session,
2755) -> Result<()> {
2756 let db = session.db().await;
2757 let channel_id = ChannelId::from_proto(request.channel_id);
2758 let invitee_id = UserId::from_proto(request.user_id);
2759 let InviteMemberResult {
2760 channel,
2761 notifications,
2762 } = db
2763 .invite_channel_member(
2764 channel_id,
2765 invitee_id,
2766 session.user_id(),
2767 request.role().into(),
2768 )
2769 .await?;
2770
2771 let update = proto::UpdateChannels {
2772 channel_invitations: vec![channel.to_proto()],
2773 ..Default::default()
2774 };
2775
2776 let connection_pool = session.connection_pool().await;
2777 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2778 session.peer.send(connection_id, update.clone())?;
2779 }
2780
2781 send_notifications(&connection_pool, &session.peer, notifications);
2782
2783 response.send(proto::Ack {})?;
2784 Ok(())
2785}
2786
2787/// remove someone from a channel
2788async fn remove_channel_member(
2789 request: proto::RemoveChannelMember,
2790 response: Response<proto::RemoveChannelMember>,
2791 session: Session,
2792) -> Result<()> {
2793 let db = session.db().await;
2794 let channel_id = ChannelId::from_proto(request.channel_id);
2795 let member_id = UserId::from_proto(request.user_id);
2796
2797 let RemoveChannelMemberResult {
2798 membership_update,
2799 notification_id,
2800 } = db
2801 .remove_channel_member(channel_id, member_id, session.user_id())
2802 .await?;
2803
2804 let mut connection_pool = session.connection_pool().await;
2805 notify_membership_updated(
2806 &mut connection_pool,
2807 membership_update,
2808 member_id,
2809 &session.peer,
2810 );
2811 for connection_id in connection_pool.user_connection_ids(member_id) {
2812 if let Some(notification_id) = notification_id {
2813 session
2814 .peer
2815 .send(
2816 connection_id,
2817 proto::DeleteNotification {
2818 notification_id: notification_id.to_proto(),
2819 },
2820 )
2821 .trace_err();
2822 }
2823 }
2824
2825 response.send(proto::Ack {})?;
2826 Ok(())
2827}
2828
2829/// Toggle the channel between public and private.
2830/// Care is taken to maintain the invariant that public channels only descend from public channels,
2831/// (though members-only channels can appear at any point in the hierarchy).
2832async fn set_channel_visibility(
2833 request: proto::SetChannelVisibility,
2834 response: Response<proto::SetChannelVisibility>,
2835 session: Session,
2836) -> Result<()> {
2837 let db = session.db().await;
2838 let channel_id = ChannelId::from_proto(request.channel_id);
2839 let visibility = request.visibility().into();
2840
2841 let channel_model = db
2842 .set_channel_visibility(channel_id, visibility, session.user_id())
2843 .await?;
2844 let root_id = channel_model.root_id();
2845 let channel = Channel::from_model(channel_model);
2846
2847 let mut connection_pool = session.connection_pool().await;
2848 for (user_id, role) in connection_pool
2849 .channel_user_ids(root_id)
2850 .collect::<Vec<_>>()
2851 .into_iter()
2852 {
2853 let update = if role.can_see_channel(channel.visibility) {
2854 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2855 proto::UpdateChannels {
2856 channels: vec![channel.to_proto()],
2857 ..Default::default()
2858 }
2859 } else {
2860 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2861 proto::UpdateChannels {
2862 delete_channels: vec![channel.id.to_proto()],
2863 ..Default::default()
2864 }
2865 };
2866
2867 for connection_id in connection_pool.user_connection_ids(user_id) {
2868 session.peer.send(connection_id, update.clone())?;
2869 }
2870 }
2871
2872 response.send(proto::Ack {})?;
2873 Ok(())
2874}
2875
2876/// Alter the role for a user in the channel.
2877async fn set_channel_member_role(
2878 request: proto::SetChannelMemberRole,
2879 response: Response<proto::SetChannelMemberRole>,
2880 session: Session,
2881) -> Result<()> {
2882 let db = session.db().await;
2883 let channel_id = ChannelId::from_proto(request.channel_id);
2884 let member_id = UserId::from_proto(request.user_id);
2885 let result = db
2886 .set_channel_member_role(
2887 channel_id,
2888 session.user_id(),
2889 member_id,
2890 request.role().into(),
2891 )
2892 .await?;
2893
2894 match result {
2895 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2896 let mut connection_pool = session.connection_pool().await;
2897 notify_membership_updated(
2898 &mut connection_pool,
2899 membership_update,
2900 member_id,
2901 &session.peer,
2902 )
2903 }
2904 db::SetMemberRoleResult::InviteUpdated(channel) => {
2905 let update = proto::UpdateChannels {
2906 channel_invitations: vec![channel.to_proto()],
2907 ..Default::default()
2908 };
2909
2910 for connection_id in session
2911 .connection_pool()
2912 .await
2913 .user_connection_ids(member_id)
2914 {
2915 session.peer.send(connection_id, update.clone())?;
2916 }
2917 }
2918 }
2919
2920 response.send(proto::Ack {})?;
2921 Ok(())
2922}
2923
2924/// Change the name of a channel
2925async fn rename_channel(
2926 request: proto::RenameChannel,
2927 response: Response<proto::RenameChannel>,
2928 session: Session,
2929) -> Result<()> {
2930 let db = session.db().await;
2931 let channel_id = ChannelId::from_proto(request.channel_id);
2932 let channel_model = db
2933 .rename_channel(channel_id, session.user_id(), &request.name)
2934 .await?;
2935 let root_id = channel_model.root_id();
2936 let channel = Channel::from_model(channel_model);
2937
2938 response.send(proto::RenameChannelResponse {
2939 channel: Some(channel.to_proto()),
2940 })?;
2941
2942 let connection_pool = session.connection_pool().await;
2943 let update = proto::UpdateChannels {
2944 channels: vec![channel.to_proto()],
2945 ..Default::default()
2946 };
2947 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2948 if role.can_see_channel(channel.visibility) {
2949 session.peer.send(connection_id, update.clone())?;
2950 }
2951 }
2952
2953 Ok(())
2954}
2955
2956/// Move a channel to a new parent.
2957async fn move_channel(
2958 request: proto::MoveChannel,
2959 response: Response<proto::MoveChannel>,
2960 session: Session,
2961) -> Result<()> {
2962 let channel_id = ChannelId::from_proto(request.channel_id);
2963 let to = ChannelId::from_proto(request.to);
2964
2965 let (root_id, channels) = session
2966 .db()
2967 .await
2968 .move_channel(channel_id, to, session.user_id())
2969 .await?;
2970
2971 let connection_pool = session.connection_pool().await;
2972 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2973 let channels = channels
2974 .iter()
2975 .filter_map(|channel| {
2976 if role.can_see_channel(channel.visibility) {
2977 Some(channel.to_proto())
2978 } else {
2979 None
2980 }
2981 })
2982 .collect::<Vec<_>>();
2983 if channels.is_empty() {
2984 continue;
2985 }
2986
2987 let update = proto::UpdateChannels {
2988 channels,
2989 ..Default::default()
2990 };
2991
2992 session.peer.send(connection_id, update.clone())?;
2993 }
2994
2995 response.send(Ack {})?;
2996 Ok(())
2997}
2998
2999/// Get the list of channel members
3000async fn get_channel_members(
3001 request: proto::GetChannelMembers,
3002 response: Response<proto::GetChannelMembers>,
3003 session: Session,
3004) -> Result<()> {
3005 let db = session.db().await;
3006 let channel_id = ChannelId::from_proto(request.channel_id);
3007 let limit = if request.limit == 0 {
3008 u16::MAX as u64
3009 } else {
3010 request.limit
3011 };
3012 let (members, users) = db
3013 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3014 .await?;
3015 response.send(proto::GetChannelMembersResponse { members, users })?;
3016 Ok(())
3017}
3018
3019/// Accept or decline a channel invitation.
3020async fn respond_to_channel_invite(
3021 request: proto::RespondToChannelInvite,
3022 response: Response<proto::RespondToChannelInvite>,
3023 session: Session,
3024) -> Result<()> {
3025 let db = session.db().await;
3026 let channel_id = ChannelId::from_proto(request.channel_id);
3027 let RespondToChannelInvite {
3028 membership_update,
3029 notifications,
3030 } = db
3031 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3032 .await?;
3033
3034 let mut connection_pool = session.connection_pool().await;
3035 if let Some(membership_update) = membership_update {
3036 notify_membership_updated(
3037 &mut connection_pool,
3038 membership_update,
3039 session.user_id(),
3040 &session.peer,
3041 );
3042 } else {
3043 let update = proto::UpdateChannels {
3044 remove_channel_invitations: vec![channel_id.to_proto()],
3045 ..Default::default()
3046 };
3047
3048 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3049 session.peer.send(connection_id, update.clone())?;
3050 }
3051 };
3052
3053 send_notifications(&connection_pool, &session.peer, notifications);
3054
3055 response.send(proto::Ack {})?;
3056
3057 Ok(())
3058}
3059
3060/// Join the channels' room
3061async fn join_channel(
3062 request: proto::JoinChannel,
3063 response: Response<proto::JoinChannel>,
3064 session: Session,
3065) -> Result<()> {
3066 let channel_id = ChannelId::from_proto(request.channel_id);
3067 join_channel_internal(channel_id, Box::new(response), session).await
3068}
3069
3070trait JoinChannelInternalResponse {
3071 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3072}
3073impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3074 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3075 Response::<proto::JoinChannel>::send(self, result)
3076 }
3077}
3078impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3079 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3080 Response::<proto::JoinRoom>::send(self, result)
3081 }
3082}
3083
3084async fn join_channel_internal(
3085 channel_id: ChannelId,
3086 response: Box<impl JoinChannelInternalResponse>,
3087 session: Session,
3088) -> Result<()> {
3089 let joined_room = {
3090 let mut db = session.db().await;
3091 // If zed quits without leaving the room, and the user re-opens zed before the
3092 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3093 // room they were in.
3094 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3095 tracing::info!(
3096 stale_connection_id = %connection,
3097 "cleaning up stale connection",
3098 );
3099 drop(db);
3100 leave_room_for_session(&session, connection).await?;
3101 db = session.db().await;
3102 }
3103
3104 let (joined_room, membership_updated, role) = db
3105 .join_channel(channel_id, session.user_id(), session.connection_id)
3106 .await?;
3107
3108 let live_kit_connection_info =
3109 session
3110 .app_state
3111 .livekit_client
3112 .as_ref()
3113 .and_then(|live_kit| {
3114 let (can_publish, token) = if role == ChannelRole::Guest {
3115 (
3116 false,
3117 live_kit
3118 .guest_token(
3119 &joined_room.room.livekit_room,
3120 &session.user_id().to_string(),
3121 )
3122 .trace_err()?,
3123 )
3124 } else {
3125 (
3126 true,
3127 live_kit
3128 .room_token(
3129 &joined_room.room.livekit_room,
3130 &session.user_id().to_string(),
3131 )
3132 .trace_err()?,
3133 )
3134 };
3135
3136 Some(LiveKitConnectionInfo {
3137 server_url: live_kit.url().into(),
3138 token,
3139 can_publish,
3140 })
3141 });
3142
3143 response.send(proto::JoinRoomResponse {
3144 room: Some(joined_room.room.clone()),
3145 channel_id: joined_room
3146 .channel
3147 .as_ref()
3148 .map(|channel| channel.id.to_proto()),
3149 live_kit_connection_info,
3150 })?;
3151
3152 let mut connection_pool = session.connection_pool().await;
3153 if let Some(membership_updated) = membership_updated {
3154 notify_membership_updated(
3155 &mut connection_pool,
3156 membership_updated,
3157 session.user_id(),
3158 &session.peer,
3159 );
3160 }
3161
3162 room_updated(&joined_room.room, &session.peer);
3163
3164 joined_room
3165 };
3166
3167 channel_updated(
3168 &joined_room
3169 .channel
3170 .ok_or_else(|| anyhow!("channel not returned"))?,
3171 &joined_room.room,
3172 &session.peer,
3173 &*session.connection_pool().await,
3174 );
3175
3176 update_user_contacts(session.user_id(), &session).await?;
3177 Ok(())
3178}
3179
3180/// Start editing the channel notes
3181async fn join_channel_buffer(
3182 request: proto::JoinChannelBuffer,
3183 response: Response<proto::JoinChannelBuffer>,
3184 session: Session,
3185) -> Result<()> {
3186 let db = session.db().await;
3187 let channel_id = ChannelId::from_proto(request.channel_id);
3188
3189 let open_response = db
3190 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3191 .await?;
3192
3193 let collaborators = open_response.collaborators.clone();
3194 response.send(open_response)?;
3195
3196 let update = UpdateChannelBufferCollaborators {
3197 channel_id: channel_id.to_proto(),
3198 collaborators: collaborators.clone(),
3199 };
3200 channel_buffer_updated(
3201 session.connection_id,
3202 collaborators
3203 .iter()
3204 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3205 &update,
3206 &session.peer,
3207 );
3208
3209 Ok(())
3210}
3211
3212/// Edit the channel notes
3213async fn update_channel_buffer(
3214 request: proto::UpdateChannelBuffer,
3215 session: Session,
3216) -> Result<()> {
3217 let db = session.db().await;
3218 let channel_id = ChannelId::from_proto(request.channel_id);
3219
3220 let (collaborators, epoch, version) = db
3221 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3222 .await?;
3223
3224 channel_buffer_updated(
3225 session.connection_id,
3226 collaborators.clone(),
3227 &proto::UpdateChannelBuffer {
3228 channel_id: channel_id.to_proto(),
3229 operations: request.operations,
3230 },
3231 &session.peer,
3232 );
3233
3234 let pool = &*session.connection_pool().await;
3235
3236 let non_collaborators =
3237 pool.channel_connection_ids(channel_id)
3238 .filter_map(|(connection_id, _)| {
3239 if collaborators.contains(&connection_id) {
3240 None
3241 } else {
3242 Some(connection_id)
3243 }
3244 });
3245
3246 broadcast(None, non_collaborators, |peer_id| {
3247 session.peer.send(
3248 peer_id,
3249 proto::UpdateChannels {
3250 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3251 channel_id: channel_id.to_proto(),
3252 epoch: epoch as u64,
3253 version: version.clone(),
3254 }],
3255 ..Default::default()
3256 },
3257 )
3258 });
3259
3260 Ok(())
3261}
3262
3263/// Rejoin the channel notes after a connection blip
3264async fn rejoin_channel_buffers(
3265 request: proto::RejoinChannelBuffers,
3266 response: Response<proto::RejoinChannelBuffers>,
3267 session: Session,
3268) -> Result<()> {
3269 let db = session.db().await;
3270 let buffers = db
3271 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3272 .await?;
3273
3274 for rejoined_buffer in &buffers {
3275 let collaborators_to_notify = rejoined_buffer
3276 .buffer
3277 .collaborators
3278 .iter()
3279 .filter_map(|c| Some(c.peer_id?.into()));
3280 channel_buffer_updated(
3281 session.connection_id,
3282 collaborators_to_notify,
3283 &proto::UpdateChannelBufferCollaborators {
3284 channel_id: rejoined_buffer.buffer.channel_id,
3285 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3286 },
3287 &session.peer,
3288 );
3289 }
3290
3291 response.send(proto::RejoinChannelBuffersResponse {
3292 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3293 })?;
3294
3295 Ok(())
3296}
3297
3298/// Stop editing the channel notes
3299async fn leave_channel_buffer(
3300 request: proto::LeaveChannelBuffer,
3301 response: Response<proto::LeaveChannelBuffer>,
3302 session: Session,
3303) -> Result<()> {
3304 let db = session.db().await;
3305 let channel_id = ChannelId::from_proto(request.channel_id);
3306
3307 let left_buffer = db
3308 .leave_channel_buffer(channel_id, session.connection_id)
3309 .await?;
3310
3311 response.send(Ack {})?;
3312
3313 channel_buffer_updated(
3314 session.connection_id,
3315 left_buffer.connections,
3316 &proto::UpdateChannelBufferCollaborators {
3317 channel_id: channel_id.to_proto(),
3318 collaborators: left_buffer.collaborators,
3319 },
3320 &session.peer,
3321 );
3322
3323 Ok(())
3324}
3325
3326fn channel_buffer_updated<T: EnvelopedMessage>(
3327 sender_id: ConnectionId,
3328 collaborators: impl IntoIterator<Item = ConnectionId>,
3329 message: &T,
3330 peer: &Peer,
3331) {
3332 broadcast(Some(sender_id), collaborators, |peer_id| {
3333 peer.send(peer_id, message.clone())
3334 });
3335}
3336
3337fn send_notifications(
3338 connection_pool: &ConnectionPool,
3339 peer: &Peer,
3340 notifications: db::NotificationBatch,
3341) {
3342 for (user_id, notification) in notifications {
3343 for connection_id in connection_pool.user_connection_ids(user_id) {
3344 if let Err(error) = peer.send(
3345 connection_id,
3346 proto::AddNotification {
3347 notification: Some(notification.clone()),
3348 },
3349 ) {
3350 tracing::error!(
3351 "failed to send notification to {:?} {}",
3352 connection_id,
3353 error
3354 );
3355 }
3356 }
3357 }
3358}
3359
3360/// Send a message to the channel
3361async fn send_channel_message(
3362 request: proto::SendChannelMessage,
3363 response: Response<proto::SendChannelMessage>,
3364 session: Session,
3365) -> Result<()> {
3366 // Validate the message body.
3367 let body = request.body.trim().to_string();
3368 if body.len() > MAX_MESSAGE_LEN {
3369 return Err(anyhow!("message is too long"))?;
3370 }
3371 if body.is_empty() {
3372 return Err(anyhow!("message can't be blank"))?;
3373 }
3374
3375 // TODO: adjust mentions if body is trimmed
3376
3377 let timestamp = OffsetDateTime::now_utc();
3378 let nonce = request
3379 .nonce
3380 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3381
3382 let channel_id = ChannelId::from_proto(request.channel_id);
3383 let CreatedChannelMessage {
3384 message_id,
3385 participant_connection_ids,
3386 notifications,
3387 } = session
3388 .db()
3389 .await
3390 .create_channel_message(
3391 channel_id,
3392 session.user_id(),
3393 &body,
3394 &request.mentions,
3395 timestamp,
3396 nonce.clone().into(),
3397 request.reply_to_message_id.map(MessageId::from_proto),
3398 )
3399 .await?;
3400
3401 let message = proto::ChannelMessage {
3402 sender_id: session.user_id().to_proto(),
3403 id: message_id.to_proto(),
3404 body,
3405 mentions: request.mentions,
3406 timestamp: timestamp.unix_timestamp() as u64,
3407 nonce: Some(nonce),
3408 reply_to_message_id: request.reply_to_message_id,
3409 edited_at: None,
3410 };
3411 broadcast(
3412 Some(session.connection_id),
3413 participant_connection_ids.clone(),
3414 |connection| {
3415 session.peer.send(
3416 connection,
3417 proto::ChannelMessageSent {
3418 channel_id: channel_id.to_proto(),
3419 message: Some(message.clone()),
3420 },
3421 )
3422 },
3423 );
3424 response.send(proto::SendChannelMessageResponse {
3425 message: Some(message),
3426 })?;
3427
3428 let pool = &*session.connection_pool().await;
3429 let non_participants =
3430 pool.channel_connection_ids(channel_id)
3431 .filter_map(|(connection_id, _)| {
3432 if participant_connection_ids.contains(&connection_id) {
3433 None
3434 } else {
3435 Some(connection_id)
3436 }
3437 });
3438 broadcast(None, non_participants, |peer_id| {
3439 session.peer.send(
3440 peer_id,
3441 proto::UpdateChannels {
3442 latest_channel_message_ids: vec![proto::ChannelMessageId {
3443 channel_id: channel_id.to_proto(),
3444 message_id: message_id.to_proto(),
3445 }],
3446 ..Default::default()
3447 },
3448 )
3449 });
3450 send_notifications(pool, &session.peer, notifications);
3451
3452 Ok(())
3453}
3454
3455/// Delete a channel message
3456async fn remove_channel_message(
3457 request: proto::RemoveChannelMessage,
3458 response: Response<proto::RemoveChannelMessage>,
3459 session: Session,
3460) -> Result<()> {
3461 let channel_id = ChannelId::from_proto(request.channel_id);
3462 let message_id = MessageId::from_proto(request.message_id);
3463 let (connection_ids, existing_notification_ids) = session
3464 .db()
3465 .await
3466 .remove_channel_message(channel_id, message_id, session.user_id())
3467 .await?;
3468
3469 broadcast(
3470 Some(session.connection_id),
3471 connection_ids,
3472 move |connection| {
3473 session.peer.send(connection, request.clone())?;
3474
3475 for notification_id in &existing_notification_ids {
3476 session.peer.send(
3477 connection,
3478 proto::DeleteNotification {
3479 notification_id: (*notification_id).to_proto(),
3480 },
3481 )?;
3482 }
3483
3484 Ok(())
3485 },
3486 );
3487 response.send(proto::Ack {})?;
3488 Ok(())
3489}
3490
3491async fn update_channel_message(
3492 request: proto::UpdateChannelMessage,
3493 response: Response<proto::UpdateChannelMessage>,
3494 session: Session,
3495) -> Result<()> {
3496 let channel_id = ChannelId::from_proto(request.channel_id);
3497 let message_id = MessageId::from_proto(request.message_id);
3498 let updated_at = OffsetDateTime::now_utc();
3499 let UpdatedChannelMessage {
3500 message_id,
3501 participant_connection_ids,
3502 notifications,
3503 reply_to_message_id,
3504 timestamp,
3505 deleted_mention_notification_ids,
3506 updated_mention_notifications,
3507 } = session
3508 .db()
3509 .await
3510 .update_channel_message(
3511 channel_id,
3512 message_id,
3513 session.user_id(),
3514 request.body.as_str(),
3515 &request.mentions,
3516 updated_at,
3517 )
3518 .await?;
3519
3520 let nonce = request
3521 .nonce
3522 .clone()
3523 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3524
3525 let message = proto::ChannelMessage {
3526 sender_id: session.user_id().to_proto(),
3527 id: message_id.to_proto(),
3528 body: request.body.clone(),
3529 mentions: request.mentions.clone(),
3530 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3531 nonce: Some(nonce),
3532 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3533 edited_at: Some(updated_at.unix_timestamp() as u64),
3534 };
3535
3536 response.send(proto::Ack {})?;
3537
3538 let pool = &*session.connection_pool().await;
3539 broadcast(
3540 Some(session.connection_id),
3541 participant_connection_ids,
3542 |connection| {
3543 session.peer.send(
3544 connection,
3545 proto::ChannelMessageUpdate {
3546 channel_id: channel_id.to_proto(),
3547 message: Some(message.clone()),
3548 },
3549 )?;
3550
3551 for notification_id in &deleted_mention_notification_ids {
3552 session.peer.send(
3553 connection,
3554 proto::DeleteNotification {
3555 notification_id: (*notification_id).to_proto(),
3556 },
3557 )?;
3558 }
3559
3560 for notification in &updated_mention_notifications {
3561 session.peer.send(
3562 connection,
3563 proto::UpdateNotification {
3564 notification: Some(notification.clone()),
3565 },
3566 )?;
3567 }
3568
3569 Ok(())
3570 },
3571 );
3572
3573 send_notifications(pool, &session.peer, notifications);
3574
3575 Ok(())
3576}
3577
3578/// Mark a channel message as read
3579async fn acknowledge_channel_message(
3580 request: proto::AckChannelMessage,
3581 session: Session,
3582) -> Result<()> {
3583 let channel_id = ChannelId::from_proto(request.channel_id);
3584 let message_id = MessageId::from_proto(request.message_id);
3585 let notifications = session
3586 .db()
3587 .await
3588 .observe_channel_message(channel_id, session.user_id(), message_id)
3589 .await?;
3590 send_notifications(
3591 &*session.connection_pool().await,
3592 &session.peer,
3593 notifications,
3594 );
3595 Ok(())
3596}
3597
3598/// Mark a buffer version as synced
3599async fn acknowledge_buffer_version(
3600 request: proto::AckBufferOperation,
3601 session: Session,
3602) -> Result<()> {
3603 let buffer_id = BufferId::from_proto(request.buffer_id);
3604 session
3605 .db()
3606 .await
3607 .observe_buffer_version(
3608 buffer_id,
3609 session.user_id(),
3610 request.epoch as i32,
3611 &request.version,
3612 )
3613 .await?;
3614 Ok(())
3615}
3616
3617async fn count_language_model_tokens(
3618 request: proto::CountLanguageModelTokens,
3619 response: Response<proto::CountLanguageModelTokens>,
3620 session: Session,
3621 config: &Config,
3622) -> Result<()> {
3623 authorize_access_to_legacy_llm_endpoints(&session).await?;
3624
3625 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3626 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3627 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3628 };
3629
3630 session
3631 .app_state
3632 .rate_limiter
3633 .check(&*rate_limit, session.user_id())
3634 .await?;
3635
3636 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3637 Some(proto::LanguageModelProvider::Google) => {
3638 let api_key = config
3639 .google_ai_api_key
3640 .as_ref()
3641 .context("no Google AI API key configured on the server")?;
3642 google_ai::count_tokens(
3643 session.http_client.as_ref(),
3644 google_ai::API_URL,
3645 api_key,
3646 serde_json::from_str(&request.request)?,
3647 )
3648 .await?
3649 }
3650 _ => return Err(anyhow!("unsupported provider"))?,
3651 };
3652
3653 response.send(proto::CountLanguageModelTokensResponse {
3654 token_count: result.total_tokens as u32,
3655 })?;
3656
3657 Ok(())
3658}
3659
3660struct ZedProCountLanguageModelTokensRateLimit;
3661
3662impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3663 fn capacity(&self) -> usize {
3664 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3665 .ok()
3666 .and_then(|v| v.parse().ok())
3667 .unwrap_or(600) // Picked arbitrarily
3668 }
3669
3670 fn refill_duration(&self) -> chrono::Duration {
3671 chrono::Duration::hours(1)
3672 }
3673
3674 fn db_name(&self) -> &'static str {
3675 "zed-pro:count-language-model-tokens"
3676 }
3677}
3678
3679struct FreeCountLanguageModelTokensRateLimit;
3680
3681impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3682 fn capacity(&self) -> usize {
3683 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3684 .ok()
3685 .and_then(|v| v.parse().ok())
3686 .unwrap_or(600 / 10) // Picked arbitrarily
3687 }
3688
3689 fn refill_duration(&self) -> chrono::Duration {
3690 chrono::Duration::hours(1)
3691 }
3692
3693 fn db_name(&self) -> &'static str {
3694 "free:count-language-model-tokens"
3695 }
3696}
3697
3698struct ZedProComputeEmbeddingsRateLimit;
3699
3700impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3701 fn capacity(&self) -> usize {
3702 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3703 .ok()
3704 .and_then(|v| v.parse().ok())
3705 .unwrap_or(5000) // Picked arbitrarily
3706 }
3707
3708 fn refill_duration(&self) -> chrono::Duration {
3709 chrono::Duration::hours(1)
3710 }
3711
3712 fn db_name(&self) -> &'static str {
3713 "zed-pro:compute-embeddings"
3714 }
3715}
3716
3717struct FreeComputeEmbeddingsRateLimit;
3718
3719impl RateLimit for FreeComputeEmbeddingsRateLimit {
3720 fn capacity(&self) -> usize {
3721 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3722 .ok()
3723 .and_then(|v| v.parse().ok())
3724 .unwrap_or(5000 / 10) // Picked arbitrarily
3725 }
3726
3727 fn refill_duration(&self) -> chrono::Duration {
3728 chrono::Duration::hours(1)
3729 }
3730
3731 fn db_name(&self) -> &'static str {
3732 "free:compute-embeddings"
3733 }
3734}
3735
3736async fn compute_embeddings(
3737 request: proto::ComputeEmbeddings,
3738 response: Response<proto::ComputeEmbeddings>,
3739 session: Session,
3740 api_key: Option<Arc<str>>,
3741) -> Result<()> {
3742 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3743 authorize_access_to_legacy_llm_endpoints(&session).await?;
3744
3745 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3746 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3747 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3748 };
3749
3750 session
3751 .app_state
3752 .rate_limiter
3753 .check(&*rate_limit, session.user_id())
3754 .await?;
3755
3756 let embeddings = match request.model.as_str() {
3757 "openai/text-embedding-3-small" => {
3758 open_ai::embed(
3759 session.http_client.as_ref(),
3760 OPEN_AI_API_URL,
3761 &api_key,
3762 OpenAiEmbeddingModel::TextEmbedding3Small,
3763 request.texts.iter().map(|text| text.as_str()),
3764 )
3765 .await?
3766 }
3767 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3768 };
3769
3770 let embeddings = request
3771 .texts
3772 .iter()
3773 .map(|text| {
3774 let mut hasher = sha2::Sha256::new();
3775 hasher.update(text.as_bytes());
3776 let result = hasher.finalize();
3777 result.to_vec()
3778 })
3779 .zip(
3780 embeddings
3781 .data
3782 .into_iter()
3783 .map(|embedding| embedding.embedding),
3784 )
3785 .collect::<HashMap<_, _>>();
3786
3787 let db = session.db().await;
3788 db.save_embeddings(&request.model, &embeddings)
3789 .await
3790 .context("failed to save embeddings")
3791 .trace_err();
3792
3793 response.send(proto::ComputeEmbeddingsResponse {
3794 embeddings: embeddings
3795 .into_iter()
3796 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3797 .collect(),
3798 })?;
3799 Ok(())
3800}
3801
3802async fn get_cached_embeddings(
3803 request: proto::GetCachedEmbeddings,
3804 response: Response<proto::GetCachedEmbeddings>,
3805 session: Session,
3806) -> Result<()> {
3807 authorize_access_to_legacy_llm_endpoints(&session).await?;
3808
3809 let db = session.db().await;
3810 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3811
3812 response.send(proto::GetCachedEmbeddingsResponse {
3813 embeddings: embeddings
3814 .into_iter()
3815 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3816 .collect(),
3817 })?;
3818 Ok(())
3819}
3820
3821/// This is leftover from before the LLM service.
3822///
3823/// The endpoints protected by this check will be moved there eventually.
3824async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3825 if session.is_staff() {
3826 Ok(())
3827 } else {
3828 Err(anyhow!("permission denied"))?
3829 }
3830}
3831
3832/// Get a Supermaven API key for the user
3833async fn get_supermaven_api_key(
3834 _request: proto::GetSupermavenApiKey,
3835 response: Response<proto::GetSupermavenApiKey>,
3836 session: Session,
3837) -> Result<()> {
3838 let user_id: String = session.user_id().to_string();
3839 if !session.is_staff() {
3840 return Err(anyhow!("supermaven not enabled for this account"))?;
3841 }
3842
3843 let email = session
3844 .email()
3845 .ok_or_else(|| anyhow!("user must have an email"))?;
3846
3847 let supermaven_admin_api = session
3848 .supermaven_client
3849 .as_ref()
3850 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3851
3852 let result = supermaven_admin_api
3853 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3854 .await?;
3855
3856 response.send(proto::GetSupermavenApiKeyResponse {
3857 api_key: result.api_key,
3858 })?;
3859
3860 Ok(())
3861}
3862
3863/// Start receiving chat updates for a channel
3864async fn join_channel_chat(
3865 request: proto::JoinChannelChat,
3866 response: Response<proto::JoinChannelChat>,
3867 session: Session,
3868) -> Result<()> {
3869 let channel_id = ChannelId::from_proto(request.channel_id);
3870
3871 let db = session.db().await;
3872 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3873 .await?;
3874 let messages = db
3875 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3876 .await?;
3877 response.send(proto::JoinChannelChatResponse {
3878 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3879 messages,
3880 })?;
3881 Ok(())
3882}
3883
3884/// Stop receiving chat updates for a channel
3885async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3886 let channel_id = ChannelId::from_proto(request.channel_id);
3887 session
3888 .db()
3889 .await
3890 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3891 .await?;
3892 Ok(())
3893}
3894
3895/// Retrieve the chat history for a channel
3896async fn get_channel_messages(
3897 request: proto::GetChannelMessages,
3898 response: Response<proto::GetChannelMessages>,
3899 session: Session,
3900) -> Result<()> {
3901 let channel_id = ChannelId::from_proto(request.channel_id);
3902 let messages = session
3903 .db()
3904 .await
3905 .get_channel_messages(
3906 channel_id,
3907 session.user_id(),
3908 MESSAGE_COUNT_PER_PAGE,
3909 Some(MessageId::from_proto(request.before_message_id)),
3910 )
3911 .await?;
3912 response.send(proto::GetChannelMessagesResponse {
3913 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3914 messages,
3915 })?;
3916 Ok(())
3917}
3918
3919/// Retrieve specific chat messages
3920async fn get_channel_messages_by_id(
3921 request: proto::GetChannelMessagesById,
3922 response: Response<proto::GetChannelMessagesById>,
3923 session: Session,
3924) -> Result<()> {
3925 let message_ids = request
3926 .message_ids
3927 .iter()
3928 .map(|id| MessageId::from_proto(*id))
3929 .collect::<Vec<_>>();
3930 let messages = session
3931 .db()
3932 .await
3933 .get_channel_messages_by_id(session.user_id(), &message_ids)
3934 .await?;
3935 response.send(proto::GetChannelMessagesResponse {
3936 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3937 messages,
3938 })?;
3939 Ok(())
3940}
3941
3942/// Retrieve the current users notifications
3943async fn get_notifications(
3944 request: proto::GetNotifications,
3945 response: Response<proto::GetNotifications>,
3946 session: Session,
3947) -> Result<()> {
3948 let notifications = session
3949 .db()
3950 .await
3951 .get_notifications(
3952 session.user_id(),
3953 NOTIFICATION_COUNT_PER_PAGE,
3954 request.before_id.map(db::NotificationId::from_proto),
3955 )
3956 .await?;
3957 response.send(proto::GetNotificationsResponse {
3958 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3959 notifications,
3960 })?;
3961 Ok(())
3962}
3963
3964/// Mark notifications as read
3965async fn mark_notification_as_read(
3966 request: proto::MarkNotificationRead,
3967 response: Response<proto::MarkNotificationRead>,
3968 session: Session,
3969) -> Result<()> {
3970 let database = &session.db().await;
3971 let notifications = database
3972 .mark_notification_as_read_by_id(
3973 session.user_id(),
3974 NotificationId::from_proto(request.notification_id),
3975 )
3976 .await?;
3977 send_notifications(
3978 &*session.connection_pool().await,
3979 &session.peer,
3980 notifications,
3981 );
3982 response.send(proto::Ack {})?;
3983 Ok(())
3984}
3985
3986/// Get the current users information
3987async fn get_private_user_info(
3988 _request: proto::GetPrivateUserInfo,
3989 response: Response<proto::GetPrivateUserInfo>,
3990 session: Session,
3991) -> Result<()> {
3992 let db = session.db().await;
3993
3994 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3995 let user = db
3996 .get_user_by_id(session.user_id())
3997 .await?
3998 .ok_or_else(|| anyhow!("user not found"))?;
3999 let flags = db.get_user_flags(session.user_id()).await?;
4000
4001 response.send(proto::GetPrivateUserInfoResponse {
4002 metrics_id,
4003 staff: user.admin,
4004 flags,
4005 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4006 })?;
4007 Ok(())
4008}
4009
4010/// Accept the terms of service (tos) on behalf of the current user
4011async fn accept_terms_of_service(
4012 _request: proto::AcceptTermsOfService,
4013 response: Response<proto::AcceptTermsOfService>,
4014 session: Session,
4015) -> Result<()> {
4016 let db = session.db().await;
4017
4018 let accepted_tos_at = Utc::now();
4019 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4020 .await?;
4021
4022 response.send(proto::AcceptTermsOfServiceResponse {
4023 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4024 })?;
4025 Ok(())
4026}
4027
4028/// The minimum account age an account must have in order to use the LLM service.
4029const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4030
4031async fn get_llm_api_token(
4032 _request: proto::GetLlmToken,
4033 response: Response<proto::GetLlmToken>,
4034 session: Session,
4035) -> Result<()> {
4036 let db = session.db().await;
4037
4038 let flags = db.get_user_flags(session.user_id()).await?;
4039 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4040 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4041 let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4042
4043 if !session.is_staff() && !has_language_models_feature_flag {
4044 Err(anyhow!("permission denied"))?
4045 }
4046
4047 let user_id = session.user_id();
4048 let user = db
4049 .get_user_by_id(user_id)
4050 .await?
4051 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4052
4053 if user.accepted_tos_at.is_none() {
4054 Err(anyhow!("terms of service not accepted"))?
4055 }
4056
4057 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4058
4059 let bypass_account_age_check =
4060 has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4061 if !bypass_account_age_check {
4062 let mut account_created_at = user.created_at;
4063 if let Some(github_created_at) = user.github_user_created_at {
4064 account_created_at = account_created_at.min(github_created_at);
4065 }
4066 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4067 Err(anyhow!("account too young"))?
4068 }
4069 }
4070
4071 let billing_preferences = db.get_billing_preferences(user.id).await?;
4072
4073 let token = LlmTokenClaims::create(
4074 &user,
4075 session.is_staff(),
4076 billing_preferences,
4077 has_llm_closed_beta_feature_flag,
4078 has_predict_edits_feature_flag,
4079 has_llm_subscription,
4080 session.current_plan(&db).await?,
4081 session.system_id.clone(),
4082 &session.app_state.config,
4083 )?;
4084 response.send(proto::GetLlmTokenResponse { token })?;
4085 Ok(())
4086}
4087
4088fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4089 let message = match message {
4090 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4091 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4092 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4093 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4094 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4095 code: frame.code.into(),
4096 reason: frame.reason,
4097 })),
4098 // We should never receive a frame while reading the message, according
4099 // to the `tungstenite` maintainers:
4100 //
4101 // > It cannot occur when you read messages from the WebSocket, but it
4102 // > can be used when you want to send the raw frames (e.g. you want to
4103 // > send the frames to the WebSocket without composing the full message first).
4104 // >
4105 // > — https://github.com/snapview/tungstenite-rs/issues/268
4106 TungsteniteMessage::Frame(_) => {
4107 bail!("received an unexpected frame while reading the message")
4108 }
4109 };
4110
4111 Ok(message)
4112}
4113
4114fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4115 match message {
4116 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4117 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4118 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4119 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4120 AxumMessage::Close(frame) => {
4121 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4122 code: frame.code.into(),
4123 reason: frame.reason,
4124 }))
4125 }
4126 }
4127}
4128
4129fn notify_membership_updated(
4130 connection_pool: &mut ConnectionPool,
4131 result: MembershipUpdated,
4132 user_id: UserId,
4133 peer: &Peer,
4134) {
4135 for membership in &result.new_channels.channel_memberships {
4136 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4137 }
4138 for channel_id in &result.removed_channels {
4139 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4140 }
4141
4142 let user_channels_update = proto::UpdateUserChannels {
4143 channel_memberships: result
4144 .new_channels
4145 .channel_memberships
4146 .iter()
4147 .map(|cm| proto::ChannelMembership {
4148 channel_id: cm.channel_id.to_proto(),
4149 role: cm.role.into(),
4150 })
4151 .collect(),
4152 ..Default::default()
4153 };
4154
4155 let mut update = build_channels_update(result.new_channels);
4156 update.delete_channels = result
4157 .removed_channels
4158 .into_iter()
4159 .map(|id| id.to_proto())
4160 .collect();
4161 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4162
4163 for connection_id in connection_pool.user_connection_ids(user_id) {
4164 peer.send(connection_id, user_channels_update.clone())
4165 .trace_err();
4166 peer.send(connection_id, update.clone()).trace_err();
4167 }
4168}
4169
4170fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4171 proto::UpdateUserChannels {
4172 channel_memberships: channels
4173 .channel_memberships
4174 .iter()
4175 .map(|m| proto::ChannelMembership {
4176 channel_id: m.channel_id.to_proto(),
4177 role: m.role.into(),
4178 })
4179 .collect(),
4180 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4181 observed_channel_message_id: channels.observed_channel_messages.clone(),
4182 }
4183}
4184
4185fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4186 let mut update = proto::UpdateChannels::default();
4187
4188 for channel in channels.channels {
4189 update.channels.push(channel.to_proto());
4190 }
4191
4192 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4193 update.latest_channel_message_ids = channels.latest_channel_messages;
4194
4195 for (channel_id, participants) in channels.channel_participants {
4196 update
4197 .channel_participants
4198 .push(proto::ChannelParticipants {
4199 channel_id: channel_id.to_proto(),
4200 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4201 });
4202 }
4203
4204 for channel in channels.invited_channels {
4205 update.channel_invitations.push(channel.to_proto());
4206 }
4207
4208 update
4209}
4210
4211fn build_initial_contacts_update(
4212 contacts: Vec<db::Contact>,
4213 pool: &ConnectionPool,
4214) -> proto::UpdateContacts {
4215 let mut update = proto::UpdateContacts::default();
4216
4217 for contact in contacts {
4218 match contact {
4219 db::Contact::Accepted { user_id, busy } => {
4220 update.contacts.push(contact_for_user(user_id, busy, pool));
4221 }
4222 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4223 db::Contact::Incoming { user_id } => {
4224 update
4225 .incoming_requests
4226 .push(proto::IncomingContactRequest {
4227 requester_id: user_id.to_proto(),
4228 })
4229 }
4230 }
4231 }
4232
4233 update
4234}
4235
4236fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4237 proto::Contact {
4238 user_id: user_id.to_proto(),
4239 online: pool.is_user_online(user_id),
4240 busy,
4241 }
4242}
4243
4244fn room_updated(room: &proto::Room, peer: &Peer) {
4245 broadcast(
4246 None,
4247 room.participants
4248 .iter()
4249 .filter_map(|participant| Some(participant.peer_id?.into())),
4250 |peer_id| {
4251 peer.send(
4252 peer_id,
4253 proto::RoomUpdated {
4254 room: Some(room.clone()),
4255 },
4256 )
4257 },
4258 );
4259}
4260
4261fn channel_updated(
4262 channel: &db::channel::Model,
4263 room: &proto::Room,
4264 peer: &Peer,
4265 pool: &ConnectionPool,
4266) {
4267 let participants = room
4268 .participants
4269 .iter()
4270 .map(|p| p.user_id)
4271 .collect::<Vec<_>>();
4272
4273 broadcast(
4274 None,
4275 pool.channel_connection_ids(channel.root_id())
4276 .filter_map(|(channel_id, role)| {
4277 role.can_see_channel(channel.visibility)
4278 .then_some(channel_id)
4279 }),
4280 |peer_id| {
4281 peer.send(
4282 peer_id,
4283 proto::UpdateChannels {
4284 channel_participants: vec![proto::ChannelParticipants {
4285 channel_id: channel.id.to_proto(),
4286 participant_user_ids: participants.clone(),
4287 }],
4288 ..Default::default()
4289 },
4290 )
4291 },
4292 );
4293}
4294
4295async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4296 let db = session.db().await;
4297
4298 let contacts = db.get_contacts(user_id).await?;
4299 let busy = db.is_user_busy(user_id).await?;
4300
4301 let pool = session.connection_pool().await;
4302 let updated_contact = contact_for_user(user_id, busy, &pool);
4303 for contact in contacts {
4304 if let db::Contact::Accepted {
4305 user_id: contact_user_id,
4306 ..
4307 } = contact
4308 {
4309 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4310 session
4311 .peer
4312 .send(
4313 contact_conn_id,
4314 proto::UpdateContacts {
4315 contacts: vec![updated_contact.clone()],
4316 remove_contacts: Default::default(),
4317 incoming_requests: Default::default(),
4318 remove_incoming_requests: Default::default(),
4319 outgoing_requests: Default::default(),
4320 remove_outgoing_requests: Default::default(),
4321 },
4322 )
4323 .trace_err();
4324 }
4325 }
4326 }
4327 Ok(())
4328}
4329
4330async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4331 let mut contacts_to_update = HashSet::default();
4332
4333 let room_id;
4334 let canceled_calls_to_user_ids;
4335 let livekit_room;
4336 let delete_livekit_room;
4337 let room;
4338 let channel;
4339
4340 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4341 contacts_to_update.insert(session.user_id());
4342
4343 for project in left_room.left_projects.values() {
4344 project_left(project, session);
4345 }
4346
4347 room_id = RoomId::from_proto(left_room.room.id);
4348 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4349 livekit_room = mem::take(&mut left_room.room.livekit_room);
4350 delete_livekit_room = left_room.deleted;
4351 room = mem::take(&mut left_room.room);
4352 channel = mem::take(&mut left_room.channel);
4353
4354 room_updated(&room, &session.peer);
4355 } else {
4356 return Ok(());
4357 }
4358
4359 if let Some(channel) = channel {
4360 channel_updated(
4361 &channel,
4362 &room,
4363 &session.peer,
4364 &*session.connection_pool().await,
4365 );
4366 }
4367
4368 {
4369 let pool = session.connection_pool().await;
4370 for canceled_user_id in canceled_calls_to_user_ids {
4371 for connection_id in pool.user_connection_ids(canceled_user_id) {
4372 session
4373 .peer
4374 .send(
4375 connection_id,
4376 proto::CallCanceled {
4377 room_id: room_id.to_proto(),
4378 },
4379 )
4380 .trace_err();
4381 }
4382 contacts_to_update.insert(canceled_user_id);
4383 }
4384 }
4385
4386 for contact_user_id in contacts_to_update {
4387 update_user_contacts(contact_user_id, session).await?;
4388 }
4389
4390 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4391 live_kit
4392 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4393 .await
4394 .trace_err();
4395
4396 if delete_livekit_room {
4397 live_kit.delete_room(livekit_room).await.trace_err();
4398 }
4399 }
4400
4401 Ok(())
4402}
4403
4404async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4405 let left_channel_buffers = session
4406 .db()
4407 .await
4408 .leave_channel_buffers(session.connection_id)
4409 .await?;
4410
4411 for left_buffer in left_channel_buffers {
4412 channel_buffer_updated(
4413 session.connection_id,
4414 left_buffer.connections,
4415 &proto::UpdateChannelBufferCollaborators {
4416 channel_id: left_buffer.channel_id.to_proto(),
4417 collaborators: left_buffer.collaborators,
4418 },
4419 &session.peer,
4420 );
4421 }
4422
4423 Ok(())
4424}
4425
4426fn project_left(project: &db::LeftProject, session: &Session) {
4427 for connection_id in &project.connection_ids {
4428 if project.should_unshare {
4429 session
4430 .peer
4431 .send(
4432 *connection_id,
4433 proto::UnshareProject {
4434 project_id: project.id.to_proto(),
4435 },
4436 )
4437 .trace_err();
4438 } else {
4439 session
4440 .peer
4441 .send(
4442 *connection_id,
4443 proto::RemoveProjectCollaborator {
4444 project_id: project.id.to_proto(),
4445 peer_id: Some(session.connection_id.into()),
4446 },
4447 )
4448 .trace_err();
4449 }
4450 }
4451}
4452
4453pub trait ResultExt {
4454 type Ok;
4455
4456 fn trace_err(self) -> Option<Self::Ok>;
4457}
4458
4459impl<T, E> ResultExt for Result<T, E>
4460where
4461 E: std::fmt::Debug,
4462{
4463 type Ok = T;
4464
4465 #[track_caller]
4466 fn trace_err(self) -> Option<T> {
4467 match self {
4468 Ok(value) => Some(value),
4469 Err(error) => {
4470 tracing::error!("{:?}", error);
4471 None
4472 }
4473 }
4474 }
4475}