1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use rpc::proto::split_repository_update;
41use sha2::Digest;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
46 TryStreamExt,
47};
48use prometheus::{register_int_gauge, IntGauge};
49use rpc::{
50 proto::{
51 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
52 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
53 },
54 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
55};
56use semantic_version::SemanticVersion;
57use serde::{Serialize, Serializer};
58use std::{
59 any::TypeId,
60 future::Future,
61 marker::PhantomData,
62 mem,
63 net::SocketAddr,
64 ops::{Deref, DerefMut},
65 rc::Rc,
66 sync::{
67 atomic::{AtomicBool, Ordering::SeqCst},
68 Arc, OnceLock,
69 },
70 time::{Duration, Instant},
71};
72use time::OffsetDateTime;
73use tokio::sync::{watch, MutexGuard, Semaphore};
74use tower::ServiceBuilder;
75use tracing::{
76 field::{self},
77 info_span, instrument, Instrument,
78};
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106#[derive(Clone, Debug)]
107pub enum Principal {
108 User(User),
109 Impersonated { user: User, admin: User },
110}
111
112impl Principal {
113 fn update_span(&self, span: &tracing::Span) {
114 match &self {
115 Principal::User(user) => {
116 span.record("user_id", user.id.0);
117 span.record("login", &user.github_login);
118 }
119 Principal::Impersonated { user, admin } => {
120 span.record("user_id", user.id.0);
121 span.record("login", &user.github_login);
122 span.record("impersonator", &admin.github_login);
123 }
124 }
125 }
126}
127
128#[derive(Clone)]
129struct Session {
130 principal: Principal,
131 connection_id: ConnectionId,
132 db: Arc<tokio::sync::Mutex<DbHandle>>,
133 peer: Arc<Peer>,
134 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
135 app_state: Arc<AppState>,
136 supermaven_client: Option<Arc<SupermavenAdminApi>>,
137 http_client: Arc<dyn HttpClient>,
138 /// The GeoIP country code for the user.
139 #[allow(unused)]
140 geoip_country_code: Option<String>,
141 system_id: Option<String>,
142 _executor: Executor,
143}
144
145impl Session {
146 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
147 #[cfg(test)]
148 tokio::task::yield_now().await;
149 let guard = self.db.lock().await;
150 #[cfg(test)]
151 tokio::task::yield_now().await;
152 guard
153 }
154
155 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 let guard = self.connection_pool.lock();
159 ConnectionPoolGuard {
160 guard,
161 _not_send: PhantomData,
162 }
163 }
164
165 fn is_staff(&self) -> bool {
166 match &self.principal {
167 Principal::User(user) => user.admin,
168 Principal::Impersonated { .. } => true,
169 }
170 }
171
172 pub async fn has_llm_subscription(
173 &self,
174 db: &MutexGuard<'_, DbHandle>,
175 ) -> anyhow::Result<bool> {
176 if self.is_staff() {
177 return Ok(true);
178 }
179
180 let user_id = self.user_id();
181
182 Ok(db.has_active_billing_subscription(user_id).await?)
183 }
184
185 pub async fn current_plan(
186 &self,
187 _db: &MutexGuard<'_, DbHandle>,
188 ) -> anyhow::Result<proto::Plan> {
189 if self.is_staff() {
190 Ok(proto::Plan::ZedPro)
191 } else {
192 Ok(proto::Plan::Free)
193 }
194 }
195
196 fn user_id(&self) -> UserId {
197 match &self.principal {
198 Principal::User(user) => user.id,
199 Principal::Impersonated { user, .. } => user.id,
200 }
201 }
202
203 pub fn email(&self) -> Option<String> {
204 match &self.principal {
205 Principal::User(user) => user.email_address.clone(),
206 Principal::Impersonated { user, .. } => user.email_address.clone(),
207 }
208 }
209}
210
211impl Debug for Session {
212 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
213 let mut result = f.debug_struct("Session");
214 match &self.principal {
215 Principal::User(user) => {
216 result.field("user", &user.github_login);
217 }
218 Principal::Impersonated { user, admin } => {
219 result.field("user", &user.github_login);
220 result.field("impersonator", &admin.github_login);
221 }
222 }
223 result.field("connection_id", &self.connection_id).finish()
224 }
225}
226
227struct DbHandle(Arc<Database>);
228
229impl Deref for DbHandle {
230 type Target = Database;
231
232 fn deref(&self) -> &Self::Target {
233 self.0.as_ref()
234 }
235}
236
237pub struct Server {
238 id: parking_lot::Mutex<ServerId>,
239 peer: Arc<Peer>,
240 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
241 app_state: Arc<AppState>,
242 handlers: HashMap<TypeId, MessageHandler>,
243 teardown: watch::Sender<bool>,
244}
245
246pub(crate) struct ConnectionPoolGuard<'a> {
247 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
248 _not_send: PhantomData<Rc<()>>,
249}
250
251#[derive(Serialize)]
252pub struct ServerSnapshot<'a> {
253 peer: &'a Peer,
254 #[serde(serialize_with = "serialize_deref")]
255 connection_pool: ConnectionPoolGuard<'a>,
256}
257
258pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
259where
260 S: Serializer,
261 T: Deref<Target = U>,
262 U: Serialize,
263{
264 Serialize::serialize(value.deref(), serializer)
265}
266
267impl Server {
268 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
269 let mut server = Self {
270 id: parking_lot::Mutex::new(id),
271 peer: Peer::new(id.0 as u32),
272 app_state: app_state.clone(),
273 connection_pool: Default::default(),
274 handlers: Default::default(),
275 teardown: watch::channel(false).0,
276 };
277
278 server
279 .add_request_handler(ping)
280 .add_request_handler(create_room)
281 .add_request_handler(join_room)
282 .add_request_handler(rejoin_room)
283 .add_request_handler(leave_room)
284 .add_request_handler(set_room_participant_role)
285 .add_request_handler(call)
286 .add_request_handler(cancel_call)
287 .add_message_handler(decline_call)
288 .add_request_handler(update_participant_location)
289 .add_request_handler(share_project)
290 .add_message_handler(unshare_project)
291 .add_request_handler(join_project)
292 .add_message_handler(leave_project)
293 .add_request_handler(update_project)
294 .add_request_handler(update_worktree)
295 .add_request_handler(update_repository)
296 .add_request_handler(remove_repository)
297 .add_message_handler(start_language_server)
298 .add_message_handler(update_language_server)
299 .add_message_handler(update_diagnostic_summary)
300 .add_message_handler(update_worktree_settings)
301 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
302 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
305 .add_request_handler(forward_find_search_candidates_request)
306 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
307 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
308 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
309 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
311 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
312 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
313 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
314 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
315 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
316 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
319 .add_request_handler(
320 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
321 )
322 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
323 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
324 .add_request_handler(
325 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
326 )
327 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
332 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
333 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
334 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
335 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
336 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
337 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
338 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
339 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
340 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
341 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
342 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
343 .add_request_handler(
344 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
345 )
346 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
347 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
348 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
349 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
350 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
351 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
352 .add_message_handler(create_buffer_for_peer)
353 .add_request_handler(update_buffer)
354 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
355 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
356 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
357 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
358 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
359 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
360 .add_request_handler(get_users)
361 .add_request_handler(fuzzy_search_users)
362 .add_request_handler(request_contact)
363 .add_request_handler(remove_contact)
364 .add_request_handler(respond_to_contact_request)
365 .add_message_handler(subscribe_to_channels)
366 .add_request_handler(create_channel)
367 .add_request_handler(delete_channel)
368 .add_request_handler(invite_channel_member)
369 .add_request_handler(remove_channel_member)
370 .add_request_handler(set_channel_member_role)
371 .add_request_handler(set_channel_visibility)
372 .add_request_handler(rename_channel)
373 .add_request_handler(join_channel_buffer)
374 .add_request_handler(leave_channel_buffer)
375 .add_message_handler(update_channel_buffer)
376 .add_request_handler(rejoin_channel_buffers)
377 .add_request_handler(get_channel_members)
378 .add_request_handler(respond_to_channel_invite)
379 .add_request_handler(join_channel)
380 .add_request_handler(join_channel_chat)
381 .add_message_handler(leave_channel_chat)
382 .add_request_handler(send_channel_message)
383 .add_request_handler(remove_channel_message)
384 .add_request_handler(update_channel_message)
385 .add_request_handler(get_channel_messages)
386 .add_request_handler(get_channel_messages_by_id)
387 .add_request_handler(get_notifications)
388 .add_request_handler(mark_notification_as_read)
389 .add_request_handler(move_channel)
390 .add_request_handler(follow)
391 .add_message_handler(unfollow)
392 .add_message_handler(update_followers)
393 .add_request_handler(get_private_user_info)
394 .add_request_handler(get_llm_api_token)
395 .add_request_handler(accept_terms_of_service)
396 .add_message_handler(acknowledge_channel_message)
397 .add_message_handler(acknowledge_buffer_version)
398 .add_request_handler(get_supermaven_api_key)
399 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
400 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
401 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
402 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
403 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
404 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
405 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
406 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
407 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
408 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
409 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
410 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
411 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
412 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
413 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
414 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
415 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
416 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
417 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
418 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
419 .add_message_handler(update_context)
420 .add_request_handler({
421 let app_state = app_state.clone();
422 move |request, response, session| {
423 let app_state = app_state.clone();
424 async move {
425 count_language_model_tokens(request, response, session, &app_state.config)
426 .await
427 }
428 }
429 })
430 .add_request_handler(get_cached_embeddings)
431 .add_request_handler({
432 let app_state = app_state.clone();
433 move |request, response, session| {
434 compute_embeddings(
435 request,
436 response,
437 session,
438 app_state.config.openai_api_key.clone(),
439 )
440 }
441 });
442
443 Arc::new(server)
444 }
445
446 pub async fn start(&self) -> Result<()> {
447 let server_id = *self.id.lock();
448 let app_state = self.app_state.clone();
449 let peer = self.peer.clone();
450 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
451 let pool = self.connection_pool.clone();
452 let livekit_client = self.app_state.livekit_client.clone();
453
454 let span = info_span!("start server");
455 self.app_state.executor.spawn_detached(
456 async move {
457 tracing::info!("waiting for cleanup timeout");
458 timeout.await;
459 tracing::info!("cleanup timeout expired, retrieving stale rooms");
460 if let Some((room_ids, channel_ids)) = app_state
461 .db
462 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
463 .await
464 .trace_err()
465 {
466 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
467 tracing::info!(
468 stale_channel_buffer_count = channel_ids.len(),
469 "retrieved stale channel buffers"
470 );
471
472 for channel_id in channel_ids {
473 if let Some(refreshed_channel_buffer) = app_state
474 .db
475 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
476 .await
477 .trace_err()
478 {
479 for connection_id in refreshed_channel_buffer.connection_ids {
480 peer.send(
481 connection_id,
482 proto::UpdateChannelBufferCollaborators {
483 channel_id: channel_id.to_proto(),
484 collaborators: refreshed_channel_buffer
485 .collaborators
486 .clone(),
487 },
488 )
489 .trace_err();
490 }
491 }
492 }
493
494 for room_id in room_ids {
495 let mut contacts_to_update = HashSet::default();
496 let mut canceled_calls_to_user_ids = Vec::new();
497 let mut livekit_room = String::new();
498 let mut delete_livekit_room = false;
499
500 if let Some(mut refreshed_room) = app_state
501 .db
502 .clear_stale_room_participants(room_id, server_id)
503 .await
504 .trace_err()
505 {
506 tracing::info!(
507 room_id = room_id.0,
508 new_participant_count = refreshed_room.room.participants.len(),
509 "refreshed room"
510 );
511 room_updated(&refreshed_room.room, &peer);
512 if let Some(channel) = refreshed_room.channel.as_ref() {
513 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
514 }
515 contacts_to_update
516 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
517 contacts_to_update
518 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
519 canceled_calls_to_user_ids =
520 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
521 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
522 delete_livekit_room = refreshed_room.room.participants.is_empty();
523 }
524
525 {
526 let pool = pool.lock();
527 for canceled_user_id in canceled_calls_to_user_ids {
528 for connection_id in pool.user_connection_ids(canceled_user_id) {
529 peer.send(
530 connection_id,
531 proto::CallCanceled {
532 room_id: room_id.to_proto(),
533 },
534 )
535 .trace_err();
536 }
537 }
538 }
539
540 for user_id in contacts_to_update {
541 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
542 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
543 if let Some((busy, contacts)) = busy.zip(contacts) {
544 let pool = pool.lock();
545 let updated_contact = contact_for_user(user_id, busy, &pool);
546 for contact in contacts {
547 if let db::Contact::Accepted {
548 user_id: contact_user_id,
549 ..
550 } = contact
551 {
552 for contact_conn_id in
553 pool.user_connection_ids(contact_user_id)
554 {
555 peer.send(
556 contact_conn_id,
557 proto::UpdateContacts {
558 contacts: vec![updated_contact.clone()],
559 remove_contacts: Default::default(),
560 incoming_requests: Default::default(),
561 remove_incoming_requests: Default::default(),
562 outgoing_requests: Default::default(),
563 remove_outgoing_requests: Default::default(),
564 },
565 )
566 .trace_err();
567 }
568 }
569 }
570 }
571 }
572
573 if let Some(live_kit) = livekit_client.as_ref() {
574 if delete_livekit_room {
575 live_kit.delete_room(livekit_room).await.trace_err();
576 }
577 }
578 }
579 }
580
581 app_state
582 .db
583 .delete_stale_servers(&app_state.config.zed_environment, server_id)
584 .await
585 .trace_err();
586 }
587 .instrument(span),
588 );
589 Ok(())
590 }
591
592 pub fn teardown(&self) {
593 self.peer.teardown();
594 self.connection_pool.lock().reset();
595 let _ = self.teardown.send(true);
596 }
597
598 #[cfg(test)]
599 pub fn reset(&self, id: ServerId) {
600 self.teardown();
601 *self.id.lock() = id;
602 self.peer.reset(id.0 as u32);
603 let _ = self.teardown.send(false);
604 }
605
606 #[cfg(test)]
607 pub fn id(&self) -> ServerId {
608 *self.id.lock()
609 }
610
611 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
612 where
613 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
614 Fut: 'static + Send + Future<Output = Result<()>>,
615 M: EnvelopedMessage,
616 {
617 let prev_handler = self.handlers.insert(
618 TypeId::of::<M>(),
619 Box::new(move |envelope, session| {
620 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
621 let received_at = envelope.received_at;
622 tracing::info!("message received");
623 let start_time = Instant::now();
624 let future = (handler)(*envelope, session);
625 async move {
626 let result = future.await;
627 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
628 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
629 let queue_duration_ms = total_duration_ms - processing_duration_ms;
630 let payload_type = M::NAME;
631
632 match result {
633 Err(error) => {
634 tracing::error!(
635 ?error,
636 total_duration_ms,
637 processing_duration_ms,
638 queue_duration_ms,
639 payload_type,
640 "error handling message"
641 )
642 }
643 Ok(()) => tracing::info!(
644 total_duration_ms,
645 processing_duration_ms,
646 queue_duration_ms,
647 "finished handling message"
648 ),
649 }
650 }
651 .boxed()
652 }),
653 );
654 if prev_handler.is_some() {
655 panic!("registered a handler for the same message twice");
656 }
657 self
658 }
659
660 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
661 where
662 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
663 Fut: 'static + Send + Future<Output = Result<()>>,
664 M: EnvelopedMessage,
665 {
666 self.add_handler(move |envelope, session| handler(envelope.payload, session));
667 self
668 }
669
670 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
671 where
672 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
673 Fut: Send + Future<Output = Result<()>>,
674 M: RequestMessage,
675 {
676 let handler = Arc::new(handler);
677 self.add_handler(move |envelope, session| {
678 let receipt = envelope.receipt();
679 let handler = handler.clone();
680 async move {
681 let peer = session.peer.clone();
682 let responded = Arc::new(AtomicBool::default());
683 let response = Response {
684 peer: peer.clone(),
685 responded: responded.clone(),
686 receipt,
687 };
688 match (handler)(envelope.payload, response, session).await {
689 Ok(()) => {
690 if responded.load(std::sync::atomic::Ordering::SeqCst) {
691 Ok(())
692 } else {
693 Err(anyhow!("handler did not send a response"))?
694 }
695 }
696 Err(error) => {
697 let proto_err = match &error {
698 Error::Internal(err) => err.to_proto(),
699 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
700 };
701 peer.respond_with_error(receipt, proto_err)?;
702 Err(error)
703 }
704 }
705 }
706 })
707 }
708
709 pub fn handle_connection(
710 self: &Arc<Self>,
711 connection: Connection,
712 address: String,
713 principal: Principal,
714 zed_version: ZedVersion,
715 geoip_country_code: Option<String>,
716 system_id: Option<String>,
717 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
718 executor: Executor,
719 ) -> impl Future<Output = ()> {
720 let this = self.clone();
721 let span = info_span!("handle connection", %address,
722 connection_id=field::Empty,
723 user_id=field::Empty,
724 login=field::Empty,
725 impersonator=field::Empty,
726 geoip_country_code=field::Empty
727 );
728 principal.update_span(&span);
729 if let Some(country_code) = geoip_country_code.as_ref() {
730 span.record("geoip_country_code", country_code);
731 }
732
733 let mut teardown = self.teardown.subscribe();
734 async move {
735 if *teardown.borrow() {
736 tracing::error!("server is tearing down");
737 return
738 }
739 let (connection_id, handle_io, mut incoming_rx) = this
740 .peer
741 .add_connection(connection, {
742 let executor = executor.clone();
743 move |duration| executor.sleep(duration)
744 });
745 tracing::Span::current().record("connection_id", format!("{}", connection_id));
746
747 tracing::info!("connection opened");
748
749 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
750 let http_client = match ReqwestClient::user_agent(&user_agent) {
751 Ok(http_client) => Arc::new(http_client),
752 Err(error) => {
753 tracing::error!(?error, "failed to create HTTP client");
754 return;
755 }
756 };
757
758 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
759 supermaven_admin_api_key.to_string(),
760 http_client.clone(),
761 )));
762
763 let session = Session {
764 principal: principal.clone(),
765 connection_id,
766 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
767 peer: this.peer.clone(),
768 connection_pool: this.connection_pool.clone(),
769 app_state: this.app_state.clone(),
770 http_client,
771 geoip_country_code,
772 system_id,
773 _executor: executor.clone(),
774 supermaven_client,
775 };
776
777 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
778 tracing::error!(?error, "failed to send initial client update");
779 return;
780 }
781
782 let handle_io = handle_io.fuse();
783 futures::pin_mut!(handle_io);
784
785 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
786 // This prevents deadlocks when e.g., client A performs a request to client B and
787 // client B performs a request to client A. If both clients stop processing further
788 // messages until their respective request completes, they won't have a chance to
789 // respond to the other client's request and cause a deadlock.
790 //
791 // This arrangement ensures we will attempt to process earlier messages first, but fall
792 // back to processing messages arrived later in the spirit of making progress.
793 let mut foreground_message_handlers = FuturesUnordered::new();
794 let concurrent_handlers = Arc::new(Semaphore::new(256));
795 loop {
796 let next_message = async {
797 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
798 let message = incoming_rx.next().await;
799 (permit, message)
800 }.fuse();
801 futures::pin_mut!(next_message);
802 futures::select_biased! {
803 _ = teardown.changed().fuse() => return,
804 result = handle_io => {
805 if let Err(error) = result {
806 tracing::error!(?error, "error handling I/O");
807 }
808 break;
809 }
810 _ = foreground_message_handlers.next() => {}
811 next_message = next_message => {
812 let (permit, message) = next_message;
813 if let Some(message) = message {
814 let type_name = message.payload_type_name();
815 // note: we copy all the fields from the parent span so we can query them in the logs.
816 // (https://github.com/tokio-rs/tracing/issues/2670).
817 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
818 user_id=field::Empty,
819 login=field::Empty,
820 impersonator=field::Empty,
821 );
822 principal.update_span(&span);
823 let span_enter = span.enter();
824 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
825 let is_background = message.is_background();
826 let handle_message = (handler)(message, session.clone());
827 drop(span_enter);
828
829 let handle_message = async move {
830 handle_message.await;
831 drop(permit);
832 }.instrument(span);
833 if is_background {
834 executor.spawn_detached(handle_message);
835 } else {
836 foreground_message_handlers.push(handle_message);
837 }
838 } else {
839 tracing::error!("no message handler");
840 }
841 } else {
842 tracing::info!("connection closed");
843 break;
844 }
845 }
846 }
847 }
848
849 drop(foreground_message_handlers);
850 tracing::info!("signing out");
851 if let Err(error) = connection_lost(session, teardown, executor).await {
852 tracing::error!(?error, "error signing out");
853 }
854
855 }.instrument(span)
856 }
857
858 async fn send_initial_client_update(
859 &self,
860 connection_id: ConnectionId,
861 principal: &Principal,
862 zed_version: ZedVersion,
863 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
864 session: &Session,
865 ) -> Result<()> {
866 self.peer.send(
867 connection_id,
868 proto::Hello {
869 peer_id: Some(connection_id.into()),
870 },
871 )?;
872 tracing::info!("sent hello message");
873 if let Some(send_connection_id) = send_connection_id.take() {
874 let _ = send_connection_id.send(connection_id);
875 }
876
877 match principal {
878 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
879 if !user.connected_once {
880 self.peer.send(connection_id, proto::ShowContacts {})?;
881 self.app_state
882 .db
883 .set_user_connected_once(user.id, true)
884 .await?;
885 }
886
887 update_user_plan(user.id, session).await?;
888
889 let contacts = self.app_state.db.get_contacts(user.id).await?;
890
891 {
892 let mut pool = self.connection_pool.lock();
893 pool.add_connection(connection_id, user.id, user.admin, zed_version);
894 self.peer.send(
895 connection_id,
896 build_initial_contacts_update(contacts, &pool),
897 )?;
898 }
899
900 if should_auto_subscribe_to_channels(zed_version) {
901 subscribe_user_to_channels(user.id, session).await?;
902 }
903
904 if let Some(incoming_call) =
905 self.app_state.db.incoming_call_for_user(user.id).await?
906 {
907 self.peer.send(connection_id, incoming_call)?;
908 }
909
910 update_user_contacts(user.id, session).await?;
911 }
912 }
913
914 Ok(())
915 }
916
917 pub async fn invite_code_redeemed(
918 self: &Arc<Self>,
919 inviter_id: UserId,
920 invitee_id: UserId,
921 ) -> Result<()> {
922 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
923 if let Some(code) = &user.invite_code {
924 let pool = self.connection_pool.lock();
925 let invitee_contact = contact_for_user(invitee_id, false, &pool);
926 for connection_id in pool.user_connection_ids(inviter_id) {
927 self.peer.send(
928 connection_id,
929 proto::UpdateContacts {
930 contacts: vec![invitee_contact.clone()],
931 ..Default::default()
932 },
933 )?;
934 self.peer.send(
935 connection_id,
936 proto::UpdateInviteInfo {
937 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
938 count: user.invite_count as u32,
939 },
940 )?;
941 }
942 }
943 }
944 Ok(())
945 }
946
947 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
948 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
949 if let Some(invite_code) = &user.invite_code {
950 let pool = self.connection_pool.lock();
951 for connection_id in pool.user_connection_ids(user_id) {
952 self.peer.send(
953 connection_id,
954 proto::UpdateInviteInfo {
955 url: format!(
956 "{}{}",
957 self.app_state.config.invite_link_prefix, invite_code
958 ),
959 count: user.invite_count as u32,
960 },
961 )?;
962 }
963 }
964 }
965 Ok(())
966 }
967
968 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
969 let pool = self.connection_pool.lock();
970 for connection_id in pool.user_connection_ids(user_id) {
971 self.peer
972 .send(connection_id, proto::RefreshLlmToken {})
973 .trace_err();
974 }
975 }
976
977 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
978 ServerSnapshot {
979 connection_pool: ConnectionPoolGuard {
980 guard: self.connection_pool.lock(),
981 _not_send: PhantomData,
982 },
983 peer: &self.peer,
984 }
985 }
986}
987
988impl Deref for ConnectionPoolGuard<'_> {
989 type Target = ConnectionPool;
990
991 fn deref(&self) -> &Self::Target {
992 &self.guard
993 }
994}
995
996impl DerefMut for ConnectionPoolGuard<'_> {
997 fn deref_mut(&mut self) -> &mut Self::Target {
998 &mut self.guard
999 }
1000}
1001
1002impl Drop for ConnectionPoolGuard<'_> {
1003 fn drop(&mut self) {
1004 #[cfg(test)]
1005 self.check_invariants();
1006 }
1007}
1008
1009fn broadcast<F>(
1010 sender_id: Option<ConnectionId>,
1011 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1012 mut f: F,
1013) where
1014 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1015{
1016 for receiver_id in receiver_ids {
1017 if Some(receiver_id) != sender_id {
1018 if let Err(error) = f(receiver_id) {
1019 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1020 }
1021 }
1022 }
1023}
1024
1025pub struct ProtocolVersion(u32);
1026
1027impl Header for ProtocolVersion {
1028 fn name() -> &'static HeaderName {
1029 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1030 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1031 }
1032
1033 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1034 where
1035 Self: Sized,
1036 I: Iterator<Item = &'i axum::http::HeaderValue>,
1037 {
1038 let version = values
1039 .next()
1040 .ok_or_else(axum::headers::Error::invalid)?
1041 .to_str()
1042 .map_err(|_| axum::headers::Error::invalid())?
1043 .parse()
1044 .map_err(|_| axum::headers::Error::invalid())?;
1045 Ok(Self(version))
1046 }
1047
1048 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1049 values.extend([self.0.to_string().parse().unwrap()]);
1050 }
1051}
1052
1053pub struct AppVersionHeader(SemanticVersion);
1054impl Header for AppVersionHeader {
1055 fn name() -> &'static HeaderName {
1056 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1057 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1058 }
1059
1060 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1061 where
1062 Self: Sized,
1063 I: Iterator<Item = &'i axum::http::HeaderValue>,
1064 {
1065 let version = values
1066 .next()
1067 .ok_or_else(axum::headers::Error::invalid)?
1068 .to_str()
1069 .map_err(|_| axum::headers::Error::invalid())?
1070 .parse()
1071 .map_err(|_| axum::headers::Error::invalid())?;
1072 Ok(Self(version))
1073 }
1074
1075 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1076 values.extend([self.0.to_string().parse().unwrap()]);
1077 }
1078}
1079
1080pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1081 Router::new()
1082 .route("/rpc", get(handle_websocket_request))
1083 .layer(
1084 ServiceBuilder::new()
1085 .layer(Extension(server.app_state.clone()))
1086 .layer(middleware::from_fn(auth::validate_header)),
1087 )
1088 .route("/metrics", get(handle_metrics))
1089 .layer(Extension(server))
1090}
1091
1092pub async fn handle_websocket_request(
1093 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1094 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1095 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1096 Extension(server): Extension<Arc<Server>>,
1097 Extension(principal): Extension<Principal>,
1098 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1099 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1100 ws: WebSocketUpgrade,
1101) -> axum::response::Response {
1102 if protocol_version != rpc::PROTOCOL_VERSION {
1103 return (
1104 StatusCode::UPGRADE_REQUIRED,
1105 "client must be upgraded".to_string(),
1106 )
1107 .into_response();
1108 }
1109
1110 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1111 return (
1112 StatusCode::UPGRADE_REQUIRED,
1113 "no version header found".to_string(),
1114 )
1115 .into_response();
1116 };
1117
1118 if !version.can_collaborate() {
1119 return (
1120 StatusCode::UPGRADE_REQUIRED,
1121 "client must be upgraded".to_string(),
1122 )
1123 .into_response();
1124 }
1125
1126 let socket_address = socket_address.to_string();
1127 ws.on_upgrade(move |socket| {
1128 let socket = socket
1129 .map_ok(to_tungstenite_message)
1130 .err_into()
1131 .with(|message| async move { to_axum_message(message) });
1132 let connection = Connection::new(Box::pin(socket));
1133 async move {
1134 server
1135 .handle_connection(
1136 connection,
1137 socket_address,
1138 principal,
1139 version,
1140 country_code_header.map(|header| header.to_string()),
1141 system_id_header.map(|header| header.to_string()),
1142 None,
1143 Executor::Production,
1144 )
1145 .await;
1146 }
1147 })
1148}
1149
1150pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1151 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1152 let connections_metric = CONNECTIONS_METRIC
1153 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1154
1155 let connections = server
1156 .connection_pool
1157 .lock()
1158 .connections()
1159 .filter(|connection| !connection.admin)
1160 .count();
1161 connections_metric.set(connections as _);
1162
1163 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1164 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1165 register_int_gauge!(
1166 "shared_projects",
1167 "number of open projects with one or more guests"
1168 )
1169 .unwrap()
1170 });
1171
1172 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1173 shared_projects_metric.set(shared_projects as _);
1174
1175 let encoder = prometheus::TextEncoder::new();
1176 let metric_families = prometheus::gather();
1177 let encoded_metrics = encoder
1178 .encode_to_string(&metric_families)
1179 .map_err(|err| anyhow!("{}", err))?;
1180 Ok(encoded_metrics)
1181}
1182
1183#[instrument(err, skip(executor))]
1184async fn connection_lost(
1185 session: Session,
1186 mut teardown: watch::Receiver<bool>,
1187 executor: Executor,
1188) -> Result<()> {
1189 session.peer.disconnect(session.connection_id);
1190 session
1191 .connection_pool()
1192 .await
1193 .remove_connection(session.connection_id)?;
1194
1195 session
1196 .db()
1197 .await
1198 .connection_lost(session.connection_id)
1199 .await
1200 .trace_err();
1201
1202 futures::select_biased! {
1203 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1204
1205 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1206 leave_room_for_session(&session, session.connection_id).await.trace_err();
1207 leave_channel_buffers_for_session(&session)
1208 .await
1209 .trace_err();
1210
1211 if !session
1212 .connection_pool()
1213 .await
1214 .is_user_online(session.user_id())
1215 {
1216 let db = session.db().await;
1217 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1218 room_updated(&room, &session.peer);
1219 }
1220 }
1221
1222 update_user_contacts(session.user_id(), &session).await?;
1223 },
1224 _ = teardown.changed().fuse() => {}
1225 }
1226
1227 Ok(())
1228}
1229
1230/// Acknowledges a ping from a client, used to keep the connection alive.
1231async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1232 response.send(proto::Ack {})?;
1233 Ok(())
1234}
1235
1236/// Creates a new room for calling (outside of channels)
1237async fn create_room(
1238 _request: proto::CreateRoom,
1239 response: Response<proto::CreateRoom>,
1240 session: Session,
1241) -> Result<()> {
1242 let livekit_room = nanoid::nanoid!(30);
1243
1244 let live_kit_connection_info = util::maybe!(async {
1245 let live_kit = session.app_state.livekit_client.as_ref();
1246 let live_kit = live_kit?;
1247 let user_id = session.user_id().to_string();
1248
1249 let token = live_kit
1250 .room_token(&livekit_room, &user_id.to_string())
1251 .trace_err()?;
1252
1253 Some(proto::LiveKitConnectionInfo {
1254 server_url: live_kit.url().into(),
1255 token,
1256 can_publish: true,
1257 })
1258 })
1259 .await;
1260
1261 let room = session
1262 .db()
1263 .await
1264 .create_room(session.user_id(), session.connection_id, &livekit_room)
1265 .await?;
1266
1267 response.send(proto::CreateRoomResponse {
1268 room: Some(room.clone()),
1269 live_kit_connection_info,
1270 })?;
1271
1272 update_user_contacts(session.user_id(), &session).await?;
1273 Ok(())
1274}
1275
1276/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1277async fn join_room(
1278 request: proto::JoinRoom,
1279 response: Response<proto::JoinRoom>,
1280 session: Session,
1281) -> Result<()> {
1282 let room_id = RoomId::from_proto(request.id);
1283
1284 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1285
1286 if let Some(channel_id) = channel_id {
1287 return join_channel_internal(channel_id, Box::new(response), session).await;
1288 }
1289
1290 let joined_room = {
1291 let room = session
1292 .db()
1293 .await
1294 .join_room(room_id, session.user_id(), session.connection_id)
1295 .await?;
1296 room_updated(&room.room, &session.peer);
1297 room.into_inner()
1298 };
1299
1300 for connection_id in session
1301 .connection_pool()
1302 .await
1303 .user_connection_ids(session.user_id())
1304 {
1305 session
1306 .peer
1307 .send(
1308 connection_id,
1309 proto::CallCanceled {
1310 room_id: room_id.to_proto(),
1311 },
1312 )
1313 .trace_err();
1314 }
1315
1316 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1317 {
1318 live_kit
1319 .room_token(
1320 &joined_room.room.livekit_room,
1321 &session.user_id().to_string(),
1322 )
1323 .trace_err()
1324 .map(|token| proto::LiveKitConnectionInfo {
1325 server_url: live_kit.url().into(),
1326 token,
1327 can_publish: true,
1328 })
1329 } else {
1330 None
1331 };
1332
1333 response.send(proto::JoinRoomResponse {
1334 room: Some(joined_room.room),
1335 channel_id: None,
1336 live_kit_connection_info,
1337 })?;
1338
1339 update_user_contacts(session.user_id(), &session).await?;
1340 Ok(())
1341}
1342
1343/// Rejoin room is used to reconnect to a room after connection errors.
1344async fn rejoin_room(
1345 request: proto::RejoinRoom,
1346 response: Response<proto::RejoinRoom>,
1347 session: Session,
1348) -> Result<()> {
1349 let room;
1350 let channel;
1351 {
1352 let mut rejoined_room = session
1353 .db()
1354 .await
1355 .rejoin_room(request, session.user_id(), session.connection_id)
1356 .await?;
1357
1358 response.send(proto::RejoinRoomResponse {
1359 room: Some(rejoined_room.room.clone()),
1360 reshared_projects: rejoined_room
1361 .reshared_projects
1362 .iter()
1363 .map(|project| proto::ResharedProject {
1364 id: project.id.to_proto(),
1365 collaborators: project
1366 .collaborators
1367 .iter()
1368 .map(|collaborator| collaborator.to_proto())
1369 .collect(),
1370 })
1371 .collect(),
1372 rejoined_projects: rejoined_room
1373 .rejoined_projects
1374 .iter()
1375 .map(|rejoined_project| rejoined_project.to_proto())
1376 .collect(),
1377 })?;
1378 room_updated(&rejoined_room.room, &session.peer);
1379
1380 for project in &rejoined_room.reshared_projects {
1381 for collaborator in &project.collaborators {
1382 session
1383 .peer
1384 .send(
1385 collaborator.connection_id,
1386 proto::UpdateProjectCollaborator {
1387 project_id: project.id.to_proto(),
1388 old_peer_id: Some(project.old_connection_id.into()),
1389 new_peer_id: Some(session.connection_id.into()),
1390 },
1391 )
1392 .trace_err();
1393 }
1394
1395 broadcast(
1396 Some(session.connection_id),
1397 project
1398 .collaborators
1399 .iter()
1400 .map(|collaborator| collaborator.connection_id),
1401 |connection_id| {
1402 session.peer.forward_send(
1403 session.connection_id,
1404 connection_id,
1405 proto::UpdateProject {
1406 project_id: project.id.to_proto(),
1407 worktrees: project.worktrees.clone(),
1408 },
1409 )
1410 },
1411 );
1412 }
1413
1414 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1415
1416 let rejoined_room = rejoined_room.into_inner();
1417
1418 room = rejoined_room.room;
1419 channel = rejoined_room.channel;
1420 }
1421
1422 if let Some(channel) = channel {
1423 channel_updated(
1424 &channel,
1425 &room,
1426 &session.peer,
1427 &*session.connection_pool().await,
1428 );
1429 }
1430
1431 update_user_contacts(session.user_id(), &session).await?;
1432 Ok(())
1433}
1434
1435fn notify_rejoined_projects(
1436 rejoined_projects: &mut Vec<RejoinedProject>,
1437 session: &Session,
1438) -> Result<()> {
1439 for project in rejoined_projects.iter() {
1440 for collaborator in &project.collaborators {
1441 session
1442 .peer
1443 .send(
1444 collaborator.connection_id,
1445 proto::UpdateProjectCollaborator {
1446 project_id: project.id.to_proto(),
1447 old_peer_id: Some(project.old_connection_id.into()),
1448 new_peer_id: Some(session.connection_id.into()),
1449 },
1450 )
1451 .trace_err();
1452 }
1453 }
1454
1455 for project in rejoined_projects {
1456 for worktree in mem::take(&mut project.worktrees) {
1457 // Stream this worktree's entries.
1458 let message = proto::UpdateWorktree {
1459 project_id: project.id.to_proto(),
1460 worktree_id: worktree.id,
1461 abs_path: worktree.abs_path.clone(),
1462 root_name: worktree.root_name,
1463 updated_entries: worktree.updated_entries,
1464 removed_entries: worktree.removed_entries,
1465 scan_id: worktree.scan_id,
1466 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1467 updated_repositories: worktree.updated_repositories,
1468 removed_repositories: worktree.removed_repositories,
1469 };
1470 for update in proto::split_worktree_update(message) {
1471 session.peer.send(session.connection_id, update)?;
1472 }
1473
1474 // Stream this worktree's diagnostics.
1475 for summary in worktree.diagnostic_summaries {
1476 session.peer.send(
1477 session.connection_id,
1478 proto::UpdateDiagnosticSummary {
1479 project_id: project.id.to_proto(),
1480 worktree_id: worktree.id,
1481 summary: Some(summary),
1482 },
1483 )?;
1484 }
1485
1486 for settings_file in worktree.settings_files {
1487 session.peer.send(
1488 session.connection_id,
1489 proto::UpdateWorktreeSettings {
1490 project_id: project.id.to_proto(),
1491 worktree_id: worktree.id,
1492 path: settings_file.path,
1493 content: Some(settings_file.content),
1494 kind: Some(settings_file.kind.to_proto().into()),
1495 },
1496 )?;
1497 }
1498 }
1499
1500 for repository in mem::take(&mut project.updated_repositories) {
1501 for update in split_repository_update(repository) {
1502 session.peer.send(session.connection_id, update)?;
1503 }
1504 }
1505
1506 for id in mem::take(&mut project.removed_repositories) {
1507 session.peer.send(
1508 session.connection_id,
1509 proto::RemoveRepository {
1510 project_id: project.id.to_proto(),
1511 id,
1512 },
1513 )?;
1514 }
1515 }
1516
1517 Ok(())
1518}
1519
1520/// leave room disconnects from the room.
1521async fn leave_room(
1522 _: proto::LeaveRoom,
1523 response: Response<proto::LeaveRoom>,
1524 session: Session,
1525) -> Result<()> {
1526 leave_room_for_session(&session, session.connection_id).await?;
1527 response.send(proto::Ack {})?;
1528 Ok(())
1529}
1530
1531/// Updates the permissions of someone else in the room.
1532async fn set_room_participant_role(
1533 request: proto::SetRoomParticipantRole,
1534 response: Response<proto::SetRoomParticipantRole>,
1535 session: Session,
1536) -> Result<()> {
1537 let user_id = UserId::from_proto(request.user_id);
1538 let role = ChannelRole::from(request.role());
1539
1540 let (livekit_room, can_publish) = {
1541 let room = session
1542 .db()
1543 .await
1544 .set_room_participant_role(
1545 session.user_id(),
1546 RoomId::from_proto(request.room_id),
1547 user_id,
1548 role,
1549 )
1550 .await?;
1551
1552 let livekit_room = room.livekit_room.clone();
1553 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1554 room_updated(&room, &session.peer);
1555 (livekit_room, can_publish)
1556 };
1557
1558 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1559 live_kit
1560 .update_participant(
1561 livekit_room.clone(),
1562 request.user_id.to_string(),
1563 livekit_api::proto::ParticipantPermission {
1564 can_subscribe: true,
1565 can_publish,
1566 can_publish_data: can_publish,
1567 hidden: false,
1568 recorder: false,
1569 },
1570 )
1571 .await
1572 .trace_err();
1573 }
1574
1575 response.send(proto::Ack {})?;
1576 Ok(())
1577}
1578
1579/// Call someone else into the current room
1580async fn call(
1581 request: proto::Call,
1582 response: Response<proto::Call>,
1583 session: Session,
1584) -> Result<()> {
1585 let room_id = RoomId::from_proto(request.room_id);
1586 let calling_user_id = session.user_id();
1587 let calling_connection_id = session.connection_id;
1588 let called_user_id = UserId::from_proto(request.called_user_id);
1589 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1590 if !session
1591 .db()
1592 .await
1593 .has_contact(calling_user_id, called_user_id)
1594 .await?
1595 {
1596 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1597 }
1598
1599 let incoming_call = {
1600 let (room, incoming_call) = &mut *session
1601 .db()
1602 .await
1603 .call(
1604 room_id,
1605 calling_user_id,
1606 calling_connection_id,
1607 called_user_id,
1608 initial_project_id,
1609 )
1610 .await?;
1611 room_updated(room, &session.peer);
1612 mem::take(incoming_call)
1613 };
1614 update_user_contacts(called_user_id, &session).await?;
1615
1616 let mut calls = session
1617 .connection_pool()
1618 .await
1619 .user_connection_ids(called_user_id)
1620 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1621 .collect::<FuturesUnordered<_>>();
1622
1623 while let Some(call_response) = calls.next().await {
1624 match call_response.as_ref() {
1625 Ok(_) => {
1626 response.send(proto::Ack {})?;
1627 return Ok(());
1628 }
1629 Err(_) => {
1630 call_response.trace_err();
1631 }
1632 }
1633 }
1634
1635 {
1636 let room = session
1637 .db()
1638 .await
1639 .call_failed(room_id, called_user_id)
1640 .await?;
1641 room_updated(&room, &session.peer);
1642 }
1643 update_user_contacts(called_user_id, &session).await?;
1644
1645 Err(anyhow!("failed to ring user"))?
1646}
1647
1648/// Cancel an outgoing call.
1649async fn cancel_call(
1650 request: proto::CancelCall,
1651 response: Response<proto::CancelCall>,
1652 session: Session,
1653) -> Result<()> {
1654 let called_user_id = UserId::from_proto(request.called_user_id);
1655 let room_id = RoomId::from_proto(request.room_id);
1656 {
1657 let room = session
1658 .db()
1659 .await
1660 .cancel_call(room_id, session.connection_id, called_user_id)
1661 .await?;
1662 room_updated(&room, &session.peer);
1663 }
1664
1665 for connection_id in session
1666 .connection_pool()
1667 .await
1668 .user_connection_ids(called_user_id)
1669 {
1670 session
1671 .peer
1672 .send(
1673 connection_id,
1674 proto::CallCanceled {
1675 room_id: room_id.to_proto(),
1676 },
1677 )
1678 .trace_err();
1679 }
1680 response.send(proto::Ack {})?;
1681
1682 update_user_contacts(called_user_id, &session).await?;
1683 Ok(())
1684}
1685
1686/// Decline an incoming call.
1687async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1688 let room_id = RoomId::from_proto(message.room_id);
1689 {
1690 let room = session
1691 .db()
1692 .await
1693 .decline_call(Some(room_id), session.user_id())
1694 .await?
1695 .ok_or_else(|| anyhow!("failed to decline call"))?;
1696 room_updated(&room, &session.peer);
1697 }
1698
1699 for connection_id in session
1700 .connection_pool()
1701 .await
1702 .user_connection_ids(session.user_id())
1703 {
1704 session
1705 .peer
1706 .send(
1707 connection_id,
1708 proto::CallCanceled {
1709 room_id: room_id.to_proto(),
1710 },
1711 )
1712 .trace_err();
1713 }
1714 update_user_contacts(session.user_id(), &session).await?;
1715 Ok(())
1716}
1717
1718/// Updates other participants in the room with your current location.
1719async fn update_participant_location(
1720 request: proto::UpdateParticipantLocation,
1721 response: Response<proto::UpdateParticipantLocation>,
1722 session: Session,
1723) -> Result<()> {
1724 let room_id = RoomId::from_proto(request.room_id);
1725 let location = request
1726 .location
1727 .ok_or_else(|| anyhow!("invalid location"))?;
1728
1729 let db = session.db().await;
1730 let room = db
1731 .update_room_participant_location(room_id, session.connection_id, location)
1732 .await?;
1733
1734 room_updated(&room, &session.peer);
1735 response.send(proto::Ack {})?;
1736 Ok(())
1737}
1738
1739/// Share a project into the room.
1740async fn share_project(
1741 request: proto::ShareProject,
1742 response: Response<proto::ShareProject>,
1743 session: Session,
1744) -> Result<()> {
1745 let (project_id, room) = &*session
1746 .db()
1747 .await
1748 .share_project(
1749 RoomId::from_proto(request.room_id),
1750 session.connection_id,
1751 &request.worktrees,
1752 request.is_ssh_project,
1753 )
1754 .await?;
1755 response.send(proto::ShareProjectResponse {
1756 project_id: project_id.to_proto(),
1757 })?;
1758 room_updated(room, &session.peer);
1759
1760 Ok(())
1761}
1762
1763/// Unshare a project from the room.
1764async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1765 let project_id = ProjectId::from_proto(message.project_id);
1766 unshare_project_internal(project_id, session.connection_id, &session).await
1767}
1768
1769async fn unshare_project_internal(
1770 project_id: ProjectId,
1771 connection_id: ConnectionId,
1772 session: &Session,
1773) -> Result<()> {
1774 let delete = {
1775 let room_guard = session
1776 .db()
1777 .await
1778 .unshare_project(project_id, connection_id)
1779 .await?;
1780
1781 let (delete, room, guest_connection_ids) = &*room_guard;
1782
1783 let message = proto::UnshareProject {
1784 project_id: project_id.to_proto(),
1785 };
1786
1787 broadcast(
1788 Some(connection_id),
1789 guest_connection_ids.iter().copied(),
1790 |conn_id| session.peer.send(conn_id, message.clone()),
1791 );
1792 if let Some(room) = room {
1793 room_updated(room, &session.peer);
1794 }
1795
1796 *delete
1797 };
1798
1799 if delete {
1800 let db = session.db().await;
1801 db.delete_project(project_id).await?;
1802 }
1803
1804 Ok(())
1805}
1806
1807/// Join someone elses shared project.
1808async fn join_project(
1809 request: proto::JoinProject,
1810 response: Response<proto::JoinProject>,
1811 session: Session,
1812) -> Result<()> {
1813 let project_id = ProjectId::from_proto(request.project_id);
1814
1815 tracing::info!(%project_id, "join project");
1816
1817 let db = session.db().await;
1818 let (project, replica_id) = &mut *db
1819 .join_project(project_id, session.connection_id, session.user_id())
1820 .await?;
1821 drop(db);
1822 tracing::info!(%project_id, "join remote project");
1823 join_project_internal(response, session, project, replica_id)
1824}
1825
1826trait JoinProjectInternalResponse {
1827 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1828}
1829impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1830 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1831 Response::<proto::JoinProject>::send(self, result)
1832 }
1833}
1834
1835fn join_project_internal(
1836 response: impl JoinProjectInternalResponse,
1837 session: Session,
1838 project: &mut Project,
1839 replica_id: &ReplicaId,
1840) -> Result<()> {
1841 let collaborators = project
1842 .collaborators
1843 .iter()
1844 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1845 .map(|collaborator| collaborator.to_proto())
1846 .collect::<Vec<_>>();
1847 let project_id = project.id;
1848 let guest_user_id = session.user_id();
1849
1850 let worktrees = project
1851 .worktrees
1852 .iter()
1853 .map(|(id, worktree)| proto::WorktreeMetadata {
1854 id: *id,
1855 root_name: worktree.root_name.clone(),
1856 visible: worktree.visible,
1857 abs_path: worktree.abs_path.clone(),
1858 })
1859 .collect::<Vec<_>>();
1860
1861 let add_project_collaborator = proto::AddProjectCollaborator {
1862 project_id: project_id.to_proto(),
1863 collaborator: Some(proto::Collaborator {
1864 peer_id: Some(session.connection_id.into()),
1865 replica_id: replica_id.0 as u32,
1866 user_id: guest_user_id.to_proto(),
1867 is_host: false,
1868 }),
1869 };
1870
1871 for collaborator in &collaborators {
1872 session
1873 .peer
1874 .send(
1875 collaborator.peer_id.unwrap().into(),
1876 add_project_collaborator.clone(),
1877 )
1878 .trace_err();
1879 }
1880
1881 // First, we send the metadata associated with each worktree.
1882 response.send(proto::JoinProjectResponse {
1883 project_id: project.id.0 as u64,
1884 worktrees: worktrees.clone(),
1885 replica_id: replica_id.0 as u32,
1886 collaborators: collaborators.clone(),
1887 language_servers: project.language_servers.clone(),
1888 role: project.role.into(),
1889 })?;
1890
1891 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1892 // Stream this worktree's entries.
1893 let message = proto::UpdateWorktree {
1894 project_id: project_id.to_proto(),
1895 worktree_id,
1896 abs_path: worktree.abs_path.clone(),
1897 root_name: worktree.root_name,
1898 updated_entries: worktree.entries,
1899 removed_entries: Default::default(),
1900 scan_id: worktree.scan_id,
1901 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1902 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1903 removed_repositories: Default::default(),
1904 };
1905 for update in proto::split_worktree_update(message) {
1906 session.peer.send(session.connection_id, update.clone())?;
1907 }
1908
1909 // Stream this worktree's diagnostics.
1910 for summary in worktree.diagnostic_summaries {
1911 session.peer.send(
1912 session.connection_id,
1913 proto::UpdateDiagnosticSummary {
1914 project_id: project_id.to_proto(),
1915 worktree_id: worktree.id,
1916 summary: Some(summary),
1917 },
1918 )?;
1919 }
1920
1921 for settings_file in worktree.settings_files {
1922 session.peer.send(
1923 session.connection_id,
1924 proto::UpdateWorktreeSettings {
1925 project_id: project_id.to_proto(),
1926 worktree_id: worktree.id,
1927 path: settings_file.path,
1928 content: Some(settings_file.content),
1929 kind: Some(settings_file.kind.to_proto() as i32),
1930 },
1931 )?;
1932 }
1933 }
1934
1935 for repository in mem::take(&mut project.repositories) {
1936 for update in split_repository_update(repository) {
1937 session.peer.send(session.connection_id, update)?;
1938 }
1939 }
1940
1941 for language_server in &project.language_servers {
1942 session.peer.send(
1943 session.connection_id,
1944 proto::UpdateLanguageServer {
1945 project_id: project_id.to_proto(),
1946 language_server_id: language_server.id,
1947 variant: Some(
1948 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1949 proto::LspDiskBasedDiagnosticsUpdated {},
1950 ),
1951 ),
1952 },
1953 )?;
1954 }
1955
1956 Ok(())
1957}
1958
1959/// Leave someone elses shared project.
1960async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1961 let sender_id = session.connection_id;
1962 let project_id = ProjectId::from_proto(request.project_id);
1963 let db = session.db().await;
1964
1965 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1966 tracing::info!(
1967 %project_id,
1968 "leave project"
1969 );
1970
1971 project_left(project, &session);
1972 if let Some(room) = room {
1973 room_updated(room, &session.peer);
1974 }
1975
1976 Ok(())
1977}
1978
1979/// Updates other participants with changes to the project
1980async fn update_project(
1981 request: proto::UpdateProject,
1982 response: Response<proto::UpdateProject>,
1983 session: Session,
1984) -> Result<()> {
1985 let project_id = ProjectId::from_proto(request.project_id);
1986 let (room, guest_connection_ids) = &*session
1987 .db()
1988 .await
1989 .update_project(project_id, session.connection_id, &request.worktrees)
1990 .await?;
1991 broadcast(
1992 Some(session.connection_id),
1993 guest_connection_ids.iter().copied(),
1994 |connection_id| {
1995 session
1996 .peer
1997 .forward_send(session.connection_id, connection_id, request.clone())
1998 },
1999 );
2000 if let Some(room) = room {
2001 room_updated(room, &session.peer);
2002 }
2003 response.send(proto::Ack {})?;
2004
2005 Ok(())
2006}
2007
2008/// Updates other participants with changes to the worktree
2009async fn update_worktree(
2010 request: proto::UpdateWorktree,
2011 response: Response<proto::UpdateWorktree>,
2012 session: Session,
2013) -> Result<()> {
2014 let guest_connection_ids = session
2015 .db()
2016 .await
2017 .update_worktree(&request, session.connection_id)
2018 .await?;
2019
2020 broadcast(
2021 Some(session.connection_id),
2022 guest_connection_ids.iter().copied(),
2023 |connection_id| {
2024 session
2025 .peer
2026 .forward_send(session.connection_id, connection_id, request.clone())
2027 },
2028 );
2029 response.send(proto::Ack {})?;
2030 Ok(())
2031}
2032
2033async fn update_repository(
2034 request: proto::UpdateRepository,
2035 response: Response<proto::UpdateRepository>,
2036 session: Session,
2037) -> Result<()> {
2038 let guest_connection_ids = session
2039 .db()
2040 .await
2041 .update_repository(&request, session.connection_id)
2042 .await?;
2043
2044 broadcast(
2045 Some(session.connection_id),
2046 guest_connection_ids.iter().copied(),
2047 |connection_id| {
2048 session
2049 .peer
2050 .forward_send(session.connection_id, connection_id, request.clone())
2051 },
2052 );
2053 response.send(proto::Ack {})?;
2054 Ok(())
2055}
2056
2057async fn remove_repository(
2058 request: proto::RemoveRepository,
2059 response: Response<proto::RemoveRepository>,
2060 session: Session,
2061) -> Result<()> {
2062 let guest_connection_ids = session
2063 .db()
2064 .await
2065 .remove_repository(&request, session.connection_id)
2066 .await?;
2067
2068 broadcast(
2069 Some(session.connection_id),
2070 guest_connection_ids.iter().copied(),
2071 |connection_id| {
2072 session
2073 .peer
2074 .forward_send(session.connection_id, connection_id, request.clone())
2075 },
2076 );
2077 response.send(proto::Ack {})?;
2078 Ok(())
2079}
2080
2081/// Updates other participants with changes to the diagnostics
2082async fn update_diagnostic_summary(
2083 message: proto::UpdateDiagnosticSummary,
2084 session: Session,
2085) -> Result<()> {
2086 let guest_connection_ids = session
2087 .db()
2088 .await
2089 .update_diagnostic_summary(&message, session.connection_id)
2090 .await?;
2091
2092 broadcast(
2093 Some(session.connection_id),
2094 guest_connection_ids.iter().copied(),
2095 |connection_id| {
2096 session
2097 .peer
2098 .forward_send(session.connection_id, connection_id, message.clone())
2099 },
2100 );
2101
2102 Ok(())
2103}
2104
2105/// Updates other participants with changes to the worktree settings
2106async fn update_worktree_settings(
2107 message: proto::UpdateWorktreeSettings,
2108 session: Session,
2109) -> Result<()> {
2110 let guest_connection_ids = session
2111 .db()
2112 .await
2113 .update_worktree_settings(&message, session.connection_id)
2114 .await?;
2115
2116 broadcast(
2117 Some(session.connection_id),
2118 guest_connection_ids.iter().copied(),
2119 |connection_id| {
2120 session
2121 .peer
2122 .forward_send(session.connection_id, connection_id, message.clone())
2123 },
2124 );
2125
2126 Ok(())
2127}
2128
2129/// Notify other participants that a language server has started.
2130async fn start_language_server(
2131 request: proto::StartLanguageServer,
2132 session: Session,
2133) -> Result<()> {
2134 let guest_connection_ids = session
2135 .db()
2136 .await
2137 .start_language_server(&request, session.connection_id)
2138 .await?;
2139
2140 broadcast(
2141 Some(session.connection_id),
2142 guest_connection_ids.iter().copied(),
2143 |connection_id| {
2144 session
2145 .peer
2146 .forward_send(session.connection_id, connection_id, request.clone())
2147 },
2148 );
2149 Ok(())
2150}
2151
2152/// Notify other participants that a language server has changed.
2153async fn update_language_server(
2154 request: proto::UpdateLanguageServer,
2155 session: Session,
2156) -> Result<()> {
2157 let project_id = ProjectId::from_proto(request.project_id);
2158 let project_connection_ids = session
2159 .db()
2160 .await
2161 .project_connection_ids(project_id, session.connection_id, true)
2162 .await?;
2163 broadcast(
2164 Some(session.connection_id),
2165 project_connection_ids.iter().copied(),
2166 |connection_id| {
2167 session
2168 .peer
2169 .forward_send(session.connection_id, connection_id, request.clone())
2170 },
2171 );
2172 Ok(())
2173}
2174
2175/// forward a project request to the host. These requests should be read only
2176/// as guests are allowed to send them.
2177async fn forward_read_only_project_request<T>(
2178 request: T,
2179 response: Response<T>,
2180 session: Session,
2181) -> Result<()>
2182where
2183 T: EntityMessage + RequestMessage,
2184{
2185 let project_id = ProjectId::from_proto(request.remote_entity_id());
2186 let host_connection_id = session
2187 .db()
2188 .await
2189 .host_for_read_only_project_request(project_id, session.connection_id)
2190 .await?;
2191 let payload = session
2192 .peer
2193 .forward_request(session.connection_id, host_connection_id, request)
2194 .await?;
2195 response.send(payload)?;
2196 Ok(())
2197}
2198
2199async fn forward_find_search_candidates_request(
2200 request: proto::FindSearchCandidates,
2201 response: Response<proto::FindSearchCandidates>,
2202 session: Session,
2203) -> Result<()> {
2204 let project_id = ProjectId::from_proto(request.remote_entity_id());
2205 let host_connection_id = session
2206 .db()
2207 .await
2208 .host_for_read_only_project_request(project_id, session.connection_id)
2209 .await?;
2210 let payload = session
2211 .peer
2212 .forward_request(session.connection_id, host_connection_id, request)
2213 .await?;
2214 response.send(payload)?;
2215 Ok(())
2216}
2217
2218/// forward a project request to the host. These requests are disallowed
2219/// for guests.
2220async fn forward_mutating_project_request<T>(
2221 request: T,
2222 response: Response<T>,
2223 session: Session,
2224) -> Result<()>
2225where
2226 T: EntityMessage + RequestMessage,
2227{
2228 let project_id = ProjectId::from_proto(request.remote_entity_id());
2229
2230 let host_connection_id = session
2231 .db()
2232 .await
2233 .host_for_mutating_project_request(project_id, session.connection_id)
2234 .await?;
2235 let payload = session
2236 .peer
2237 .forward_request(session.connection_id, host_connection_id, request)
2238 .await?;
2239 response.send(payload)?;
2240 Ok(())
2241}
2242
2243/// Notify other participants that a new buffer has been created
2244async fn create_buffer_for_peer(
2245 request: proto::CreateBufferForPeer,
2246 session: Session,
2247) -> Result<()> {
2248 session
2249 .db()
2250 .await
2251 .check_user_is_project_host(
2252 ProjectId::from_proto(request.project_id),
2253 session.connection_id,
2254 )
2255 .await?;
2256 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2257 session
2258 .peer
2259 .forward_send(session.connection_id, peer_id.into(), request)?;
2260 Ok(())
2261}
2262
2263/// Notify other participants that a buffer has been updated. This is
2264/// allowed for guests as long as the update is limited to selections.
2265async fn update_buffer(
2266 request: proto::UpdateBuffer,
2267 response: Response<proto::UpdateBuffer>,
2268 session: Session,
2269) -> Result<()> {
2270 let project_id = ProjectId::from_proto(request.project_id);
2271 let mut capability = Capability::ReadOnly;
2272
2273 for op in request.operations.iter() {
2274 match op.variant {
2275 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2276 Some(_) => capability = Capability::ReadWrite,
2277 }
2278 }
2279
2280 let host = {
2281 let guard = session
2282 .db()
2283 .await
2284 .connections_for_buffer_update(project_id, session.connection_id, capability)
2285 .await?;
2286
2287 let (host, guests) = &*guard;
2288
2289 broadcast(
2290 Some(session.connection_id),
2291 guests.clone(),
2292 |connection_id| {
2293 session
2294 .peer
2295 .forward_send(session.connection_id, connection_id, request.clone())
2296 },
2297 );
2298
2299 *host
2300 };
2301
2302 if host != session.connection_id {
2303 session
2304 .peer
2305 .forward_request(session.connection_id, host, request.clone())
2306 .await?;
2307 }
2308
2309 response.send(proto::Ack {})?;
2310 Ok(())
2311}
2312
2313async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2314 let project_id = ProjectId::from_proto(message.project_id);
2315
2316 let operation = message.operation.as_ref().context("invalid operation")?;
2317 let capability = match operation.variant.as_ref() {
2318 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2319 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2320 match buffer_op.variant {
2321 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2322 Capability::ReadOnly
2323 }
2324 _ => Capability::ReadWrite,
2325 }
2326 } else {
2327 Capability::ReadWrite
2328 }
2329 }
2330 Some(_) => Capability::ReadWrite,
2331 None => Capability::ReadOnly,
2332 };
2333
2334 let guard = session
2335 .db()
2336 .await
2337 .connections_for_buffer_update(project_id, session.connection_id, capability)
2338 .await?;
2339
2340 let (host, guests) = &*guard;
2341
2342 broadcast(
2343 Some(session.connection_id),
2344 guests.iter().chain([host]).copied(),
2345 |connection_id| {
2346 session
2347 .peer
2348 .forward_send(session.connection_id, connection_id, message.clone())
2349 },
2350 );
2351
2352 Ok(())
2353}
2354
2355/// Notify other participants that a project has been updated.
2356async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2357 request: T,
2358 session: Session,
2359) -> Result<()> {
2360 let project_id = ProjectId::from_proto(request.remote_entity_id());
2361 let project_connection_ids = session
2362 .db()
2363 .await
2364 .project_connection_ids(project_id, session.connection_id, false)
2365 .await?;
2366
2367 broadcast(
2368 Some(session.connection_id),
2369 project_connection_ids.iter().copied(),
2370 |connection_id| {
2371 session
2372 .peer
2373 .forward_send(session.connection_id, connection_id, request.clone())
2374 },
2375 );
2376 Ok(())
2377}
2378
2379/// Start following another user in a call.
2380async fn follow(
2381 request: proto::Follow,
2382 response: Response<proto::Follow>,
2383 session: Session,
2384) -> Result<()> {
2385 let room_id = RoomId::from_proto(request.room_id);
2386 let project_id = request.project_id.map(ProjectId::from_proto);
2387 let leader_id = request
2388 .leader_id
2389 .ok_or_else(|| anyhow!("invalid leader id"))?
2390 .into();
2391 let follower_id = session.connection_id;
2392
2393 session
2394 .db()
2395 .await
2396 .check_room_participants(room_id, leader_id, session.connection_id)
2397 .await?;
2398
2399 let response_payload = session
2400 .peer
2401 .forward_request(session.connection_id, leader_id, request)
2402 .await?;
2403 response.send(response_payload)?;
2404
2405 if let Some(project_id) = project_id {
2406 let room = session
2407 .db()
2408 .await
2409 .follow(room_id, project_id, leader_id, follower_id)
2410 .await?;
2411 room_updated(&room, &session.peer);
2412 }
2413
2414 Ok(())
2415}
2416
2417/// Stop following another user in a call.
2418async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2419 let room_id = RoomId::from_proto(request.room_id);
2420 let project_id = request.project_id.map(ProjectId::from_proto);
2421 let leader_id = request
2422 .leader_id
2423 .ok_or_else(|| anyhow!("invalid leader id"))?
2424 .into();
2425 let follower_id = session.connection_id;
2426
2427 session
2428 .db()
2429 .await
2430 .check_room_participants(room_id, leader_id, session.connection_id)
2431 .await?;
2432
2433 session
2434 .peer
2435 .forward_send(session.connection_id, leader_id, request)?;
2436
2437 if let Some(project_id) = project_id {
2438 let room = session
2439 .db()
2440 .await
2441 .unfollow(room_id, project_id, leader_id, follower_id)
2442 .await?;
2443 room_updated(&room, &session.peer);
2444 }
2445
2446 Ok(())
2447}
2448
2449/// Notify everyone following you of your current location.
2450async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2451 let room_id = RoomId::from_proto(request.room_id);
2452 let database = session.db.lock().await;
2453
2454 let connection_ids = if let Some(project_id) = request.project_id {
2455 let project_id = ProjectId::from_proto(project_id);
2456 database
2457 .project_connection_ids(project_id, session.connection_id, true)
2458 .await?
2459 } else {
2460 database
2461 .room_connection_ids(room_id, session.connection_id)
2462 .await?
2463 };
2464
2465 // For now, don't send view update messages back to that view's current leader.
2466 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2467 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2468 _ => None,
2469 });
2470
2471 for connection_id in connection_ids.iter().cloned() {
2472 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2473 session
2474 .peer
2475 .forward_send(session.connection_id, connection_id, request.clone())?;
2476 }
2477 }
2478 Ok(())
2479}
2480
2481/// Get public data about users.
2482async fn get_users(
2483 request: proto::GetUsers,
2484 response: Response<proto::GetUsers>,
2485 session: Session,
2486) -> Result<()> {
2487 let user_ids = request
2488 .user_ids
2489 .into_iter()
2490 .map(UserId::from_proto)
2491 .collect();
2492 let users = session
2493 .db()
2494 .await
2495 .get_users_by_ids(user_ids)
2496 .await?
2497 .into_iter()
2498 .map(|user| proto::User {
2499 id: user.id.to_proto(),
2500 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2501 github_login: user.github_login,
2502 email: user.email_address,
2503 name: user.name,
2504 })
2505 .collect();
2506 response.send(proto::UsersResponse { users })?;
2507 Ok(())
2508}
2509
2510/// Search for users (to invite) buy Github login
2511async fn fuzzy_search_users(
2512 request: proto::FuzzySearchUsers,
2513 response: Response<proto::FuzzySearchUsers>,
2514 session: Session,
2515) -> Result<()> {
2516 let query = request.query;
2517 let users = match query.len() {
2518 0 => vec![],
2519 1 | 2 => session
2520 .db()
2521 .await
2522 .get_user_by_github_login(&query)
2523 .await?
2524 .into_iter()
2525 .collect(),
2526 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2527 };
2528 let users = users
2529 .into_iter()
2530 .filter(|user| user.id != session.user_id())
2531 .map(|user| proto::User {
2532 id: user.id.to_proto(),
2533 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2534 github_login: user.github_login,
2535 name: user.name,
2536 email: user.email_address,
2537 })
2538 .collect();
2539 response.send(proto::UsersResponse { users })?;
2540 Ok(())
2541}
2542
2543/// Send a contact request to another user.
2544async fn request_contact(
2545 request: proto::RequestContact,
2546 response: Response<proto::RequestContact>,
2547 session: Session,
2548) -> Result<()> {
2549 let requester_id = session.user_id();
2550 let responder_id = UserId::from_proto(request.responder_id);
2551 if requester_id == responder_id {
2552 return Err(anyhow!("cannot add yourself as a contact"))?;
2553 }
2554
2555 let notifications = session
2556 .db()
2557 .await
2558 .send_contact_request(requester_id, responder_id)
2559 .await?;
2560
2561 // Update outgoing contact requests of requester
2562 let mut update = proto::UpdateContacts::default();
2563 update.outgoing_requests.push(responder_id.to_proto());
2564 for connection_id in session
2565 .connection_pool()
2566 .await
2567 .user_connection_ids(requester_id)
2568 {
2569 session.peer.send(connection_id, update.clone())?;
2570 }
2571
2572 // Update incoming contact requests of responder
2573 let mut update = proto::UpdateContacts::default();
2574 update
2575 .incoming_requests
2576 .push(proto::IncomingContactRequest {
2577 requester_id: requester_id.to_proto(),
2578 });
2579 let connection_pool = session.connection_pool().await;
2580 for connection_id in connection_pool.user_connection_ids(responder_id) {
2581 session.peer.send(connection_id, update.clone())?;
2582 }
2583
2584 send_notifications(&connection_pool, &session.peer, notifications);
2585
2586 response.send(proto::Ack {})?;
2587 Ok(())
2588}
2589
2590/// Accept or decline a contact request
2591async fn respond_to_contact_request(
2592 request: proto::RespondToContactRequest,
2593 response: Response<proto::RespondToContactRequest>,
2594 session: Session,
2595) -> Result<()> {
2596 let responder_id = session.user_id();
2597 let requester_id = UserId::from_proto(request.requester_id);
2598 let db = session.db().await;
2599 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2600 db.dismiss_contact_notification(responder_id, requester_id)
2601 .await?;
2602 } else {
2603 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2604
2605 let notifications = db
2606 .respond_to_contact_request(responder_id, requester_id, accept)
2607 .await?;
2608 let requester_busy = db.is_user_busy(requester_id).await?;
2609 let responder_busy = db.is_user_busy(responder_id).await?;
2610
2611 let pool = session.connection_pool().await;
2612 // Update responder with new contact
2613 let mut update = proto::UpdateContacts::default();
2614 if accept {
2615 update
2616 .contacts
2617 .push(contact_for_user(requester_id, requester_busy, &pool));
2618 }
2619 update
2620 .remove_incoming_requests
2621 .push(requester_id.to_proto());
2622 for connection_id in pool.user_connection_ids(responder_id) {
2623 session.peer.send(connection_id, update.clone())?;
2624 }
2625
2626 // Update requester with new contact
2627 let mut update = proto::UpdateContacts::default();
2628 if accept {
2629 update
2630 .contacts
2631 .push(contact_for_user(responder_id, responder_busy, &pool));
2632 }
2633 update
2634 .remove_outgoing_requests
2635 .push(responder_id.to_proto());
2636
2637 for connection_id in pool.user_connection_ids(requester_id) {
2638 session.peer.send(connection_id, update.clone())?;
2639 }
2640
2641 send_notifications(&pool, &session.peer, notifications);
2642 }
2643
2644 response.send(proto::Ack {})?;
2645 Ok(())
2646}
2647
2648/// Remove a contact.
2649async fn remove_contact(
2650 request: proto::RemoveContact,
2651 response: Response<proto::RemoveContact>,
2652 session: Session,
2653) -> Result<()> {
2654 let requester_id = session.user_id();
2655 let responder_id = UserId::from_proto(request.user_id);
2656 let db = session.db().await;
2657 let (contact_accepted, deleted_notification_id) =
2658 db.remove_contact(requester_id, responder_id).await?;
2659
2660 let pool = session.connection_pool().await;
2661 // Update outgoing contact requests of requester
2662 let mut update = proto::UpdateContacts::default();
2663 if contact_accepted {
2664 update.remove_contacts.push(responder_id.to_proto());
2665 } else {
2666 update
2667 .remove_outgoing_requests
2668 .push(responder_id.to_proto());
2669 }
2670 for connection_id in pool.user_connection_ids(requester_id) {
2671 session.peer.send(connection_id, update.clone())?;
2672 }
2673
2674 // Update incoming contact requests of responder
2675 let mut update = proto::UpdateContacts::default();
2676 if contact_accepted {
2677 update.remove_contacts.push(requester_id.to_proto());
2678 } else {
2679 update
2680 .remove_incoming_requests
2681 .push(requester_id.to_proto());
2682 }
2683 for connection_id in pool.user_connection_ids(responder_id) {
2684 session.peer.send(connection_id, update.clone())?;
2685 if let Some(notification_id) = deleted_notification_id {
2686 session.peer.send(
2687 connection_id,
2688 proto::DeleteNotification {
2689 notification_id: notification_id.to_proto(),
2690 },
2691 )?;
2692 }
2693 }
2694
2695 response.send(proto::Ack {})?;
2696 Ok(())
2697}
2698
2699fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2700 version.0.minor() < 139
2701}
2702
2703async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2704 let plan = session.current_plan(&session.db().await).await?;
2705
2706 session
2707 .peer
2708 .send(
2709 session.connection_id,
2710 proto::UpdateUserPlan { plan: plan.into() },
2711 )
2712 .trace_err();
2713
2714 Ok(())
2715}
2716
2717async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2718 subscribe_user_to_channels(session.user_id(), &session).await?;
2719 Ok(())
2720}
2721
2722async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2723 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2724 let mut pool = session.connection_pool().await;
2725 for membership in &channels_for_user.channel_memberships {
2726 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2727 }
2728 session.peer.send(
2729 session.connection_id,
2730 build_update_user_channels(&channels_for_user),
2731 )?;
2732 session.peer.send(
2733 session.connection_id,
2734 build_channels_update(channels_for_user),
2735 )?;
2736 Ok(())
2737}
2738
2739/// Creates a new channel.
2740async fn create_channel(
2741 request: proto::CreateChannel,
2742 response: Response<proto::CreateChannel>,
2743 session: Session,
2744) -> Result<()> {
2745 let db = session.db().await;
2746
2747 let parent_id = request.parent_id.map(ChannelId::from_proto);
2748 let (channel, membership) = db
2749 .create_channel(&request.name, parent_id, session.user_id())
2750 .await?;
2751
2752 let root_id = channel.root_id();
2753 let channel = Channel::from_model(channel);
2754
2755 response.send(proto::CreateChannelResponse {
2756 channel: Some(channel.to_proto()),
2757 parent_id: request.parent_id,
2758 })?;
2759
2760 let mut connection_pool = session.connection_pool().await;
2761 if let Some(membership) = membership {
2762 connection_pool.subscribe_to_channel(
2763 membership.user_id,
2764 membership.channel_id,
2765 membership.role,
2766 );
2767 let update = proto::UpdateUserChannels {
2768 channel_memberships: vec![proto::ChannelMembership {
2769 channel_id: membership.channel_id.to_proto(),
2770 role: membership.role.into(),
2771 }],
2772 ..Default::default()
2773 };
2774 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2775 session.peer.send(connection_id, update.clone())?;
2776 }
2777 }
2778
2779 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2780 if !role.can_see_channel(channel.visibility) {
2781 continue;
2782 }
2783
2784 let update = proto::UpdateChannels {
2785 channels: vec![channel.to_proto()],
2786 ..Default::default()
2787 };
2788 session.peer.send(connection_id, update.clone())?;
2789 }
2790
2791 Ok(())
2792}
2793
2794/// Delete a channel
2795async fn delete_channel(
2796 request: proto::DeleteChannel,
2797 response: Response<proto::DeleteChannel>,
2798 session: Session,
2799) -> Result<()> {
2800 let db = session.db().await;
2801
2802 let channel_id = request.channel_id;
2803 let (root_channel, removed_channels) = db
2804 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2805 .await?;
2806 response.send(proto::Ack {})?;
2807
2808 // Notify members of removed channels
2809 let mut update = proto::UpdateChannels::default();
2810 update
2811 .delete_channels
2812 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2813
2814 let connection_pool = session.connection_pool().await;
2815 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2816 session.peer.send(connection_id, update.clone())?;
2817 }
2818
2819 Ok(())
2820}
2821
2822/// Invite someone to join a channel.
2823async fn invite_channel_member(
2824 request: proto::InviteChannelMember,
2825 response: Response<proto::InviteChannelMember>,
2826 session: Session,
2827) -> Result<()> {
2828 let db = session.db().await;
2829 let channel_id = ChannelId::from_proto(request.channel_id);
2830 let invitee_id = UserId::from_proto(request.user_id);
2831 let InviteMemberResult {
2832 channel,
2833 notifications,
2834 } = db
2835 .invite_channel_member(
2836 channel_id,
2837 invitee_id,
2838 session.user_id(),
2839 request.role().into(),
2840 )
2841 .await?;
2842
2843 let update = proto::UpdateChannels {
2844 channel_invitations: vec![channel.to_proto()],
2845 ..Default::default()
2846 };
2847
2848 let connection_pool = session.connection_pool().await;
2849 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2850 session.peer.send(connection_id, update.clone())?;
2851 }
2852
2853 send_notifications(&connection_pool, &session.peer, notifications);
2854
2855 response.send(proto::Ack {})?;
2856 Ok(())
2857}
2858
2859/// remove someone from a channel
2860async fn remove_channel_member(
2861 request: proto::RemoveChannelMember,
2862 response: Response<proto::RemoveChannelMember>,
2863 session: Session,
2864) -> Result<()> {
2865 let db = session.db().await;
2866 let channel_id = ChannelId::from_proto(request.channel_id);
2867 let member_id = UserId::from_proto(request.user_id);
2868
2869 let RemoveChannelMemberResult {
2870 membership_update,
2871 notification_id,
2872 } = db
2873 .remove_channel_member(channel_id, member_id, session.user_id())
2874 .await?;
2875
2876 let mut connection_pool = session.connection_pool().await;
2877 notify_membership_updated(
2878 &mut connection_pool,
2879 membership_update,
2880 member_id,
2881 &session.peer,
2882 );
2883 for connection_id in connection_pool.user_connection_ids(member_id) {
2884 if let Some(notification_id) = notification_id {
2885 session
2886 .peer
2887 .send(
2888 connection_id,
2889 proto::DeleteNotification {
2890 notification_id: notification_id.to_proto(),
2891 },
2892 )
2893 .trace_err();
2894 }
2895 }
2896
2897 response.send(proto::Ack {})?;
2898 Ok(())
2899}
2900
2901/// Toggle the channel between public and private.
2902/// Care is taken to maintain the invariant that public channels only descend from public channels,
2903/// (though members-only channels can appear at any point in the hierarchy).
2904async fn set_channel_visibility(
2905 request: proto::SetChannelVisibility,
2906 response: Response<proto::SetChannelVisibility>,
2907 session: Session,
2908) -> Result<()> {
2909 let db = session.db().await;
2910 let channel_id = ChannelId::from_proto(request.channel_id);
2911 let visibility = request.visibility().into();
2912
2913 let channel_model = db
2914 .set_channel_visibility(channel_id, visibility, session.user_id())
2915 .await?;
2916 let root_id = channel_model.root_id();
2917 let channel = Channel::from_model(channel_model);
2918
2919 let mut connection_pool = session.connection_pool().await;
2920 for (user_id, role) in connection_pool
2921 .channel_user_ids(root_id)
2922 .collect::<Vec<_>>()
2923 .into_iter()
2924 {
2925 let update = if role.can_see_channel(channel.visibility) {
2926 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2927 proto::UpdateChannels {
2928 channels: vec![channel.to_proto()],
2929 ..Default::default()
2930 }
2931 } else {
2932 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2933 proto::UpdateChannels {
2934 delete_channels: vec![channel.id.to_proto()],
2935 ..Default::default()
2936 }
2937 };
2938
2939 for connection_id in connection_pool.user_connection_ids(user_id) {
2940 session.peer.send(connection_id, update.clone())?;
2941 }
2942 }
2943
2944 response.send(proto::Ack {})?;
2945 Ok(())
2946}
2947
2948/// Alter the role for a user in the channel.
2949async fn set_channel_member_role(
2950 request: proto::SetChannelMemberRole,
2951 response: Response<proto::SetChannelMemberRole>,
2952 session: Session,
2953) -> Result<()> {
2954 let db = session.db().await;
2955 let channel_id = ChannelId::from_proto(request.channel_id);
2956 let member_id = UserId::from_proto(request.user_id);
2957 let result = db
2958 .set_channel_member_role(
2959 channel_id,
2960 session.user_id(),
2961 member_id,
2962 request.role().into(),
2963 )
2964 .await?;
2965
2966 match result {
2967 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2968 let mut connection_pool = session.connection_pool().await;
2969 notify_membership_updated(
2970 &mut connection_pool,
2971 membership_update,
2972 member_id,
2973 &session.peer,
2974 )
2975 }
2976 db::SetMemberRoleResult::InviteUpdated(channel) => {
2977 let update = proto::UpdateChannels {
2978 channel_invitations: vec![channel.to_proto()],
2979 ..Default::default()
2980 };
2981
2982 for connection_id in session
2983 .connection_pool()
2984 .await
2985 .user_connection_ids(member_id)
2986 {
2987 session.peer.send(connection_id, update.clone())?;
2988 }
2989 }
2990 }
2991
2992 response.send(proto::Ack {})?;
2993 Ok(())
2994}
2995
2996/// Change the name of a channel
2997async fn rename_channel(
2998 request: proto::RenameChannel,
2999 response: Response<proto::RenameChannel>,
3000 session: Session,
3001) -> Result<()> {
3002 let db = session.db().await;
3003 let channel_id = ChannelId::from_proto(request.channel_id);
3004 let channel_model = db
3005 .rename_channel(channel_id, session.user_id(), &request.name)
3006 .await?;
3007 let root_id = channel_model.root_id();
3008 let channel = Channel::from_model(channel_model);
3009
3010 response.send(proto::RenameChannelResponse {
3011 channel: Some(channel.to_proto()),
3012 })?;
3013
3014 let connection_pool = session.connection_pool().await;
3015 let update = proto::UpdateChannels {
3016 channels: vec![channel.to_proto()],
3017 ..Default::default()
3018 };
3019 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3020 if role.can_see_channel(channel.visibility) {
3021 session.peer.send(connection_id, update.clone())?;
3022 }
3023 }
3024
3025 Ok(())
3026}
3027
3028/// Move a channel to a new parent.
3029async fn move_channel(
3030 request: proto::MoveChannel,
3031 response: Response<proto::MoveChannel>,
3032 session: Session,
3033) -> Result<()> {
3034 let channel_id = ChannelId::from_proto(request.channel_id);
3035 let to = ChannelId::from_proto(request.to);
3036
3037 let (root_id, channels) = session
3038 .db()
3039 .await
3040 .move_channel(channel_id, to, session.user_id())
3041 .await?;
3042
3043 let connection_pool = session.connection_pool().await;
3044 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3045 let channels = channels
3046 .iter()
3047 .filter_map(|channel| {
3048 if role.can_see_channel(channel.visibility) {
3049 Some(channel.to_proto())
3050 } else {
3051 None
3052 }
3053 })
3054 .collect::<Vec<_>>();
3055 if channels.is_empty() {
3056 continue;
3057 }
3058
3059 let update = proto::UpdateChannels {
3060 channels,
3061 ..Default::default()
3062 };
3063
3064 session.peer.send(connection_id, update.clone())?;
3065 }
3066
3067 response.send(Ack {})?;
3068 Ok(())
3069}
3070
3071/// Get the list of channel members
3072async fn get_channel_members(
3073 request: proto::GetChannelMembers,
3074 response: Response<proto::GetChannelMembers>,
3075 session: Session,
3076) -> Result<()> {
3077 let db = session.db().await;
3078 let channel_id = ChannelId::from_proto(request.channel_id);
3079 let limit = if request.limit == 0 {
3080 u16::MAX as u64
3081 } else {
3082 request.limit
3083 };
3084 let (members, users) = db
3085 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3086 .await?;
3087 response.send(proto::GetChannelMembersResponse { members, users })?;
3088 Ok(())
3089}
3090
3091/// Accept or decline a channel invitation.
3092async fn respond_to_channel_invite(
3093 request: proto::RespondToChannelInvite,
3094 response: Response<proto::RespondToChannelInvite>,
3095 session: Session,
3096) -> Result<()> {
3097 let db = session.db().await;
3098 let channel_id = ChannelId::from_proto(request.channel_id);
3099 let RespondToChannelInvite {
3100 membership_update,
3101 notifications,
3102 } = db
3103 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3104 .await?;
3105
3106 let mut connection_pool = session.connection_pool().await;
3107 if let Some(membership_update) = membership_update {
3108 notify_membership_updated(
3109 &mut connection_pool,
3110 membership_update,
3111 session.user_id(),
3112 &session.peer,
3113 );
3114 } else {
3115 let update = proto::UpdateChannels {
3116 remove_channel_invitations: vec![channel_id.to_proto()],
3117 ..Default::default()
3118 };
3119
3120 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3121 session.peer.send(connection_id, update.clone())?;
3122 }
3123 };
3124
3125 send_notifications(&connection_pool, &session.peer, notifications);
3126
3127 response.send(proto::Ack {})?;
3128
3129 Ok(())
3130}
3131
3132/// Join the channels' room
3133async fn join_channel(
3134 request: proto::JoinChannel,
3135 response: Response<proto::JoinChannel>,
3136 session: Session,
3137) -> Result<()> {
3138 let channel_id = ChannelId::from_proto(request.channel_id);
3139 join_channel_internal(channel_id, Box::new(response), session).await
3140}
3141
3142trait JoinChannelInternalResponse {
3143 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3144}
3145impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3146 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3147 Response::<proto::JoinChannel>::send(self, result)
3148 }
3149}
3150impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3151 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3152 Response::<proto::JoinRoom>::send(self, result)
3153 }
3154}
3155
3156async fn join_channel_internal(
3157 channel_id: ChannelId,
3158 response: Box<impl JoinChannelInternalResponse>,
3159 session: Session,
3160) -> Result<()> {
3161 let joined_room = {
3162 let mut db = session.db().await;
3163 // If zed quits without leaving the room, and the user re-opens zed before the
3164 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3165 // room they were in.
3166 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3167 tracing::info!(
3168 stale_connection_id = %connection,
3169 "cleaning up stale connection",
3170 );
3171 drop(db);
3172 leave_room_for_session(&session, connection).await?;
3173 db = session.db().await;
3174 }
3175
3176 let (joined_room, membership_updated, role) = db
3177 .join_channel(channel_id, session.user_id(), session.connection_id)
3178 .await?;
3179
3180 let live_kit_connection_info =
3181 session
3182 .app_state
3183 .livekit_client
3184 .as_ref()
3185 .and_then(|live_kit| {
3186 let (can_publish, token) = if role == ChannelRole::Guest {
3187 (
3188 false,
3189 live_kit
3190 .guest_token(
3191 &joined_room.room.livekit_room,
3192 &session.user_id().to_string(),
3193 )
3194 .trace_err()?,
3195 )
3196 } else {
3197 (
3198 true,
3199 live_kit
3200 .room_token(
3201 &joined_room.room.livekit_room,
3202 &session.user_id().to_string(),
3203 )
3204 .trace_err()?,
3205 )
3206 };
3207
3208 Some(LiveKitConnectionInfo {
3209 server_url: live_kit.url().into(),
3210 token,
3211 can_publish,
3212 })
3213 });
3214
3215 response.send(proto::JoinRoomResponse {
3216 room: Some(joined_room.room.clone()),
3217 channel_id: joined_room
3218 .channel
3219 .as_ref()
3220 .map(|channel| channel.id.to_proto()),
3221 live_kit_connection_info,
3222 })?;
3223
3224 let mut connection_pool = session.connection_pool().await;
3225 if let Some(membership_updated) = membership_updated {
3226 notify_membership_updated(
3227 &mut connection_pool,
3228 membership_updated,
3229 session.user_id(),
3230 &session.peer,
3231 );
3232 }
3233
3234 room_updated(&joined_room.room, &session.peer);
3235
3236 joined_room
3237 };
3238
3239 channel_updated(
3240 &joined_room
3241 .channel
3242 .ok_or_else(|| anyhow!("channel not returned"))?,
3243 &joined_room.room,
3244 &session.peer,
3245 &*session.connection_pool().await,
3246 );
3247
3248 update_user_contacts(session.user_id(), &session).await?;
3249 Ok(())
3250}
3251
3252/// Start editing the channel notes
3253async fn join_channel_buffer(
3254 request: proto::JoinChannelBuffer,
3255 response: Response<proto::JoinChannelBuffer>,
3256 session: Session,
3257) -> Result<()> {
3258 let db = session.db().await;
3259 let channel_id = ChannelId::from_proto(request.channel_id);
3260
3261 let open_response = db
3262 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3263 .await?;
3264
3265 let collaborators = open_response.collaborators.clone();
3266 response.send(open_response)?;
3267
3268 let update = UpdateChannelBufferCollaborators {
3269 channel_id: channel_id.to_proto(),
3270 collaborators: collaborators.clone(),
3271 };
3272 channel_buffer_updated(
3273 session.connection_id,
3274 collaborators
3275 .iter()
3276 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3277 &update,
3278 &session.peer,
3279 );
3280
3281 Ok(())
3282}
3283
3284/// Edit the channel notes
3285async fn update_channel_buffer(
3286 request: proto::UpdateChannelBuffer,
3287 session: Session,
3288) -> Result<()> {
3289 let db = session.db().await;
3290 let channel_id = ChannelId::from_proto(request.channel_id);
3291
3292 let (collaborators, epoch, version) = db
3293 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3294 .await?;
3295
3296 channel_buffer_updated(
3297 session.connection_id,
3298 collaborators.clone(),
3299 &proto::UpdateChannelBuffer {
3300 channel_id: channel_id.to_proto(),
3301 operations: request.operations,
3302 },
3303 &session.peer,
3304 );
3305
3306 let pool = &*session.connection_pool().await;
3307
3308 let non_collaborators =
3309 pool.channel_connection_ids(channel_id)
3310 .filter_map(|(connection_id, _)| {
3311 if collaborators.contains(&connection_id) {
3312 None
3313 } else {
3314 Some(connection_id)
3315 }
3316 });
3317
3318 broadcast(None, non_collaborators, |peer_id| {
3319 session.peer.send(
3320 peer_id,
3321 proto::UpdateChannels {
3322 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3323 channel_id: channel_id.to_proto(),
3324 epoch: epoch as u64,
3325 version: version.clone(),
3326 }],
3327 ..Default::default()
3328 },
3329 )
3330 });
3331
3332 Ok(())
3333}
3334
3335/// Rejoin the channel notes after a connection blip
3336async fn rejoin_channel_buffers(
3337 request: proto::RejoinChannelBuffers,
3338 response: Response<proto::RejoinChannelBuffers>,
3339 session: Session,
3340) -> Result<()> {
3341 let db = session.db().await;
3342 let buffers = db
3343 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3344 .await?;
3345
3346 for rejoined_buffer in &buffers {
3347 let collaborators_to_notify = rejoined_buffer
3348 .buffer
3349 .collaborators
3350 .iter()
3351 .filter_map(|c| Some(c.peer_id?.into()));
3352 channel_buffer_updated(
3353 session.connection_id,
3354 collaborators_to_notify,
3355 &proto::UpdateChannelBufferCollaborators {
3356 channel_id: rejoined_buffer.buffer.channel_id,
3357 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3358 },
3359 &session.peer,
3360 );
3361 }
3362
3363 response.send(proto::RejoinChannelBuffersResponse {
3364 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3365 })?;
3366
3367 Ok(())
3368}
3369
3370/// Stop editing the channel notes
3371async fn leave_channel_buffer(
3372 request: proto::LeaveChannelBuffer,
3373 response: Response<proto::LeaveChannelBuffer>,
3374 session: Session,
3375) -> Result<()> {
3376 let db = session.db().await;
3377 let channel_id = ChannelId::from_proto(request.channel_id);
3378
3379 let left_buffer = db
3380 .leave_channel_buffer(channel_id, session.connection_id)
3381 .await?;
3382
3383 response.send(Ack {})?;
3384
3385 channel_buffer_updated(
3386 session.connection_id,
3387 left_buffer.connections,
3388 &proto::UpdateChannelBufferCollaborators {
3389 channel_id: channel_id.to_proto(),
3390 collaborators: left_buffer.collaborators,
3391 },
3392 &session.peer,
3393 );
3394
3395 Ok(())
3396}
3397
3398fn channel_buffer_updated<T: EnvelopedMessage>(
3399 sender_id: ConnectionId,
3400 collaborators: impl IntoIterator<Item = ConnectionId>,
3401 message: &T,
3402 peer: &Peer,
3403) {
3404 broadcast(Some(sender_id), collaborators, |peer_id| {
3405 peer.send(peer_id, message.clone())
3406 });
3407}
3408
3409fn send_notifications(
3410 connection_pool: &ConnectionPool,
3411 peer: &Peer,
3412 notifications: db::NotificationBatch,
3413) {
3414 for (user_id, notification) in notifications {
3415 for connection_id in connection_pool.user_connection_ids(user_id) {
3416 if let Err(error) = peer.send(
3417 connection_id,
3418 proto::AddNotification {
3419 notification: Some(notification.clone()),
3420 },
3421 ) {
3422 tracing::error!(
3423 "failed to send notification to {:?} {}",
3424 connection_id,
3425 error
3426 );
3427 }
3428 }
3429 }
3430}
3431
3432/// Send a message to the channel
3433async fn send_channel_message(
3434 request: proto::SendChannelMessage,
3435 response: Response<proto::SendChannelMessage>,
3436 session: Session,
3437) -> Result<()> {
3438 // Validate the message body.
3439 let body = request.body.trim().to_string();
3440 if body.len() > MAX_MESSAGE_LEN {
3441 return Err(anyhow!("message is too long"))?;
3442 }
3443 if body.is_empty() {
3444 return Err(anyhow!("message can't be blank"))?;
3445 }
3446
3447 // TODO: adjust mentions if body is trimmed
3448
3449 let timestamp = OffsetDateTime::now_utc();
3450 let nonce = request
3451 .nonce
3452 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3453
3454 let channel_id = ChannelId::from_proto(request.channel_id);
3455 let CreatedChannelMessage {
3456 message_id,
3457 participant_connection_ids,
3458 notifications,
3459 } = session
3460 .db()
3461 .await
3462 .create_channel_message(
3463 channel_id,
3464 session.user_id(),
3465 &body,
3466 &request.mentions,
3467 timestamp,
3468 nonce.clone().into(),
3469 request.reply_to_message_id.map(MessageId::from_proto),
3470 )
3471 .await?;
3472
3473 let message = proto::ChannelMessage {
3474 sender_id: session.user_id().to_proto(),
3475 id: message_id.to_proto(),
3476 body,
3477 mentions: request.mentions,
3478 timestamp: timestamp.unix_timestamp() as u64,
3479 nonce: Some(nonce),
3480 reply_to_message_id: request.reply_to_message_id,
3481 edited_at: None,
3482 };
3483 broadcast(
3484 Some(session.connection_id),
3485 participant_connection_ids.clone(),
3486 |connection| {
3487 session.peer.send(
3488 connection,
3489 proto::ChannelMessageSent {
3490 channel_id: channel_id.to_proto(),
3491 message: Some(message.clone()),
3492 },
3493 )
3494 },
3495 );
3496 response.send(proto::SendChannelMessageResponse {
3497 message: Some(message),
3498 })?;
3499
3500 let pool = &*session.connection_pool().await;
3501 let non_participants =
3502 pool.channel_connection_ids(channel_id)
3503 .filter_map(|(connection_id, _)| {
3504 if participant_connection_ids.contains(&connection_id) {
3505 None
3506 } else {
3507 Some(connection_id)
3508 }
3509 });
3510 broadcast(None, non_participants, |peer_id| {
3511 session.peer.send(
3512 peer_id,
3513 proto::UpdateChannels {
3514 latest_channel_message_ids: vec![proto::ChannelMessageId {
3515 channel_id: channel_id.to_proto(),
3516 message_id: message_id.to_proto(),
3517 }],
3518 ..Default::default()
3519 },
3520 )
3521 });
3522 send_notifications(pool, &session.peer, notifications);
3523
3524 Ok(())
3525}
3526
3527/// Delete a channel message
3528async fn remove_channel_message(
3529 request: proto::RemoveChannelMessage,
3530 response: Response<proto::RemoveChannelMessage>,
3531 session: Session,
3532) -> Result<()> {
3533 let channel_id = ChannelId::from_proto(request.channel_id);
3534 let message_id = MessageId::from_proto(request.message_id);
3535 let (connection_ids, existing_notification_ids) = session
3536 .db()
3537 .await
3538 .remove_channel_message(channel_id, message_id, session.user_id())
3539 .await?;
3540
3541 broadcast(
3542 Some(session.connection_id),
3543 connection_ids,
3544 move |connection| {
3545 session.peer.send(connection, request.clone())?;
3546
3547 for notification_id in &existing_notification_ids {
3548 session.peer.send(
3549 connection,
3550 proto::DeleteNotification {
3551 notification_id: (*notification_id).to_proto(),
3552 },
3553 )?;
3554 }
3555
3556 Ok(())
3557 },
3558 );
3559 response.send(proto::Ack {})?;
3560 Ok(())
3561}
3562
3563async fn update_channel_message(
3564 request: proto::UpdateChannelMessage,
3565 response: Response<proto::UpdateChannelMessage>,
3566 session: Session,
3567) -> Result<()> {
3568 let channel_id = ChannelId::from_proto(request.channel_id);
3569 let message_id = MessageId::from_proto(request.message_id);
3570 let updated_at = OffsetDateTime::now_utc();
3571 let UpdatedChannelMessage {
3572 message_id,
3573 participant_connection_ids,
3574 notifications,
3575 reply_to_message_id,
3576 timestamp,
3577 deleted_mention_notification_ids,
3578 updated_mention_notifications,
3579 } = session
3580 .db()
3581 .await
3582 .update_channel_message(
3583 channel_id,
3584 message_id,
3585 session.user_id(),
3586 request.body.as_str(),
3587 &request.mentions,
3588 updated_at,
3589 )
3590 .await?;
3591
3592 let nonce = request
3593 .nonce
3594 .clone()
3595 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3596
3597 let message = proto::ChannelMessage {
3598 sender_id: session.user_id().to_proto(),
3599 id: message_id.to_proto(),
3600 body: request.body.clone(),
3601 mentions: request.mentions.clone(),
3602 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3603 nonce: Some(nonce),
3604 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3605 edited_at: Some(updated_at.unix_timestamp() as u64),
3606 };
3607
3608 response.send(proto::Ack {})?;
3609
3610 let pool = &*session.connection_pool().await;
3611 broadcast(
3612 Some(session.connection_id),
3613 participant_connection_ids,
3614 |connection| {
3615 session.peer.send(
3616 connection,
3617 proto::ChannelMessageUpdate {
3618 channel_id: channel_id.to_proto(),
3619 message: Some(message.clone()),
3620 },
3621 )?;
3622
3623 for notification_id in &deleted_mention_notification_ids {
3624 session.peer.send(
3625 connection,
3626 proto::DeleteNotification {
3627 notification_id: (*notification_id).to_proto(),
3628 },
3629 )?;
3630 }
3631
3632 for notification in &updated_mention_notifications {
3633 session.peer.send(
3634 connection,
3635 proto::UpdateNotification {
3636 notification: Some(notification.clone()),
3637 },
3638 )?;
3639 }
3640
3641 Ok(())
3642 },
3643 );
3644
3645 send_notifications(pool, &session.peer, notifications);
3646
3647 Ok(())
3648}
3649
3650/// Mark a channel message as read
3651async fn acknowledge_channel_message(
3652 request: proto::AckChannelMessage,
3653 session: Session,
3654) -> Result<()> {
3655 let channel_id = ChannelId::from_proto(request.channel_id);
3656 let message_id = MessageId::from_proto(request.message_id);
3657 let notifications = session
3658 .db()
3659 .await
3660 .observe_channel_message(channel_id, session.user_id(), message_id)
3661 .await?;
3662 send_notifications(
3663 &*session.connection_pool().await,
3664 &session.peer,
3665 notifications,
3666 );
3667 Ok(())
3668}
3669
3670/// Mark a buffer version as synced
3671async fn acknowledge_buffer_version(
3672 request: proto::AckBufferOperation,
3673 session: Session,
3674) -> Result<()> {
3675 let buffer_id = BufferId::from_proto(request.buffer_id);
3676 session
3677 .db()
3678 .await
3679 .observe_buffer_version(
3680 buffer_id,
3681 session.user_id(),
3682 request.epoch as i32,
3683 &request.version,
3684 )
3685 .await?;
3686 Ok(())
3687}
3688
3689async fn count_language_model_tokens(
3690 request: proto::CountLanguageModelTokens,
3691 response: Response<proto::CountLanguageModelTokens>,
3692 session: Session,
3693 config: &Config,
3694) -> Result<()> {
3695 authorize_access_to_legacy_llm_endpoints(&session).await?;
3696
3697 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3698 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3699 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3700 };
3701
3702 session
3703 .app_state
3704 .rate_limiter
3705 .check(&*rate_limit, session.user_id())
3706 .await?;
3707
3708 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3709 Some(proto::LanguageModelProvider::Google) => {
3710 let api_key = config
3711 .google_ai_api_key
3712 .as_ref()
3713 .context("no Google AI API key configured on the server")?;
3714 google_ai::count_tokens(
3715 session.http_client.as_ref(),
3716 google_ai::API_URL,
3717 api_key,
3718 serde_json::from_str(&request.request)?,
3719 )
3720 .await?
3721 }
3722 _ => return Err(anyhow!("unsupported provider"))?,
3723 };
3724
3725 response.send(proto::CountLanguageModelTokensResponse {
3726 token_count: result.total_tokens as u32,
3727 })?;
3728
3729 Ok(())
3730}
3731
3732struct ZedProCountLanguageModelTokensRateLimit;
3733
3734impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3735 fn capacity(&self) -> usize {
3736 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3737 .ok()
3738 .and_then(|v| v.parse().ok())
3739 .unwrap_or(600) // Picked arbitrarily
3740 }
3741
3742 fn refill_duration(&self) -> chrono::Duration {
3743 chrono::Duration::hours(1)
3744 }
3745
3746 fn db_name(&self) -> &'static str {
3747 "zed-pro:count-language-model-tokens"
3748 }
3749}
3750
3751struct FreeCountLanguageModelTokensRateLimit;
3752
3753impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3754 fn capacity(&self) -> usize {
3755 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3756 .ok()
3757 .and_then(|v| v.parse().ok())
3758 .unwrap_or(600 / 10) // Picked arbitrarily
3759 }
3760
3761 fn refill_duration(&self) -> chrono::Duration {
3762 chrono::Duration::hours(1)
3763 }
3764
3765 fn db_name(&self) -> &'static str {
3766 "free:count-language-model-tokens"
3767 }
3768}
3769
3770struct ZedProComputeEmbeddingsRateLimit;
3771
3772impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3773 fn capacity(&self) -> usize {
3774 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3775 .ok()
3776 .and_then(|v| v.parse().ok())
3777 .unwrap_or(5000) // Picked arbitrarily
3778 }
3779
3780 fn refill_duration(&self) -> chrono::Duration {
3781 chrono::Duration::hours(1)
3782 }
3783
3784 fn db_name(&self) -> &'static str {
3785 "zed-pro:compute-embeddings"
3786 }
3787}
3788
3789struct FreeComputeEmbeddingsRateLimit;
3790
3791impl RateLimit for FreeComputeEmbeddingsRateLimit {
3792 fn capacity(&self) -> usize {
3793 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3794 .ok()
3795 .and_then(|v| v.parse().ok())
3796 .unwrap_or(5000 / 10) // Picked arbitrarily
3797 }
3798
3799 fn refill_duration(&self) -> chrono::Duration {
3800 chrono::Duration::hours(1)
3801 }
3802
3803 fn db_name(&self) -> &'static str {
3804 "free:compute-embeddings"
3805 }
3806}
3807
3808async fn compute_embeddings(
3809 request: proto::ComputeEmbeddings,
3810 response: Response<proto::ComputeEmbeddings>,
3811 session: Session,
3812 api_key: Option<Arc<str>>,
3813) -> Result<()> {
3814 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3815 authorize_access_to_legacy_llm_endpoints(&session).await?;
3816
3817 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3818 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3819 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3820 };
3821
3822 session
3823 .app_state
3824 .rate_limiter
3825 .check(&*rate_limit, session.user_id())
3826 .await?;
3827
3828 let embeddings = match request.model.as_str() {
3829 "openai/text-embedding-3-small" => {
3830 open_ai::embed(
3831 session.http_client.as_ref(),
3832 OPEN_AI_API_URL,
3833 &api_key,
3834 OpenAiEmbeddingModel::TextEmbedding3Small,
3835 request.texts.iter().map(|text| text.as_str()),
3836 )
3837 .await?
3838 }
3839 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3840 };
3841
3842 let embeddings = request
3843 .texts
3844 .iter()
3845 .map(|text| {
3846 let mut hasher = sha2::Sha256::new();
3847 hasher.update(text.as_bytes());
3848 let result = hasher.finalize();
3849 result.to_vec()
3850 })
3851 .zip(
3852 embeddings
3853 .data
3854 .into_iter()
3855 .map(|embedding| embedding.embedding),
3856 )
3857 .collect::<HashMap<_, _>>();
3858
3859 let db = session.db().await;
3860 db.save_embeddings(&request.model, &embeddings)
3861 .await
3862 .context("failed to save embeddings")
3863 .trace_err();
3864
3865 response.send(proto::ComputeEmbeddingsResponse {
3866 embeddings: embeddings
3867 .into_iter()
3868 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3869 .collect(),
3870 })?;
3871 Ok(())
3872}
3873
3874async fn get_cached_embeddings(
3875 request: proto::GetCachedEmbeddings,
3876 response: Response<proto::GetCachedEmbeddings>,
3877 session: Session,
3878) -> Result<()> {
3879 authorize_access_to_legacy_llm_endpoints(&session).await?;
3880
3881 let db = session.db().await;
3882 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3883
3884 response.send(proto::GetCachedEmbeddingsResponse {
3885 embeddings: embeddings
3886 .into_iter()
3887 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3888 .collect(),
3889 })?;
3890 Ok(())
3891}
3892
3893/// This is leftover from before the LLM service.
3894///
3895/// The endpoints protected by this check will be moved there eventually.
3896async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3897 if session.is_staff() {
3898 Ok(())
3899 } else {
3900 Err(anyhow!("permission denied"))?
3901 }
3902}
3903
3904/// Get a Supermaven API key for the user
3905async fn get_supermaven_api_key(
3906 _request: proto::GetSupermavenApiKey,
3907 response: Response<proto::GetSupermavenApiKey>,
3908 session: Session,
3909) -> Result<()> {
3910 let user_id: String = session.user_id().to_string();
3911 if !session.is_staff() {
3912 return Err(anyhow!("supermaven not enabled for this account"))?;
3913 }
3914
3915 let email = session
3916 .email()
3917 .ok_or_else(|| anyhow!("user must have an email"))?;
3918
3919 let supermaven_admin_api = session
3920 .supermaven_client
3921 .as_ref()
3922 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3923
3924 let result = supermaven_admin_api
3925 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3926 .await?;
3927
3928 response.send(proto::GetSupermavenApiKeyResponse {
3929 api_key: result.api_key,
3930 })?;
3931
3932 Ok(())
3933}
3934
3935/// Start receiving chat updates for a channel
3936async fn join_channel_chat(
3937 request: proto::JoinChannelChat,
3938 response: Response<proto::JoinChannelChat>,
3939 session: Session,
3940) -> Result<()> {
3941 let channel_id = ChannelId::from_proto(request.channel_id);
3942
3943 let db = session.db().await;
3944 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3945 .await?;
3946 let messages = db
3947 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3948 .await?;
3949 response.send(proto::JoinChannelChatResponse {
3950 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3951 messages,
3952 })?;
3953 Ok(())
3954}
3955
3956/// Stop receiving chat updates for a channel
3957async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3958 let channel_id = ChannelId::from_proto(request.channel_id);
3959 session
3960 .db()
3961 .await
3962 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3963 .await?;
3964 Ok(())
3965}
3966
3967/// Retrieve the chat history for a channel
3968async fn get_channel_messages(
3969 request: proto::GetChannelMessages,
3970 response: Response<proto::GetChannelMessages>,
3971 session: Session,
3972) -> Result<()> {
3973 let channel_id = ChannelId::from_proto(request.channel_id);
3974 let messages = session
3975 .db()
3976 .await
3977 .get_channel_messages(
3978 channel_id,
3979 session.user_id(),
3980 MESSAGE_COUNT_PER_PAGE,
3981 Some(MessageId::from_proto(request.before_message_id)),
3982 )
3983 .await?;
3984 response.send(proto::GetChannelMessagesResponse {
3985 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3986 messages,
3987 })?;
3988 Ok(())
3989}
3990
3991/// Retrieve specific chat messages
3992async fn get_channel_messages_by_id(
3993 request: proto::GetChannelMessagesById,
3994 response: Response<proto::GetChannelMessagesById>,
3995 session: Session,
3996) -> Result<()> {
3997 let message_ids = request
3998 .message_ids
3999 .iter()
4000 .map(|id| MessageId::from_proto(*id))
4001 .collect::<Vec<_>>();
4002 let messages = session
4003 .db()
4004 .await
4005 .get_channel_messages_by_id(session.user_id(), &message_ids)
4006 .await?;
4007 response.send(proto::GetChannelMessagesResponse {
4008 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4009 messages,
4010 })?;
4011 Ok(())
4012}
4013
4014/// Retrieve the current users notifications
4015async fn get_notifications(
4016 request: proto::GetNotifications,
4017 response: Response<proto::GetNotifications>,
4018 session: Session,
4019) -> Result<()> {
4020 let notifications = session
4021 .db()
4022 .await
4023 .get_notifications(
4024 session.user_id(),
4025 NOTIFICATION_COUNT_PER_PAGE,
4026 request.before_id.map(db::NotificationId::from_proto),
4027 )
4028 .await?;
4029 response.send(proto::GetNotificationsResponse {
4030 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4031 notifications,
4032 })?;
4033 Ok(())
4034}
4035
4036/// Mark notifications as read
4037async fn mark_notification_as_read(
4038 request: proto::MarkNotificationRead,
4039 response: Response<proto::MarkNotificationRead>,
4040 session: Session,
4041) -> Result<()> {
4042 let database = &session.db().await;
4043 let notifications = database
4044 .mark_notification_as_read_by_id(
4045 session.user_id(),
4046 NotificationId::from_proto(request.notification_id),
4047 )
4048 .await?;
4049 send_notifications(
4050 &*session.connection_pool().await,
4051 &session.peer,
4052 notifications,
4053 );
4054 response.send(proto::Ack {})?;
4055 Ok(())
4056}
4057
4058/// Get the current users information
4059async fn get_private_user_info(
4060 _request: proto::GetPrivateUserInfo,
4061 response: Response<proto::GetPrivateUserInfo>,
4062 session: Session,
4063) -> Result<()> {
4064 let db = session.db().await;
4065
4066 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4067 let user = db
4068 .get_user_by_id(session.user_id())
4069 .await?
4070 .ok_or_else(|| anyhow!("user not found"))?;
4071 let flags = db.get_user_flags(session.user_id()).await?;
4072
4073 response.send(proto::GetPrivateUserInfoResponse {
4074 metrics_id,
4075 staff: user.admin,
4076 flags,
4077 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4078 })?;
4079 Ok(())
4080}
4081
4082/// Accept the terms of service (tos) on behalf of the current user
4083async fn accept_terms_of_service(
4084 _request: proto::AcceptTermsOfService,
4085 response: Response<proto::AcceptTermsOfService>,
4086 session: Session,
4087) -> Result<()> {
4088 let db = session.db().await;
4089
4090 let accepted_tos_at = Utc::now();
4091 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4092 .await?;
4093
4094 response.send(proto::AcceptTermsOfServiceResponse {
4095 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4096 })?;
4097 Ok(())
4098}
4099
4100/// The minimum account age an account must have in order to use the LLM service.
4101pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4102
4103async fn get_llm_api_token(
4104 _request: proto::GetLlmToken,
4105 response: Response<proto::GetLlmToken>,
4106 session: Session,
4107) -> Result<()> {
4108 let db = session.db().await;
4109
4110 let flags = db.get_user_flags(session.user_id()).await?;
4111 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4112
4113 if !session.is_staff() && !has_language_models_feature_flag {
4114 Err(anyhow!("permission denied"))?
4115 }
4116
4117 let user_id = session.user_id();
4118 let user = db
4119 .get_user_by_id(user_id)
4120 .await?
4121 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4122
4123 if user.accepted_tos_at.is_none() {
4124 Err(anyhow!("terms of service not accepted"))?
4125 }
4126
4127 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4128 let billing_preferences = db.get_billing_preferences(user.id).await?;
4129
4130 let token = LlmTokenClaims::create(
4131 &user,
4132 session.is_staff(),
4133 billing_preferences,
4134 &flags,
4135 has_llm_subscription,
4136 session.current_plan(&db).await?,
4137 session.system_id.clone(),
4138 &session.app_state.config,
4139 )?;
4140 response.send(proto::GetLlmTokenResponse { token })?;
4141 Ok(())
4142}
4143
4144fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4145 let message = match message {
4146 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4147 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4148 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4149 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4150 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4151 code: frame.code.into(),
4152 reason: frame.reason,
4153 })),
4154 // We should never receive a frame while reading the message, according
4155 // to the `tungstenite` maintainers:
4156 //
4157 // > It cannot occur when you read messages from the WebSocket, but it
4158 // > can be used when you want to send the raw frames (e.g. you want to
4159 // > send the frames to the WebSocket without composing the full message first).
4160 // >
4161 // > — https://github.com/snapview/tungstenite-rs/issues/268
4162 TungsteniteMessage::Frame(_) => {
4163 bail!("received an unexpected frame while reading the message")
4164 }
4165 };
4166
4167 Ok(message)
4168}
4169
4170fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4171 match message {
4172 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4173 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4174 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4175 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4176 AxumMessage::Close(frame) => {
4177 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4178 code: frame.code.into(),
4179 reason: frame.reason,
4180 }))
4181 }
4182 }
4183}
4184
4185fn notify_membership_updated(
4186 connection_pool: &mut ConnectionPool,
4187 result: MembershipUpdated,
4188 user_id: UserId,
4189 peer: &Peer,
4190) {
4191 for membership in &result.new_channels.channel_memberships {
4192 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4193 }
4194 for channel_id in &result.removed_channels {
4195 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4196 }
4197
4198 let user_channels_update = proto::UpdateUserChannels {
4199 channel_memberships: result
4200 .new_channels
4201 .channel_memberships
4202 .iter()
4203 .map(|cm| proto::ChannelMembership {
4204 channel_id: cm.channel_id.to_proto(),
4205 role: cm.role.into(),
4206 })
4207 .collect(),
4208 ..Default::default()
4209 };
4210
4211 let mut update = build_channels_update(result.new_channels);
4212 update.delete_channels = result
4213 .removed_channels
4214 .into_iter()
4215 .map(|id| id.to_proto())
4216 .collect();
4217 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4218
4219 for connection_id in connection_pool.user_connection_ids(user_id) {
4220 peer.send(connection_id, user_channels_update.clone())
4221 .trace_err();
4222 peer.send(connection_id, update.clone()).trace_err();
4223 }
4224}
4225
4226fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4227 proto::UpdateUserChannels {
4228 channel_memberships: channels
4229 .channel_memberships
4230 .iter()
4231 .map(|m| proto::ChannelMembership {
4232 channel_id: m.channel_id.to_proto(),
4233 role: m.role.into(),
4234 })
4235 .collect(),
4236 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4237 observed_channel_message_id: channels.observed_channel_messages.clone(),
4238 }
4239}
4240
4241fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4242 let mut update = proto::UpdateChannels::default();
4243
4244 for channel in channels.channels {
4245 update.channels.push(channel.to_proto());
4246 }
4247
4248 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4249 update.latest_channel_message_ids = channels.latest_channel_messages;
4250
4251 for (channel_id, participants) in channels.channel_participants {
4252 update
4253 .channel_participants
4254 .push(proto::ChannelParticipants {
4255 channel_id: channel_id.to_proto(),
4256 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4257 });
4258 }
4259
4260 for channel in channels.invited_channels {
4261 update.channel_invitations.push(channel.to_proto());
4262 }
4263
4264 update
4265}
4266
4267fn build_initial_contacts_update(
4268 contacts: Vec<db::Contact>,
4269 pool: &ConnectionPool,
4270) -> proto::UpdateContacts {
4271 let mut update = proto::UpdateContacts::default();
4272
4273 for contact in contacts {
4274 match contact {
4275 db::Contact::Accepted { user_id, busy } => {
4276 update.contacts.push(contact_for_user(user_id, busy, pool));
4277 }
4278 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4279 db::Contact::Incoming { user_id } => {
4280 update
4281 .incoming_requests
4282 .push(proto::IncomingContactRequest {
4283 requester_id: user_id.to_proto(),
4284 })
4285 }
4286 }
4287 }
4288
4289 update
4290}
4291
4292fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4293 proto::Contact {
4294 user_id: user_id.to_proto(),
4295 online: pool.is_user_online(user_id),
4296 busy,
4297 }
4298}
4299
4300fn room_updated(room: &proto::Room, peer: &Peer) {
4301 broadcast(
4302 None,
4303 room.participants
4304 .iter()
4305 .filter_map(|participant| Some(participant.peer_id?.into())),
4306 |peer_id| {
4307 peer.send(
4308 peer_id,
4309 proto::RoomUpdated {
4310 room: Some(room.clone()),
4311 },
4312 )
4313 },
4314 );
4315}
4316
4317fn channel_updated(
4318 channel: &db::channel::Model,
4319 room: &proto::Room,
4320 peer: &Peer,
4321 pool: &ConnectionPool,
4322) {
4323 let participants = room
4324 .participants
4325 .iter()
4326 .map(|p| p.user_id)
4327 .collect::<Vec<_>>();
4328
4329 broadcast(
4330 None,
4331 pool.channel_connection_ids(channel.root_id())
4332 .filter_map(|(channel_id, role)| {
4333 role.can_see_channel(channel.visibility)
4334 .then_some(channel_id)
4335 }),
4336 |peer_id| {
4337 peer.send(
4338 peer_id,
4339 proto::UpdateChannels {
4340 channel_participants: vec![proto::ChannelParticipants {
4341 channel_id: channel.id.to_proto(),
4342 participant_user_ids: participants.clone(),
4343 }],
4344 ..Default::default()
4345 },
4346 )
4347 },
4348 );
4349}
4350
4351async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4352 let db = session.db().await;
4353
4354 let contacts = db.get_contacts(user_id).await?;
4355 let busy = db.is_user_busy(user_id).await?;
4356
4357 let pool = session.connection_pool().await;
4358 let updated_contact = contact_for_user(user_id, busy, &pool);
4359 for contact in contacts {
4360 if let db::Contact::Accepted {
4361 user_id: contact_user_id,
4362 ..
4363 } = contact
4364 {
4365 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4366 session
4367 .peer
4368 .send(
4369 contact_conn_id,
4370 proto::UpdateContacts {
4371 contacts: vec![updated_contact.clone()],
4372 remove_contacts: Default::default(),
4373 incoming_requests: Default::default(),
4374 remove_incoming_requests: Default::default(),
4375 outgoing_requests: Default::default(),
4376 remove_outgoing_requests: Default::default(),
4377 },
4378 )
4379 .trace_err();
4380 }
4381 }
4382 }
4383 Ok(())
4384}
4385
4386async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4387 let mut contacts_to_update = HashSet::default();
4388
4389 let room_id;
4390 let canceled_calls_to_user_ids;
4391 let livekit_room;
4392 let delete_livekit_room;
4393 let room;
4394 let channel;
4395
4396 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4397 contacts_to_update.insert(session.user_id());
4398
4399 for project in left_room.left_projects.values() {
4400 project_left(project, session);
4401 }
4402
4403 room_id = RoomId::from_proto(left_room.room.id);
4404 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4405 livekit_room = mem::take(&mut left_room.room.livekit_room);
4406 delete_livekit_room = left_room.deleted;
4407 room = mem::take(&mut left_room.room);
4408 channel = mem::take(&mut left_room.channel);
4409
4410 room_updated(&room, &session.peer);
4411 } else {
4412 return Ok(());
4413 }
4414
4415 if let Some(channel) = channel {
4416 channel_updated(
4417 &channel,
4418 &room,
4419 &session.peer,
4420 &*session.connection_pool().await,
4421 );
4422 }
4423
4424 {
4425 let pool = session.connection_pool().await;
4426 for canceled_user_id in canceled_calls_to_user_ids {
4427 for connection_id in pool.user_connection_ids(canceled_user_id) {
4428 session
4429 .peer
4430 .send(
4431 connection_id,
4432 proto::CallCanceled {
4433 room_id: room_id.to_proto(),
4434 },
4435 )
4436 .trace_err();
4437 }
4438 contacts_to_update.insert(canceled_user_id);
4439 }
4440 }
4441
4442 for contact_user_id in contacts_to_update {
4443 update_user_contacts(contact_user_id, session).await?;
4444 }
4445
4446 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4447 live_kit
4448 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4449 .await
4450 .trace_err();
4451
4452 if delete_livekit_room {
4453 live_kit.delete_room(livekit_room).await.trace_err();
4454 }
4455 }
4456
4457 Ok(())
4458}
4459
4460async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4461 let left_channel_buffers = session
4462 .db()
4463 .await
4464 .leave_channel_buffers(session.connection_id)
4465 .await?;
4466
4467 for left_buffer in left_channel_buffers {
4468 channel_buffer_updated(
4469 session.connection_id,
4470 left_buffer.connections,
4471 &proto::UpdateChannelBufferCollaborators {
4472 channel_id: left_buffer.channel_id.to_proto(),
4473 collaborators: left_buffer.collaborators,
4474 },
4475 &session.peer,
4476 );
4477 }
4478
4479 Ok(())
4480}
4481
4482fn project_left(project: &db::LeftProject, session: &Session) {
4483 for connection_id in &project.connection_ids {
4484 if project.should_unshare {
4485 session
4486 .peer
4487 .send(
4488 *connection_id,
4489 proto::UnshareProject {
4490 project_id: project.id.to_proto(),
4491 },
4492 )
4493 .trace_err();
4494 } else {
4495 session
4496 .peer
4497 .send(
4498 *connection_id,
4499 proto::RemoveProjectCollaborator {
4500 project_id: project.id.to_proto(),
4501 peer_id: Some(session.connection_id.into()),
4502 },
4503 )
4504 .trace_err();
4505 }
4506 }
4507}
4508
4509pub trait ResultExt {
4510 type Ok;
4511
4512 fn trace_err(self) -> Option<Self::Ok>;
4513}
4514
4515impl<T, E> ResultExt for Result<T, E>
4516where
4517 E: std::fmt::Debug,
4518{
4519 type Ok = T;
4520
4521 #[track_caller]
4522 fn trace_err(self) -> Option<T> {
4523 match self {
4524 Ok(value) => Some(value),
4525 Err(error) => {
4526 tracing::error!("{:?}", error);
4527 None
4528 }
4529 }
4530 }
4531}