1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
10 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
11 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use chrono::Utc;
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use http_client::HttpClient;
38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
39use reqwest_client::ReqwestClient;
40use rpc::proto::split_repository_update;
41use sha2::Digest;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
46 TryStreamExt,
47};
48use prometheus::{register_int_gauge, IntGauge};
49use rpc::{
50 proto::{
51 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
52 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
53 },
54 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
55};
56use semantic_version::SemanticVersion;
57use serde::{Serialize, Serializer};
58use std::{
59 any::TypeId,
60 future::Future,
61 marker::PhantomData,
62 mem,
63 net::SocketAddr,
64 ops::{Deref, DerefMut},
65 rc::Rc,
66 sync::{
67 atomic::{AtomicBool, Ordering::SeqCst},
68 Arc, OnceLock,
69 },
70 time::{Duration, Instant},
71};
72use time::OffsetDateTime;
73use tokio::sync::{watch, MutexGuard, Semaphore};
74use tower::ServiceBuilder;
75use tracing::{
76 field::{self},
77 info_span, instrument, Instrument,
78};
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106#[derive(Clone, Debug)]
107pub enum Principal {
108 User(User),
109 Impersonated { user: User, admin: User },
110}
111
112impl Principal {
113 fn update_span(&self, span: &tracing::Span) {
114 match &self {
115 Principal::User(user) => {
116 span.record("user_id", user.id.0);
117 span.record("login", &user.github_login);
118 }
119 Principal::Impersonated { user, admin } => {
120 span.record("user_id", user.id.0);
121 span.record("login", &user.github_login);
122 span.record("impersonator", &admin.github_login);
123 }
124 }
125 }
126}
127
128#[derive(Clone)]
129struct Session {
130 principal: Principal,
131 connection_id: ConnectionId,
132 db: Arc<tokio::sync::Mutex<DbHandle>>,
133 peer: Arc<Peer>,
134 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
135 app_state: Arc<AppState>,
136 supermaven_client: Option<Arc<SupermavenAdminApi>>,
137 http_client: Arc<dyn HttpClient>,
138 /// The GeoIP country code for the user.
139 #[allow(unused)]
140 geoip_country_code: Option<String>,
141 system_id: Option<String>,
142 _executor: Executor,
143}
144
145impl Session {
146 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
147 #[cfg(test)]
148 tokio::task::yield_now().await;
149 let guard = self.db.lock().await;
150 #[cfg(test)]
151 tokio::task::yield_now().await;
152 guard
153 }
154
155 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 let guard = self.connection_pool.lock();
159 ConnectionPoolGuard {
160 guard,
161 _not_send: PhantomData,
162 }
163 }
164
165 fn is_staff(&self) -> bool {
166 match &self.principal {
167 Principal::User(user) => user.admin,
168 Principal::Impersonated { .. } => true,
169 }
170 }
171
172 pub async fn has_llm_subscription(
173 &self,
174 db: &MutexGuard<'_, DbHandle>,
175 ) -> anyhow::Result<bool> {
176 if self.is_staff() {
177 return Ok(true);
178 }
179
180 let user_id = self.user_id();
181
182 Ok(db.has_active_billing_subscription(user_id).await?)
183 }
184
185 pub async fn current_plan(
186 &self,
187 _db: &MutexGuard<'_, DbHandle>,
188 ) -> anyhow::Result<proto::Plan> {
189 if self.is_staff() {
190 Ok(proto::Plan::ZedPro)
191 } else {
192 Ok(proto::Plan::Free)
193 }
194 }
195
196 fn user_id(&self) -> UserId {
197 match &self.principal {
198 Principal::User(user) => user.id,
199 Principal::Impersonated { user, .. } => user.id,
200 }
201 }
202
203 pub fn email(&self) -> Option<String> {
204 match &self.principal {
205 Principal::User(user) => user.email_address.clone(),
206 Principal::Impersonated { user, .. } => user.email_address.clone(),
207 }
208 }
209}
210
211impl Debug for Session {
212 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
213 let mut result = f.debug_struct("Session");
214 match &self.principal {
215 Principal::User(user) => {
216 result.field("user", &user.github_login);
217 }
218 Principal::Impersonated { user, admin } => {
219 result.field("user", &user.github_login);
220 result.field("impersonator", &admin.github_login);
221 }
222 }
223 result.field("connection_id", &self.connection_id).finish()
224 }
225}
226
227struct DbHandle(Arc<Database>);
228
229impl Deref for DbHandle {
230 type Target = Database;
231
232 fn deref(&self) -> &Self::Target {
233 self.0.as_ref()
234 }
235}
236
237pub struct Server {
238 id: parking_lot::Mutex<ServerId>,
239 peer: Arc<Peer>,
240 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
241 app_state: Arc<AppState>,
242 handlers: HashMap<TypeId, MessageHandler>,
243 teardown: watch::Sender<bool>,
244}
245
246pub(crate) struct ConnectionPoolGuard<'a> {
247 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
248 _not_send: PhantomData<Rc<()>>,
249}
250
251#[derive(Serialize)]
252pub struct ServerSnapshot<'a> {
253 peer: &'a Peer,
254 #[serde(serialize_with = "serialize_deref")]
255 connection_pool: ConnectionPoolGuard<'a>,
256}
257
258pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
259where
260 S: Serializer,
261 T: Deref<Target = U>,
262 U: Serialize,
263{
264 Serialize::serialize(value.deref(), serializer)
265}
266
267impl Server {
268 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
269 let mut server = Self {
270 id: parking_lot::Mutex::new(id),
271 peer: Peer::new(id.0 as u32),
272 app_state: app_state.clone(),
273 connection_pool: Default::default(),
274 handlers: Default::default(),
275 teardown: watch::channel(false).0,
276 };
277
278 server
279 .add_request_handler(ping)
280 .add_request_handler(create_room)
281 .add_request_handler(join_room)
282 .add_request_handler(rejoin_room)
283 .add_request_handler(leave_room)
284 .add_request_handler(set_room_participant_role)
285 .add_request_handler(call)
286 .add_request_handler(cancel_call)
287 .add_message_handler(decline_call)
288 .add_request_handler(update_participant_location)
289 .add_request_handler(share_project)
290 .add_message_handler(unshare_project)
291 .add_request_handler(join_project)
292 .add_message_handler(leave_project)
293 .add_request_handler(update_project)
294 .add_request_handler(update_worktree)
295 .add_request_handler(update_repository)
296 .add_request_handler(remove_repository)
297 .add_message_handler(start_language_server)
298 .add_message_handler(update_language_server)
299 .add_message_handler(update_diagnostic_summary)
300 .add_message_handler(update_worktree_settings)
301 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
302 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
305 .add_request_handler(forward_find_search_candidates_request)
306 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
307 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
308 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
309 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
310 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
311 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
312 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
313 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
314 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
315 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
316 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
320 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(
325 forward_read_only_project_request::<proto::LanguageServerIdForName>,
326 )
327 .add_request_handler(
328 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
329 )
330 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
331 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
332 .add_request_handler(
333 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
334 )
335 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
336 .add_request_handler(
337 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
338 )
339 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
358 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
359 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
360 .add_message_handler(create_buffer_for_peer)
361 .add_request_handler(update_buffer)
362 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
363 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
364 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
365 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
366 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
367 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
368 .add_request_handler(get_users)
369 .add_request_handler(fuzzy_search_users)
370 .add_request_handler(request_contact)
371 .add_request_handler(remove_contact)
372 .add_request_handler(respond_to_contact_request)
373 .add_message_handler(subscribe_to_channels)
374 .add_request_handler(create_channel)
375 .add_request_handler(delete_channel)
376 .add_request_handler(invite_channel_member)
377 .add_request_handler(remove_channel_member)
378 .add_request_handler(set_channel_member_role)
379 .add_request_handler(set_channel_visibility)
380 .add_request_handler(rename_channel)
381 .add_request_handler(join_channel_buffer)
382 .add_request_handler(leave_channel_buffer)
383 .add_message_handler(update_channel_buffer)
384 .add_request_handler(rejoin_channel_buffers)
385 .add_request_handler(get_channel_members)
386 .add_request_handler(respond_to_channel_invite)
387 .add_request_handler(join_channel)
388 .add_request_handler(join_channel_chat)
389 .add_message_handler(leave_channel_chat)
390 .add_request_handler(send_channel_message)
391 .add_request_handler(remove_channel_message)
392 .add_request_handler(update_channel_message)
393 .add_request_handler(get_channel_messages)
394 .add_request_handler(get_channel_messages_by_id)
395 .add_request_handler(get_notifications)
396 .add_request_handler(mark_notification_as_read)
397 .add_request_handler(move_channel)
398 .add_request_handler(follow)
399 .add_message_handler(unfollow)
400 .add_message_handler(update_followers)
401 .add_request_handler(get_private_user_info)
402 .add_request_handler(get_llm_api_token)
403 .add_request_handler(accept_terms_of_service)
404 .add_message_handler(acknowledge_channel_message)
405 .add_message_handler(acknowledge_buffer_version)
406 .add_request_handler(get_supermaven_api_key)
407 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
408 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
409 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
410 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
411 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
412 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
413 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
414 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
415 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
416 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
417 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
418 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
419 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
420 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
421 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
422 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
423 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
424 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
425 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
426 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
427 .add_message_handler(update_context)
428 .add_request_handler({
429 let app_state = app_state.clone();
430 move |request, response, session| {
431 let app_state = app_state.clone();
432 async move {
433 count_language_model_tokens(request, response, session, &app_state.config)
434 .await
435 }
436 }
437 })
438 .add_request_handler(get_cached_embeddings)
439 .add_request_handler({
440 let app_state = app_state.clone();
441 move |request, response, session| {
442 compute_embeddings(
443 request,
444 response,
445 session,
446 app_state.config.openai_api_key.clone(),
447 )
448 }
449 });
450
451 Arc::new(server)
452 }
453
454 pub async fn start(&self) -> Result<()> {
455 let server_id = *self.id.lock();
456 let app_state = self.app_state.clone();
457 let peer = self.peer.clone();
458 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
459 let pool = self.connection_pool.clone();
460 let livekit_client = self.app_state.livekit_client.clone();
461
462 let span = info_span!("start server");
463 self.app_state.executor.spawn_detached(
464 async move {
465 tracing::info!("waiting for cleanup timeout");
466 timeout.await;
467 tracing::info!("cleanup timeout expired, retrieving stale rooms");
468 if let Some((room_ids, channel_ids)) = app_state
469 .db
470 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
471 .await
472 .trace_err()
473 {
474 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
475 tracing::info!(
476 stale_channel_buffer_count = channel_ids.len(),
477 "retrieved stale channel buffers"
478 );
479
480 for channel_id in channel_ids {
481 if let Some(refreshed_channel_buffer) = app_state
482 .db
483 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
484 .await
485 .trace_err()
486 {
487 for connection_id in refreshed_channel_buffer.connection_ids {
488 peer.send(
489 connection_id,
490 proto::UpdateChannelBufferCollaborators {
491 channel_id: channel_id.to_proto(),
492 collaborators: refreshed_channel_buffer
493 .collaborators
494 .clone(),
495 },
496 )
497 .trace_err();
498 }
499 }
500 }
501
502 for room_id in room_ids {
503 let mut contacts_to_update = HashSet::default();
504 let mut canceled_calls_to_user_ids = Vec::new();
505 let mut livekit_room = String::new();
506 let mut delete_livekit_room = false;
507
508 if let Some(mut refreshed_room) = app_state
509 .db
510 .clear_stale_room_participants(room_id, server_id)
511 .await
512 .trace_err()
513 {
514 tracing::info!(
515 room_id = room_id.0,
516 new_participant_count = refreshed_room.room.participants.len(),
517 "refreshed room"
518 );
519 room_updated(&refreshed_room.room, &peer);
520 if let Some(channel) = refreshed_room.channel.as_ref() {
521 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
522 }
523 contacts_to_update
524 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
525 contacts_to_update
526 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
527 canceled_calls_to_user_ids =
528 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
529 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
530 delete_livekit_room = refreshed_room.room.participants.is_empty();
531 }
532
533 {
534 let pool = pool.lock();
535 for canceled_user_id in canceled_calls_to_user_ids {
536 for connection_id in pool.user_connection_ids(canceled_user_id) {
537 peer.send(
538 connection_id,
539 proto::CallCanceled {
540 room_id: room_id.to_proto(),
541 },
542 )
543 .trace_err();
544 }
545 }
546 }
547
548 for user_id in contacts_to_update {
549 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
550 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
551 if let Some((busy, contacts)) = busy.zip(contacts) {
552 let pool = pool.lock();
553 let updated_contact = contact_for_user(user_id, busy, &pool);
554 for contact in contacts {
555 if let db::Contact::Accepted {
556 user_id: contact_user_id,
557 ..
558 } = contact
559 {
560 for contact_conn_id in
561 pool.user_connection_ids(contact_user_id)
562 {
563 peer.send(
564 contact_conn_id,
565 proto::UpdateContacts {
566 contacts: vec![updated_contact.clone()],
567 remove_contacts: Default::default(),
568 incoming_requests: Default::default(),
569 remove_incoming_requests: Default::default(),
570 outgoing_requests: Default::default(),
571 remove_outgoing_requests: Default::default(),
572 },
573 )
574 .trace_err();
575 }
576 }
577 }
578 }
579 }
580
581 if let Some(live_kit) = livekit_client.as_ref() {
582 if delete_livekit_room {
583 live_kit.delete_room(livekit_room).await.trace_err();
584 }
585 }
586 }
587 }
588
589 app_state
590 .db
591 .delete_stale_servers(&app_state.config.zed_environment, server_id)
592 .await
593 .trace_err();
594 }
595 .instrument(span),
596 );
597 Ok(())
598 }
599
600 pub fn teardown(&self) {
601 self.peer.teardown();
602 self.connection_pool.lock().reset();
603 let _ = self.teardown.send(true);
604 }
605
606 #[cfg(test)]
607 pub fn reset(&self, id: ServerId) {
608 self.teardown();
609 *self.id.lock() = id;
610 self.peer.reset(id.0 as u32);
611 let _ = self.teardown.send(false);
612 }
613
614 #[cfg(test)]
615 pub fn id(&self) -> ServerId {
616 *self.id.lock()
617 }
618
619 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
620 where
621 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
622 Fut: 'static + Send + Future<Output = Result<()>>,
623 M: EnvelopedMessage,
624 {
625 let prev_handler = self.handlers.insert(
626 TypeId::of::<M>(),
627 Box::new(move |envelope, session| {
628 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
629 let received_at = envelope.received_at;
630 tracing::info!("message received");
631 let start_time = Instant::now();
632 let future = (handler)(*envelope, session);
633 async move {
634 let result = future.await;
635 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
636 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
637 let queue_duration_ms = total_duration_ms - processing_duration_ms;
638 let payload_type = M::NAME;
639
640 match result {
641 Err(error) => {
642 tracing::error!(
643 ?error,
644 total_duration_ms,
645 processing_duration_ms,
646 queue_duration_ms,
647 payload_type,
648 "error handling message"
649 )
650 }
651 Ok(()) => tracing::info!(
652 total_duration_ms,
653 processing_duration_ms,
654 queue_duration_ms,
655 "finished handling message"
656 ),
657 }
658 }
659 .boxed()
660 }),
661 );
662 if prev_handler.is_some() {
663 panic!("registered a handler for the same message twice");
664 }
665 self
666 }
667
668 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
669 where
670 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
671 Fut: 'static + Send + Future<Output = Result<()>>,
672 M: EnvelopedMessage,
673 {
674 self.add_handler(move |envelope, session| handler(envelope.payload, session));
675 self
676 }
677
678 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
679 where
680 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
681 Fut: Send + Future<Output = Result<()>>,
682 M: RequestMessage,
683 {
684 let handler = Arc::new(handler);
685 self.add_handler(move |envelope, session| {
686 let receipt = envelope.receipt();
687 let handler = handler.clone();
688 async move {
689 let peer = session.peer.clone();
690 let responded = Arc::new(AtomicBool::default());
691 let response = Response {
692 peer: peer.clone(),
693 responded: responded.clone(),
694 receipt,
695 };
696 match (handler)(envelope.payload, response, session).await {
697 Ok(()) => {
698 if responded.load(std::sync::atomic::Ordering::SeqCst) {
699 Ok(())
700 } else {
701 Err(anyhow!("handler did not send a response"))?
702 }
703 }
704 Err(error) => {
705 let proto_err = match &error {
706 Error::Internal(err) => err.to_proto(),
707 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
708 };
709 peer.respond_with_error(receipt, proto_err)?;
710 Err(error)
711 }
712 }
713 }
714 })
715 }
716
717 pub fn handle_connection(
718 self: &Arc<Self>,
719 connection: Connection,
720 address: String,
721 principal: Principal,
722 zed_version: ZedVersion,
723 geoip_country_code: Option<String>,
724 system_id: Option<String>,
725 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
726 executor: Executor,
727 ) -> impl Future<Output = ()> + use<> {
728 let this = self.clone();
729 let span = info_span!("handle connection", %address,
730 connection_id=field::Empty,
731 user_id=field::Empty,
732 login=field::Empty,
733 impersonator=field::Empty,
734 geoip_country_code=field::Empty
735 );
736 principal.update_span(&span);
737 if let Some(country_code) = geoip_country_code.as_ref() {
738 span.record("geoip_country_code", country_code);
739 }
740
741 let mut teardown = self.teardown.subscribe();
742 async move {
743 if *teardown.borrow() {
744 tracing::error!("server is tearing down");
745 return
746 }
747 let (connection_id, handle_io, mut incoming_rx) = this
748 .peer
749 .add_connection(connection, {
750 let executor = executor.clone();
751 move |duration| executor.sleep(duration)
752 });
753 tracing::Span::current().record("connection_id", format!("{}", connection_id));
754
755 tracing::info!("connection opened");
756
757 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
758 let http_client = match ReqwestClient::user_agent(&user_agent) {
759 Ok(http_client) => Arc::new(http_client),
760 Err(error) => {
761 tracing::error!(?error, "failed to create HTTP client");
762 return;
763 }
764 };
765
766 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
767 supermaven_admin_api_key.to_string(),
768 http_client.clone(),
769 )));
770
771 let session = Session {
772 principal: principal.clone(),
773 connection_id,
774 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
775 peer: this.peer.clone(),
776 connection_pool: this.connection_pool.clone(),
777 app_state: this.app_state.clone(),
778 http_client,
779 geoip_country_code,
780 system_id,
781 _executor: executor.clone(),
782 supermaven_client,
783 };
784
785 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
786 tracing::error!(?error, "failed to send initial client update");
787 return;
788 }
789
790 let handle_io = handle_io.fuse();
791 futures::pin_mut!(handle_io);
792
793 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
794 // This prevents deadlocks when e.g., client A performs a request to client B and
795 // client B performs a request to client A. If both clients stop processing further
796 // messages until their respective request completes, they won't have a chance to
797 // respond to the other client's request and cause a deadlock.
798 //
799 // This arrangement ensures we will attempt to process earlier messages first, but fall
800 // back to processing messages arrived later in the spirit of making progress.
801 let mut foreground_message_handlers = FuturesUnordered::new();
802 let concurrent_handlers = Arc::new(Semaphore::new(256));
803 loop {
804 let next_message = async {
805 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
806 let message = incoming_rx.next().await;
807 (permit, message)
808 }.fuse();
809 futures::pin_mut!(next_message);
810 futures::select_biased! {
811 _ = teardown.changed().fuse() => return,
812 result = handle_io => {
813 if let Err(error) = result {
814 tracing::error!(?error, "error handling I/O");
815 }
816 break;
817 }
818 _ = foreground_message_handlers.next() => {}
819 next_message = next_message => {
820 let (permit, message) = next_message;
821 if let Some(message) = message {
822 let type_name = message.payload_type_name();
823 // note: we copy all the fields from the parent span so we can query them in the logs.
824 // (https://github.com/tokio-rs/tracing/issues/2670).
825 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
826 user_id=field::Empty,
827 login=field::Empty,
828 impersonator=field::Empty,
829 );
830 principal.update_span(&span);
831 let span_enter = span.enter();
832 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
833 let is_background = message.is_background();
834 let handle_message = (handler)(message, session.clone());
835 drop(span_enter);
836
837 let handle_message = async move {
838 handle_message.await;
839 drop(permit);
840 }.instrument(span);
841 if is_background {
842 executor.spawn_detached(handle_message);
843 } else {
844 foreground_message_handlers.push(handle_message);
845 }
846 } else {
847 tracing::error!("no message handler");
848 }
849 } else {
850 tracing::info!("connection closed");
851 break;
852 }
853 }
854 }
855 }
856
857 drop(foreground_message_handlers);
858 tracing::info!("signing out");
859 if let Err(error) = connection_lost(session, teardown, executor).await {
860 tracing::error!(?error, "error signing out");
861 }
862
863 }.instrument(span)
864 }
865
866 async fn send_initial_client_update(
867 &self,
868 connection_id: ConnectionId,
869 principal: &Principal,
870 zed_version: ZedVersion,
871 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
872 session: &Session,
873 ) -> Result<()> {
874 self.peer.send(
875 connection_id,
876 proto::Hello {
877 peer_id: Some(connection_id.into()),
878 },
879 )?;
880 tracing::info!("sent hello message");
881 if let Some(send_connection_id) = send_connection_id.take() {
882 let _ = send_connection_id.send(connection_id);
883 }
884
885 match principal {
886 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
887 if !user.connected_once {
888 self.peer.send(connection_id, proto::ShowContacts {})?;
889 self.app_state
890 .db
891 .set_user_connected_once(user.id, true)
892 .await?;
893 }
894
895 update_user_plan(user.id, session).await?;
896
897 let contacts = self.app_state.db.get_contacts(user.id).await?;
898
899 {
900 let mut pool = self.connection_pool.lock();
901 pool.add_connection(connection_id, user.id, user.admin, zed_version);
902 self.peer.send(
903 connection_id,
904 build_initial_contacts_update(contacts, &pool),
905 )?;
906 }
907
908 if should_auto_subscribe_to_channels(zed_version) {
909 subscribe_user_to_channels(user.id, session).await?;
910 }
911
912 if let Some(incoming_call) =
913 self.app_state.db.incoming_call_for_user(user.id).await?
914 {
915 self.peer.send(connection_id, incoming_call)?;
916 }
917
918 update_user_contacts(user.id, session).await?;
919 }
920 }
921
922 Ok(())
923 }
924
925 pub async fn invite_code_redeemed(
926 self: &Arc<Self>,
927 inviter_id: UserId,
928 invitee_id: UserId,
929 ) -> Result<()> {
930 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
931 if let Some(code) = &user.invite_code {
932 let pool = self.connection_pool.lock();
933 let invitee_contact = contact_for_user(invitee_id, false, &pool);
934 for connection_id in pool.user_connection_ids(inviter_id) {
935 self.peer.send(
936 connection_id,
937 proto::UpdateContacts {
938 contacts: vec![invitee_contact.clone()],
939 ..Default::default()
940 },
941 )?;
942 self.peer.send(
943 connection_id,
944 proto::UpdateInviteInfo {
945 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
946 count: user.invite_count as u32,
947 },
948 )?;
949 }
950 }
951 }
952 Ok(())
953 }
954
955 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
956 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
957 if let Some(invite_code) = &user.invite_code {
958 let pool = self.connection_pool.lock();
959 for connection_id in pool.user_connection_ids(user_id) {
960 self.peer.send(
961 connection_id,
962 proto::UpdateInviteInfo {
963 url: format!(
964 "{}{}",
965 self.app_state.config.invite_link_prefix, invite_code
966 ),
967 count: user.invite_count as u32,
968 },
969 )?;
970 }
971 }
972 }
973 Ok(())
974 }
975
976 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
977 let pool = self.connection_pool.lock();
978 for connection_id in pool.user_connection_ids(user_id) {
979 self.peer
980 .send(connection_id, proto::RefreshLlmToken {})
981 .trace_err();
982 }
983 }
984
985 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
986 ServerSnapshot {
987 connection_pool: ConnectionPoolGuard {
988 guard: self.connection_pool.lock(),
989 _not_send: PhantomData,
990 },
991 peer: &self.peer,
992 }
993 }
994}
995
996impl Deref for ConnectionPoolGuard<'_> {
997 type Target = ConnectionPool;
998
999 fn deref(&self) -> &Self::Target {
1000 &self.guard
1001 }
1002}
1003
1004impl DerefMut for ConnectionPoolGuard<'_> {
1005 fn deref_mut(&mut self) -> &mut Self::Target {
1006 &mut self.guard
1007 }
1008}
1009
1010impl Drop for ConnectionPoolGuard<'_> {
1011 fn drop(&mut self) {
1012 #[cfg(test)]
1013 self.check_invariants();
1014 }
1015}
1016
1017fn broadcast<F>(
1018 sender_id: Option<ConnectionId>,
1019 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1020 mut f: F,
1021) where
1022 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1023{
1024 for receiver_id in receiver_ids {
1025 if Some(receiver_id) != sender_id {
1026 if let Err(error) = f(receiver_id) {
1027 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1028 }
1029 }
1030 }
1031}
1032
1033pub struct ProtocolVersion(u32);
1034
1035impl Header for ProtocolVersion {
1036 fn name() -> &'static HeaderName {
1037 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1038 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1039 }
1040
1041 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1042 where
1043 Self: Sized,
1044 I: Iterator<Item = &'i axum::http::HeaderValue>,
1045 {
1046 let version = values
1047 .next()
1048 .ok_or_else(axum::headers::Error::invalid)?
1049 .to_str()
1050 .map_err(|_| axum::headers::Error::invalid())?
1051 .parse()
1052 .map_err(|_| axum::headers::Error::invalid())?;
1053 Ok(Self(version))
1054 }
1055
1056 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1057 values.extend([self.0.to_string().parse().unwrap()]);
1058 }
1059}
1060
1061pub struct AppVersionHeader(SemanticVersion);
1062impl Header for AppVersionHeader {
1063 fn name() -> &'static HeaderName {
1064 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1065 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1066 }
1067
1068 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1069 where
1070 Self: Sized,
1071 I: Iterator<Item = &'i axum::http::HeaderValue>,
1072 {
1073 let version = values
1074 .next()
1075 .ok_or_else(axum::headers::Error::invalid)?
1076 .to_str()
1077 .map_err(|_| axum::headers::Error::invalid())?
1078 .parse()
1079 .map_err(|_| axum::headers::Error::invalid())?;
1080 Ok(Self(version))
1081 }
1082
1083 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1084 values.extend([self.0.to_string().parse().unwrap()]);
1085 }
1086}
1087
1088pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1089 Router::new()
1090 .route("/rpc", get(handle_websocket_request))
1091 .layer(
1092 ServiceBuilder::new()
1093 .layer(Extension(server.app_state.clone()))
1094 .layer(middleware::from_fn(auth::validate_header)),
1095 )
1096 .route("/metrics", get(handle_metrics))
1097 .layer(Extension(server))
1098}
1099
1100pub async fn handle_websocket_request(
1101 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1102 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1103 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1104 Extension(server): Extension<Arc<Server>>,
1105 Extension(principal): Extension<Principal>,
1106 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1107 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1108 ws: WebSocketUpgrade,
1109) -> axum::response::Response {
1110 if protocol_version != rpc::PROTOCOL_VERSION {
1111 return (
1112 StatusCode::UPGRADE_REQUIRED,
1113 "client must be upgraded".to_string(),
1114 )
1115 .into_response();
1116 }
1117
1118 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1119 return (
1120 StatusCode::UPGRADE_REQUIRED,
1121 "no version header found".to_string(),
1122 )
1123 .into_response();
1124 };
1125
1126 if !version.can_collaborate() {
1127 return (
1128 StatusCode::UPGRADE_REQUIRED,
1129 "client must be upgraded".to_string(),
1130 )
1131 .into_response();
1132 }
1133
1134 let socket_address = socket_address.to_string();
1135 ws.on_upgrade(move |socket| {
1136 let socket = socket
1137 .map_ok(to_tungstenite_message)
1138 .err_into()
1139 .with(|message| async move { to_axum_message(message) });
1140 let connection = Connection::new(Box::pin(socket));
1141 async move {
1142 server
1143 .handle_connection(
1144 connection,
1145 socket_address,
1146 principal,
1147 version,
1148 country_code_header.map(|header| header.to_string()),
1149 system_id_header.map(|header| header.to_string()),
1150 None,
1151 Executor::Production,
1152 )
1153 .await;
1154 }
1155 })
1156}
1157
1158pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1159 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1160 let connections_metric = CONNECTIONS_METRIC
1161 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1162
1163 let connections = server
1164 .connection_pool
1165 .lock()
1166 .connections()
1167 .filter(|connection| !connection.admin)
1168 .count();
1169 connections_metric.set(connections as _);
1170
1171 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1172 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1173 register_int_gauge!(
1174 "shared_projects",
1175 "number of open projects with one or more guests"
1176 )
1177 .unwrap()
1178 });
1179
1180 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1181 shared_projects_metric.set(shared_projects as _);
1182
1183 let encoder = prometheus::TextEncoder::new();
1184 let metric_families = prometheus::gather();
1185 let encoded_metrics = encoder
1186 .encode_to_string(&metric_families)
1187 .map_err(|err| anyhow!("{}", err))?;
1188 Ok(encoded_metrics)
1189}
1190
1191#[instrument(err, skip(executor))]
1192async fn connection_lost(
1193 session: Session,
1194 mut teardown: watch::Receiver<bool>,
1195 executor: Executor,
1196) -> Result<()> {
1197 session.peer.disconnect(session.connection_id);
1198 session
1199 .connection_pool()
1200 .await
1201 .remove_connection(session.connection_id)?;
1202
1203 session
1204 .db()
1205 .await
1206 .connection_lost(session.connection_id)
1207 .await
1208 .trace_err();
1209
1210 futures::select_biased! {
1211 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1212
1213 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1214 leave_room_for_session(&session, session.connection_id).await.trace_err();
1215 leave_channel_buffers_for_session(&session)
1216 .await
1217 .trace_err();
1218
1219 if !session
1220 .connection_pool()
1221 .await
1222 .is_user_online(session.user_id())
1223 {
1224 let db = session.db().await;
1225 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1226 room_updated(&room, &session.peer);
1227 }
1228 }
1229
1230 update_user_contacts(session.user_id(), &session).await?;
1231 },
1232 _ = teardown.changed().fuse() => {}
1233 }
1234
1235 Ok(())
1236}
1237
1238/// Acknowledges a ping from a client, used to keep the connection alive.
1239async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1240 response.send(proto::Ack {})?;
1241 Ok(())
1242}
1243
1244/// Creates a new room for calling (outside of channels)
1245async fn create_room(
1246 _request: proto::CreateRoom,
1247 response: Response<proto::CreateRoom>,
1248 session: Session,
1249) -> Result<()> {
1250 let livekit_room = nanoid::nanoid!(30);
1251
1252 let live_kit_connection_info = util::maybe!(async {
1253 let live_kit = session.app_state.livekit_client.as_ref();
1254 let live_kit = live_kit?;
1255 let user_id = session.user_id().to_string();
1256
1257 let token = live_kit
1258 .room_token(&livekit_room, &user_id.to_string())
1259 .trace_err()?;
1260
1261 Some(proto::LiveKitConnectionInfo {
1262 server_url: live_kit.url().into(),
1263 token,
1264 can_publish: true,
1265 })
1266 })
1267 .await;
1268
1269 let room = session
1270 .db()
1271 .await
1272 .create_room(session.user_id(), session.connection_id, &livekit_room)
1273 .await?;
1274
1275 response.send(proto::CreateRoomResponse {
1276 room: Some(room.clone()),
1277 live_kit_connection_info,
1278 })?;
1279
1280 update_user_contacts(session.user_id(), &session).await?;
1281 Ok(())
1282}
1283
1284/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1285async fn join_room(
1286 request: proto::JoinRoom,
1287 response: Response<proto::JoinRoom>,
1288 session: Session,
1289) -> Result<()> {
1290 let room_id = RoomId::from_proto(request.id);
1291
1292 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1293
1294 if let Some(channel_id) = channel_id {
1295 return join_channel_internal(channel_id, Box::new(response), session).await;
1296 }
1297
1298 let joined_room = {
1299 let room = session
1300 .db()
1301 .await
1302 .join_room(room_id, session.user_id(), session.connection_id)
1303 .await?;
1304 room_updated(&room.room, &session.peer);
1305 room.into_inner()
1306 };
1307
1308 for connection_id in session
1309 .connection_pool()
1310 .await
1311 .user_connection_ids(session.user_id())
1312 {
1313 session
1314 .peer
1315 .send(
1316 connection_id,
1317 proto::CallCanceled {
1318 room_id: room_id.to_proto(),
1319 },
1320 )
1321 .trace_err();
1322 }
1323
1324 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1325 {
1326 live_kit
1327 .room_token(
1328 &joined_room.room.livekit_room,
1329 &session.user_id().to_string(),
1330 )
1331 .trace_err()
1332 .map(|token| proto::LiveKitConnectionInfo {
1333 server_url: live_kit.url().into(),
1334 token,
1335 can_publish: true,
1336 })
1337 } else {
1338 None
1339 };
1340
1341 response.send(proto::JoinRoomResponse {
1342 room: Some(joined_room.room),
1343 channel_id: None,
1344 live_kit_connection_info,
1345 })?;
1346
1347 update_user_contacts(session.user_id(), &session).await?;
1348 Ok(())
1349}
1350
1351/// Rejoin room is used to reconnect to a room after connection errors.
1352async fn rejoin_room(
1353 request: proto::RejoinRoom,
1354 response: Response<proto::RejoinRoom>,
1355 session: Session,
1356) -> Result<()> {
1357 let room;
1358 let channel;
1359 {
1360 let mut rejoined_room = session
1361 .db()
1362 .await
1363 .rejoin_room(request, session.user_id(), session.connection_id)
1364 .await?;
1365
1366 response.send(proto::RejoinRoomResponse {
1367 room: Some(rejoined_room.room.clone()),
1368 reshared_projects: rejoined_room
1369 .reshared_projects
1370 .iter()
1371 .map(|project| proto::ResharedProject {
1372 id: project.id.to_proto(),
1373 collaborators: project
1374 .collaborators
1375 .iter()
1376 .map(|collaborator| collaborator.to_proto())
1377 .collect(),
1378 })
1379 .collect(),
1380 rejoined_projects: rejoined_room
1381 .rejoined_projects
1382 .iter()
1383 .map(|rejoined_project| rejoined_project.to_proto())
1384 .collect(),
1385 })?;
1386 room_updated(&rejoined_room.room, &session.peer);
1387
1388 for project in &rejoined_room.reshared_projects {
1389 for collaborator in &project.collaborators {
1390 session
1391 .peer
1392 .send(
1393 collaborator.connection_id,
1394 proto::UpdateProjectCollaborator {
1395 project_id: project.id.to_proto(),
1396 old_peer_id: Some(project.old_connection_id.into()),
1397 new_peer_id: Some(session.connection_id.into()),
1398 },
1399 )
1400 .trace_err();
1401 }
1402
1403 broadcast(
1404 Some(session.connection_id),
1405 project
1406 .collaborators
1407 .iter()
1408 .map(|collaborator| collaborator.connection_id),
1409 |connection_id| {
1410 session.peer.forward_send(
1411 session.connection_id,
1412 connection_id,
1413 proto::UpdateProject {
1414 project_id: project.id.to_proto(),
1415 worktrees: project.worktrees.clone(),
1416 },
1417 )
1418 },
1419 );
1420 }
1421
1422 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1423
1424 let rejoined_room = rejoined_room.into_inner();
1425
1426 room = rejoined_room.room;
1427 channel = rejoined_room.channel;
1428 }
1429
1430 if let Some(channel) = channel {
1431 channel_updated(
1432 &channel,
1433 &room,
1434 &session.peer,
1435 &*session.connection_pool().await,
1436 );
1437 }
1438
1439 update_user_contacts(session.user_id(), &session).await?;
1440 Ok(())
1441}
1442
1443fn notify_rejoined_projects(
1444 rejoined_projects: &mut Vec<RejoinedProject>,
1445 session: &Session,
1446) -> Result<()> {
1447 for project in rejoined_projects.iter() {
1448 for collaborator in &project.collaborators {
1449 session
1450 .peer
1451 .send(
1452 collaborator.connection_id,
1453 proto::UpdateProjectCollaborator {
1454 project_id: project.id.to_proto(),
1455 old_peer_id: Some(project.old_connection_id.into()),
1456 new_peer_id: Some(session.connection_id.into()),
1457 },
1458 )
1459 .trace_err();
1460 }
1461 }
1462
1463 for project in rejoined_projects {
1464 for worktree in mem::take(&mut project.worktrees) {
1465 // Stream this worktree's entries.
1466 let message = proto::UpdateWorktree {
1467 project_id: project.id.to_proto(),
1468 worktree_id: worktree.id,
1469 abs_path: worktree.abs_path.clone(),
1470 root_name: worktree.root_name,
1471 updated_entries: worktree.updated_entries,
1472 removed_entries: worktree.removed_entries,
1473 scan_id: worktree.scan_id,
1474 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1475 updated_repositories: worktree.updated_repositories,
1476 removed_repositories: worktree.removed_repositories,
1477 };
1478 for update in proto::split_worktree_update(message) {
1479 session.peer.send(session.connection_id, update)?;
1480 }
1481
1482 // Stream this worktree's diagnostics.
1483 for summary in worktree.diagnostic_summaries {
1484 session.peer.send(
1485 session.connection_id,
1486 proto::UpdateDiagnosticSummary {
1487 project_id: project.id.to_proto(),
1488 worktree_id: worktree.id,
1489 summary: Some(summary),
1490 },
1491 )?;
1492 }
1493
1494 for settings_file in worktree.settings_files {
1495 session.peer.send(
1496 session.connection_id,
1497 proto::UpdateWorktreeSettings {
1498 project_id: project.id.to_proto(),
1499 worktree_id: worktree.id,
1500 path: settings_file.path,
1501 content: Some(settings_file.content),
1502 kind: Some(settings_file.kind.to_proto().into()),
1503 },
1504 )?;
1505 }
1506 }
1507
1508 for repository in mem::take(&mut project.updated_repositories) {
1509 for update in split_repository_update(repository) {
1510 session.peer.send(session.connection_id, update)?;
1511 }
1512 }
1513
1514 for id in mem::take(&mut project.removed_repositories) {
1515 session.peer.send(
1516 session.connection_id,
1517 proto::RemoveRepository {
1518 project_id: project.id.to_proto(),
1519 id,
1520 },
1521 )?;
1522 }
1523 }
1524
1525 Ok(())
1526}
1527
1528/// leave room disconnects from the room.
1529async fn leave_room(
1530 _: proto::LeaveRoom,
1531 response: Response<proto::LeaveRoom>,
1532 session: Session,
1533) -> Result<()> {
1534 leave_room_for_session(&session, session.connection_id).await?;
1535 response.send(proto::Ack {})?;
1536 Ok(())
1537}
1538
1539/// Updates the permissions of someone else in the room.
1540async fn set_room_participant_role(
1541 request: proto::SetRoomParticipantRole,
1542 response: Response<proto::SetRoomParticipantRole>,
1543 session: Session,
1544) -> Result<()> {
1545 let user_id = UserId::from_proto(request.user_id);
1546 let role = ChannelRole::from(request.role());
1547
1548 let (livekit_room, can_publish) = {
1549 let room = session
1550 .db()
1551 .await
1552 .set_room_participant_role(
1553 session.user_id(),
1554 RoomId::from_proto(request.room_id),
1555 user_id,
1556 role,
1557 )
1558 .await?;
1559
1560 let livekit_room = room.livekit_room.clone();
1561 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1562 room_updated(&room, &session.peer);
1563 (livekit_room, can_publish)
1564 };
1565
1566 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1567 live_kit
1568 .update_participant(
1569 livekit_room.clone(),
1570 request.user_id.to_string(),
1571 livekit_api::proto::ParticipantPermission {
1572 can_subscribe: true,
1573 can_publish,
1574 can_publish_data: can_publish,
1575 hidden: false,
1576 recorder: false,
1577 },
1578 )
1579 .await
1580 .trace_err();
1581 }
1582
1583 response.send(proto::Ack {})?;
1584 Ok(())
1585}
1586
1587/// Call someone else into the current room
1588async fn call(
1589 request: proto::Call,
1590 response: Response<proto::Call>,
1591 session: Session,
1592) -> Result<()> {
1593 let room_id = RoomId::from_proto(request.room_id);
1594 let calling_user_id = session.user_id();
1595 let calling_connection_id = session.connection_id;
1596 let called_user_id = UserId::from_proto(request.called_user_id);
1597 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1598 if !session
1599 .db()
1600 .await
1601 .has_contact(calling_user_id, called_user_id)
1602 .await?
1603 {
1604 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1605 }
1606
1607 let incoming_call = {
1608 let (room, incoming_call) = &mut *session
1609 .db()
1610 .await
1611 .call(
1612 room_id,
1613 calling_user_id,
1614 calling_connection_id,
1615 called_user_id,
1616 initial_project_id,
1617 )
1618 .await?;
1619 room_updated(room, &session.peer);
1620 mem::take(incoming_call)
1621 };
1622 update_user_contacts(called_user_id, &session).await?;
1623
1624 let mut calls = session
1625 .connection_pool()
1626 .await
1627 .user_connection_ids(called_user_id)
1628 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1629 .collect::<FuturesUnordered<_>>();
1630
1631 while let Some(call_response) = calls.next().await {
1632 match call_response.as_ref() {
1633 Ok(_) => {
1634 response.send(proto::Ack {})?;
1635 return Ok(());
1636 }
1637 Err(_) => {
1638 call_response.trace_err();
1639 }
1640 }
1641 }
1642
1643 {
1644 let room = session
1645 .db()
1646 .await
1647 .call_failed(room_id, called_user_id)
1648 .await?;
1649 room_updated(&room, &session.peer);
1650 }
1651 update_user_contacts(called_user_id, &session).await?;
1652
1653 Err(anyhow!("failed to ring user"))?
1654}
1655
1656/// Cancel an outgoing call.
1657async fn cancel_call(
1658 request: proto::CancelCall,
1659 response: Response<proto::CancelCall>,
1660 session: Session,
1661) -> Result<()> {
1662 let called_user_id = UserId::from_proto(request.called_user_id);
1663 let room_id = RoomId::from_proto(request.room_id);
1664 {
1665 let room = session
1666 .db()
1667 .await
1668 .cancel_call(room_id, session.connection_id, called_user_id)
1669 .await?;
1670 room_updated(&room, &session.peer);
1671 }
1672
1673 for connection_id in session
1674 .connection_pool()
1675 .await
1676 .user_connection_ids(called_user_id)
1677 {
1678 session
1679 .peer
1680 .send(
1681 connection_id,
1682 proto::CallCanceled {
1683 room_id: room_id.to_proto(),
1684 },
1685 )
1686 .trace_err();
1687 }
1688 response.send(proto::Ack {})?;
1689
1690 update_user_contacts(called_user_id, &session).await?;
1691 Ok(())
1692}
1693
1694/// Decline an incoming call.
1695async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1696 let room_id = RoomId::from_proto(message.room_id);
1697 {
1698 let room = session
1699 .db()
1700 .await
1701 .decline_call(Some(room_id), session.user_id())
1702 .await?
1703 .ok_or_else(|| anyhow!("failed to decline call"))?;
1704 room_updated(&room, &session.peer);
1705 }
1706
1707 for connection_id in session
1708 .connection_pool()
1709 .await
1710 .user_connection_ids(session.user_id())
1711 {
1712 session
1713 .peer
1714 .send(
1715 connection_id,
1716 proto::CallCanceled {
1717 room_id: room_id.to_proto(),
1718 },
1719 )
1720 .trace_err();
1721 }
1722 update_user_contacts(session.user_id(), &session).await?;
1723 Ok(())
1724}
1725
1726/// Updates other participants in the room with your current location.
1727async fn update_participant_location(
1728 request: proto::UpdateParticipantLocation,
1729 response: Response<proto::UpdateParticipantLocation>,
1730 session: Session,
1731) -> Result<()> {
1732 let room_id = RoomId::from_proto(request.room_id);
1733 let location = request
1734 .location
1735 .ok_or_else(|| anyhow!("invalid location"))?;
1736
1737 let db = session.db().await;
1738 let room = db
1739 .update_room_participant_location(room_id, session.connection_id, location)
1740 .await?;
1741
1742 room_updated(&room, &session.peer);
1743 response.send(proto::Ack {})?;
1744 Ok(())
1745}
1746
1747/// Share a project into the room.
1748async fn share_project(
1749 request: proto::ShareProject,
1750 response: Response<proto::ShareProject>,
1751 session: Session,
1752) -> Result<()> {
1753 let (project_id, room) = &*session
1754 .db()
1755 .await
1756 .share_project(
1757 RoomId::from_proto(request.room_id),
1758 session.connection_id,
1759 &request.worktrees,
1760 request.is_ssh_project,
1761 )
1762 .await?;
1763 response.send(proto::ShareProjectResponse {
1764 project_id: project_id.to_proto(),
1765 })?;
1766 room_updated(room, &session.peer);
1767
1768 Ok(())
1769}
1770
1771/// Unshare a project from the room.
1772async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1773 let project_id = ProjectId::from_proto(message.project_id);
1774 unshare_project_internal(project_id, session.connection_id, &session).await
1775}
1776
1777async fn unshare_project_internal(
1778 project_id: ProjectId,
1779 connection_id: ConnectionId,
1780 session: &Session,
1781) -> Result<()> {
1782 let delete = {
1783 let room_guard = session
1784 .db()
1785 .await
1786 .unshare_project(project_id, connection_id)
1787 .await?;
1788
1789 let (delete, room, guest_connection_ids) = &*room_guard;
1790
1791 let message = proto::UnshareProject {
1792 project_id: project_id.to_proto(),
1793 };
1794
1795 broadcast(
1796 Some(connection_id),
1797 guest_connection_ids.iter().copied(),
1798 |conn_id| session.peer.send(conn_id, message.clone()),
1799 );
1800 if let Some(room) = room {
1801 room_updated(room, &session.peer);
1802 }
1803
1804 *delete
1805 };
1806
1807 if delete {
1808 let db = session.db().await;
1809 db.delete_project(project_id).await?;
1810 }
1811
1812 Ok(())
1813}
1814
1815/// Join someone elses shared project.
1816async fn join_project(
1817 request: proto::JoinProject,
1818 response: Response<proto::JoinProject>,
1819 session: Session,
1820) -> Result<()> {
1821 let project_id = ProjectId::from_proto(request.project_id);
1822
1823 tracing::info!(%project_id, "join project");
1824
1825 let db = session.db().await;
1826 let (project, replica_id) = &mut *db
1827 .join_project(project_id, session.connection_id, session.user_id())
1828 .await?;
1829 drop(db);
1830 tracing::info!(%project_id, "join remote project");
1831 join_project_internal(response, session, project, replica_id)
1832}
1833
1834trait JoinProjectInternalResponse {
1835 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1836}
1837impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1838 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1839 Response::<proto::JoinProject>::send(self, result)
1840 }
1841}
1842
1843fn join_project_internal(
1844 response: impl JoinProjectInternalResponse,
1845 session: Session,
1846 project: &mut Project,
1847 replica_id: &ReplicaId,
1848) -> Result<()> {
1849 let collaborators = project
1850 .collaborators
1851 .iter()
1852 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1853 .map(|collaborator| collaborator.to_proto())
1854 .collect::<Vec<_>>();
1855 let project_id = project.id;
1856 let guest_user_id = session.user_id();
1857
1858 let worktrees = project
1859 .worktrees
1860 .iter()
1861 .map(|(id, worktree)| proto::WorktreeMetadata {
1862 id: *id,
1863 root_name: worktree.root_name.clone(),
1864 visible: worktree.visible,
1865 abs_path: worktree.abs_path.clone(),
1866 })
1867 .collect::<Vec<_>>();
1868
1869 let add_project_collaborator = proto::AddProjectCollaborator {
1870 project_id: project_id.to_proto(),
1871 collaborator: Some(proto::Collaborator {
1872 peer_id: Some(session.connection_id.into()),
1873 replica_id: replica_id.0 as u32,
1874 user_id: guest_user_id.to_proto(),
1875 is_host: false,
1876 }),
1877 };
1878
1879 for collaborator in &collaborators {
1880 session
1881 .peer
1882 .send(
1883 collaborator.peer_id.unwrap().into(),
1884 add_project_collaborator.clone(),
1885 )
1886 .trace_err();
1887 }
1888
1889 // First, we send the metadata associated with each worktree.
1890 response.send(proto::JoinProjectResponse {
1891 project_id: project.id.0 as u64,
1892 worktrees: worktrees.clone(),
1893 replica_id: replica_id.0 as u32,
1894 collaborators: collaborators.clone(),
1895 language_servers: project.language_servers.clone(),
1896 role: project.role.into(),
1897 })?;
1898
1899 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1900 // Stream this worktree's entries.
1901 let message = proto::UpdateWorktree {
1902 project_id: project_id.to_proto(),
1903 worktree_id,
1904 abs_path: worktree.abs_path.clone(),
1905 root_name: worktree.root_name,
1906 updated_entries: worktree.entries,
1907 removed_entries: Default::default(),
1908 scan_id: worktree.scan_id,
1909 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1910 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1911 removed_repositories: Default::default(),
1912 };
1913 for update in proto::split_worktree_update(message) {
1914 session.peer.send(session.connection_id, update.clone())?;
1915 }
1916
1917 // Stream this worktree's diagnostics.
1918 for summary in worktree.diagnostic_summaries {
1919 session.peer.send(
1920 session.connection_id,
1921 proto::UpdateDiagnosticSummary {
1922 project_id: project_id.to_proto(),
1923 worktree_id: worktree.id,
1924 summary: Some(summary),
1925 },
1926 )?;
1927 }
1928
1929 for settings_file in worktree.settings_files {
1930 session.peer.send(
1931 session.connection_id,
1932 proto::UpdateWorktreeSettings {
1933 project_id: project_id.to_proto(),
1934 worktree_id: worktree.id,
1935 path: settings_file.path,
1936 content: Some(settings_file.content),
1937 kind: Some(settings_file.kind.to_proto() as i32),
1938 },
1939 )?;
1940 }
1941 }
1942
1943 for repository in mem::take(&mut project.repositories) {
1944 for update in split_repository_update(repository) {
1945 session.peer.send(session.connection_id, update)?;
1946 }
1947 }
1948
1949 for language_server in &project.language_servers {
1950 session.peer.send(
1951 session.connection_id,
1952 proto::UpdateLanguageServer {
1953 project_id: project_id.to_proto(),
1954 language_server_id: language_server.id,
1955 variant: Some(
1956 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1957 proto::LspDiskBasedDiagnosticsUpdated {},
1958 ),
1959 ),
1960 },
1961 )?;
1962 }
1963
1964 Ok(())
1965}
1966
1967/// Leave someone elses shared project.
1968async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1969 let sender_id = session.connection_id;
1970 let project_id = ProjectId::from_proto(request.project_id);
1971 let db = session.db().await;
1972
1973 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1974 tracing::info!(
1975 %project_id,
1976 "leave project"
1977 );
1978
1979 project_left(project, &session);
1980 if let Some(room) = room {
1981 room_updated(room, &session.peer);
1982 }
1983
1984 Ok(())
1985}
1986
1987/// Updates other participants with changes to the project
1988async fn update_project(
1989 request: proto::UpdateProject,
1990 response: Response<proto::UpdateProject>,
1991 session: Session,
1992) -> Result<()> {
1993 let project_id = ProjectId::from_proto(request.project_id);
1994 let (room, guest_connection_ids) = &*session
1995 .db()
1996 .await
1997 .update_project(project_id, session.connection_id, &request.worktrees)
1998 .await?;
1999 broadcast(
2000 Some(session.connection_id),
2001 guest_connection_ids.iter().copied(),
2002 |connection_id| {
2003 session
2004 .peer
2005 .forward_send(session.connection_id, connection_id, request.clone())
2006 },
2007 );
2008 if let Some(room) = room {
2009 room_updated(room, &session.peer);
2010 }
2011 response.send(proto::Ack {})?;
2012
2013 Ok(())
2014}
2015
2016/// Updates other participants with changes to the worktree
2017async fn update_worktree(
2018 request: proto::UpdateWorktree,
2019 response: Response<proto::UpdateWorktree>,
2020 session: Session,
2021) -> Result<()> {
2022 let guest_connection_ids = session
2023 .db()
2024 .await
2025 .update_worktree(&request, session.connection_id)
2026 .await?;
2027
2028 broadcast(
2029 Some(session.connection_id),
2030 guest_connection_ids.iter().copied(),
2031 |connection_id| {
2032 session
2033 .peer
2034 .forward_send(session.connection_id, connection_id, request.clone())
2035 },
2036 );
2037 response.send(proto::Ack {})?;
2038 Ok(())
2039}
2040
2041async fn update_repository(
2042 request: proto::UpdateRepository,
2043 response: Response<proto::UpdateRepository>,
2044 session: Session,
2045) -> Result<()> {
2046 let guest_connection_ids = session
2047 .db()
2048 .await
2049 .update_repository(&request, session.connection_id)
2050 .await?;
2051
2052 broadcast(
2053 Some(session.connection_id),
2054 guest_connection_ids.iter().copied(),
2055 |connection_id| {
2056 session
2057 .peer
2058 .forward_send(session.connection_id, connection_id, request.clone())
2059 },
2060 );
2061 response.send(proto::Ack {})?;
2062 Ok(())
2063}
2064
2065async fn remove_repository(
2066 request: proto::RemoveRepository,
2067 response: Response<proto::RemoveRepository>,
2068 session: Session,
2069) -> Result<()> {
2070 let guest_connection_ids = session
2071 .db()
2072 .await
2073 .remove_repository(&request, session.connection_id)
2074 .await?;
2075
2076 broadcast(
2077 Some(session.connection_id),
2078 guest_connection_ids.iter().copied(),
2079 |connection_id| {
2080 session
2081 .peer
2082 .forward_send(session.connection_id, connection_id, request.clone())
2083 },
2084 );
2085 response.send(proto::Ack {})?;
2086 Ok(())
2087}
2088
2089/// Updates other participants with changes to the diagnostics
2090async fn update_diagnostic_summary(
2091 message: proto::UpdateDiagnosticSummary,
2092 session: Session,
2093) -> Result<()> {
2094 let guest_connection_ids = session
2095 .db()
2096 .await
2097 .update_diagnostic_summary(&message, session.connection_id)
2098 .await?;
2099
2100 broadcast(
2101 Some(session.connection_id),
2102 guest_connection_ids.iter().copied(),
2103 |connection_id| {
2104 session
2105 .peer
2106 .forward_send(session.connection_id, connection_id, message.clone())
2107 },
2108 );
2109
2110 Ok(())
2111}
2112
2113/// Updates other participants with changes to the worktree settings
2114async fn update_worktree_settings(
2115 message: proto::UpdateWorktreeSettings,
2116 session: Session,
2117) -> Result<()> {
2118 let guest_connection_ids = session
2119 .db()
2120 .await
2121 .update_worktree_settings(&message, session.connection_id)
2122 .await?;
2123
2124 broadcast(
2125 Some(session.connection_id),
2126 guest_connection_ids.iter().copied(),
2127 |connection_id| {
2128 session
2129 .peer
2130 .forward_send(session.connection_id, connection_id, message.clone())
2131 },
2132 );
2133
2134 Ok(())
2135}
2136
2137/// Notify other participants that a language server has started.
2138async fn start_language_server(
2139 request: proto::StartLanguageServer,
2140 session: Session,
2141) -> Result<()> {
2142 let guest_connection_ids = session
2143 .db()
2144 .await
2145 .start_language_server(&request, session.connection_id)
2146 .await?;
2147
2148 broadcast(
2149 Some(session.connection_id),
2150 guest_connection_ids.iter().copied(),
2151 |connection_id| {
2152 session
2153 .peer
2154 .forward_send(session.connection_id, connection_id, request.clone())
2155 },
2156 );
2157 Ok(())
2158}
2159
2160/// Notify other participants that a language server has changed.
2161async fn update_language_server(
2162 request: proto::UpdateLanguageServer,
2163 session: Session,
2164) -> Result<()> {
2165 let project_id = ProjectId::from_proto(request.project_id);
2166 let project_connection_ids = session
2167 .db()
2168 .await
2169 .project_connection_ids(project_id, session.connection_id, true)
2170 .await?;
2171 broadcast(
2172 Some(session.connection_id),
2173 project_connection_ids.iter().copied(),
2174 |connection_id| {
2175 session
2176 .peer
2177 .forward_send(session.connection_id, connection_id, request.clone())
2178 },
2179 );
2180 Ok(())
2181}
2182
2183/// forward a project request to the host. These requests should be read only
2184/// as guests are allowed to send them.
2185async fn forward_read_only_project_request<T>(
2186 request: T,
2187 response: Response<T>,
2188 session: Session,
2189) -> Result<()>
2190where
2191 T: EntityMessage + RequestMessage,
2192{
2193 let project_id = ProjectId::from_proto(request.remote_entity_id());
2194 let host_connection_id = session
2195 .db()
2196 .await
2197 .host_for_read_only_project_request(project_id, session.connection_id)
2198 .await?;
2199 let payload = session
2200 .peer
2201 .forward_request(session.connection_id, host_connection_id, request)
2202 .await?;
2203 response.send(payload)?;
2204 Ok(())
2205}
2206
2207async fn forward_find_search_candidates_request(
2208 request: proto::FindSearchCandidates,
2209 response: Response<proto::FindSearchCandidates>,
2210 session: Session,
2211) -> Result<()> {
2212 let project_id = ProjectId::from_proto(request.remote_entity_id());
2213 let host_connection_id = session
2214 .db()
2215 .await
2216 .host_for_read_only_project_request(project_id, session.connection_id)
2217 .await?;
2218 let payload = session
2219 .peer
2220 .forward_request(session.connection_id, host_connection_id, request)
2221 .await?;
2222 response.send(payload)?;
2223 Ok(())
2224}
2225
2226/// forward a project request to the host. These requests are disallowed
2227/// for guests.
2228async fn forward_mutating_project_request<T>(
2229 request: T,
2230 response: Response<T>,
2231 session: Session,
2232) -> Result<()>
2233where
2234 T: EntityMessage + RequestMessage,
2235{
2236 let project_id = ProjectId::from_proto(request.remote_entity_id());
2237
2238 let host_connection_id = session
2239 .db()
2240 .await
2241 .host_for_mutating_project_request(project_id, session.connection_id)
2242 .await?;
2243 let payload = session
2244 .peer
2245 .forward_request(session.connection_id, host_connection_id, request)
2246 .await?;
2247 response.send(payload)?;
2248 Ok(())
2249}
2250
2251/// Notify other participants that a new buffer has been created
2252async fn create_buffer_for_peer(
2253 request: proto::CreateBufferForPeer,
2254 session: Session,
2255) -> Result<()> {
2256 session
2257 .db()
2258 .await
2259 .check_user_is_project_host(
2260 ProjectId::from_proto(request.project_id),
2261 session.connection_id,
2262 )
2263 .await?;
2264 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2265 session
2266 .peer
2267 .forward_send(session.connection_id, peer_id.into(), request)?;
2268 Ok(())
2269}
2270
2271/// Notify other participants that a buffer has been updated. This is
2272/// allowed for guests as long as the update is limited to selections.
2273async fn update_buffer(
2274 request: proto::UpdateBuffer,
2275 response: Response<proto::UpdateBuffer>,
2276 session: Session,
2277) -> Result<()> {
2278 let project_id = ProjectId::from_proto(request.project_id);
2279 let mut capability = Capability::ReadOnly;
2280
2281 for op in request.operations.iter() {
2282 match op.variant {
2283 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2284 Some(_) => capability = Capability::ReadWrite,
2285 }
2286 }
2287
2288 let host = {
2289 let guard = session
2290 .db()
2291 .await
2292 .connections_for_buffer_update(project_id, session.connection_id, capability)
2293 .await?;
2294
2295 let (host, guests) = &*guard;
2296
2297 broadcast(
2298 Some(session.connection_id),
2299 guests.clone(),
2300 |connection_id| {
2301 session
2302 .peer
2303 .forward_send(session.connection_id, connection_id, request.clone())
2304 },
2305 );
2306
2307 *host
2308 };
2309
2310 if host != session.connection_id {
2311 session
2312 .peer
2313 .forward_request(session.connection_id, host, request.clone())
2314 .await?;
2315 }
2316
2317 response.send(proto::Ack {})?;
2318 Ok(())
2319}
2320
2321async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2322 let project_id = ProjectId::from_proto(message.project_id);
2323
2324 let operation = message.operation.as_ref().context("invalid operation")?;
2325 let capability = match operation.variant.as_ref() {
2326 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2327 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2328 match buffer_op.variant {
2329 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2330 Capability::ReadOnly
2331 }
2332 _ => Capability::ReadWrite,
2333 }
2334 } else {
2335 Capability::ReadWrite
2336 }
2337 }
2338 Some(_) => Capability::ReadWrite,
2339 None => Capability::ReadOnly,
2340 };
2341
2342 let guard = session
2343 .db()
2344 .await
2345 .connections_for_buffer_update(project_id, session.connection_id, capability)
2346 .await?;
2347
2348 let (host, guests) = &*guard;
2349
2350 broadcast(
2351 Some(session.connection_id),
2352 guests.iter().chain([host]).copied(),
2353 |connection_id| {
2354 session
2355 .peer
2356 .forward_send(session.connection_id, connection_id, message.clone())
2357 },
2358 );
2359
2360 Ok(())
2361}
2362
2363/// Notify other participants that a project has been updated.
2364async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2365 request: T,
2366 session: Session,
2367) -> Result<()> {
2368 let project_id = ProjectId::from_proto(request.remote_entity_id());
2369 let project_connection_ids = session
2370 .db()
2371 .await
2372 .project_connection_ids(project_id, session.connection_id, false)
2373 .await?;
2374
2375 broadcast(
2376 Some(session.connection_id),
2377 project_connection_ids.iter().copied(),
2378 |connection_id| {
2379 session
2380 .peer
2381 .forward_send(session.connection_id, connection_id, request.clone())
2382 },
2383 );
2384 Ok(())
2385}
2386
2387/// Start following another user in a call.
2388async fn follow(
2389 request: proto::Follow,
2390 response: Response<proto::Follow>,
2391 session: Session,
2392) -> Result<()> {
2393 let room_id = RoomId::from_proto(request.room_id);
2394 let project_id = request.project_id.map(ProjectId::from_proto);
2395 let leader_id = request
2396 .leader_id
2397 .ok_or_else(|| anyhow!("invalid leader id"))?
2398 .into();
2399 let follower_id = session.connection_id;
2400
2401 session
2402 .db()
2403 .await
2404 .check_room_participants(room_id, leader_id, session.connection_id)
2405 .await?;
2406
2407 let response_payload = session
2408 .peer
2409 .forward_request(session.connection_id, leader_id, request)
2410 .await?;
2411 response.send(response_payload)?;
2412
2413 if let Some(project_id) = project_id {
2414 let room = session
2415 .db()
2416 .await
2417 .follow(room_id, project_id, leader_id, follower_id)
2418 .await?;
2419 room_updated(&room, &session.peer);
2420 }
2421
2422 Ok(())
2423}
2424
2425/// Stop following another user in a call.
2426async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2427 let room_id = RoomId::from_proto(request.room_id);
2428 let project_id = request.project_id.map(ProjectId::from_proto);
2429 let leader_id = request
2430 .leader_id
2431 .ok_or_else(|| anyhow!("invalid leader id"))?
2432 .into();
2433 let follower_id = session.connection_id;
2434
2435 session
2436 .db()
2437 .await
2438 .check_room_participants(room_id, leader_id, session.connection_id)
2439 .await?;
2440
2441 session
2442 .peer
2443 .forward_send(session.connection_id, leader_id, request)?;
2444
2445 if let Some(project_id) = project_id {
2446 let room = session
2447 .db()
2448 .await
2449 .unfollow(room_id, project_id, leader_id, follower_id)
2450 .await?;
2451 room_updated(&room, &session.peer);
2452 }
2453
2454 Ok(())
2455}
2456
2457/// Notify everyone following you of your current location.
2458async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2459 let room_id = RoomId::from_proto(request.room_id);
2460 let database = session.db.lock().await;
2461
2462 let connection_ids = if let Some(project_id) = request.project_id {
2463 let project_id = ProjectId::from_proto(project_id);
2464 database
2465 .project_connection_ids(project_id, session.connection_id, true)
2466 .await?
2467 } else {
2468 database
2469 .room_connection_ids(room_id, session.connection_id)
2470 .await?
2471 };
2472
2473 // For now, don't send view update messages back to that view's current leader.
2474 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2475 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2476 _ => None,
2477 });
2478
2479 for connection_id in connection_ids.iter().cloned() {
2480 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2481 session
2482 .peer
2483 .forward_send(session.connection_id, connection_id, request.clone())?;
2484 }
2485 }
2486 Ok(())
2487}
2488
2489/// Get public data about users.
2490async fn get_users(
2491 request: proto::GetUsers,
2492 response: Response<proto::GetUsers>,
2493 session: Session,
2494) -> Result<()> {
2495 let user_ids = request
2496 .user_ids
2497 .into_iter()
2498 .map(UserId::from_proto)
2499 .collect();
2500 let users = session
2501 .db()
2502 .await
2503 .get_users_by_ids(user_ids)
2504 .await?
2505 .into_iter()
2506 .map(|user| proto::User {
2507 id: user.id.to_proto(),
2508 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2509 github_login: user.github_login,
2510 email: user.email_address,
2511 name: user.name,
2512 })
2513 .collect();
2514 response.send(proto::UsersResponse { users })?;
2515 Ok(())
2516}
2517
2518/// Search for users (to invite) buy Github login
2519async fn fuzzy_search_users(
2520 request: proto::FuzzySearchUsers,
2521 response: Response<proto::FuzzySearchUsers>,
2522 session: Session,
2523) -> Result<()> {
2524 let query = request.query;
2525 let users = match query.len() {
2526 0 => vec![],
2527 1 | 2 => session
2528 .db()
2529 .await
2530 .get_user_by_github_login(&query)
2531 .await?
2532 .into_iter()
2533 .collect(),
2534 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2535 };
2536 let users = users
2537 .into_iter()
2538 .filter(|user| user.id != session.user_id())
2539 .map(|user| proto::User {
2540 id: user.id.to_proto(),
2541 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2542 github_login: user.github_login,
2543 name: user.name,
2544 email: user.email_address,
2545 })
2546 .collect();
2547 response.send(proto::UsersResponse { users })?;
2548 Ok(())
2549}
2550
2551/// Send a contact request to another user.
2552async fn request_contact(
2553 request: proto::RequestContact,
2554 response: Response<proto::RequestContact>,
2555 session: Session,
2556) -> Result<()> {
2557 let requester_id = session.user_id();
2558 let responder_id = UserId::from_proto(request.responder_id);
2559 if requester_id == responder_id {
2560 return Err(anyhow!("cannot add yourself as a contact"))?;
2561 }
2562
2563 let notifications = session
2564 .db()
2565 .await
2566 .send_contact_request(requester_id, responder_id)
2567 .await?;
2568
2569 // Update outgoing contact requests of requester
2570 let mut update = proto::UpdateContacts::default();
2571 update.outgoing_requests.push(responder_id.to_proto());
2572 for connection_id in session
2573 .connection_pool()
2574 .await
2575 .user_connection_ids(requester_id)
2576 {
2577 session.peer.send(connection_id, update.clone())?;
2578 }
2579
2580 // Update incoming contact requests of responder
2581 let mut update = proto::UpdateContacts::default();
2582 update
2583 .incoming_requests
2584 .push(proto::IncomingContactRequest {
2585 requester_id: requester_id.to_proto(),
2586 });
2587 let connection_pool = session.connection_pool().await;
2588 for connection_id in connection_pool.user_connection_ids(responder_id) {
2589 session.peer.send(connection_id, update.clone())?;
2590 }
2591
2592 send_notifications(&connection_pool, &session.peer, notifications);
2593
2594 response.send(proto::Ack {})?;
2595 Ok(())
2596}
2597
2598/// Accept or decline a contact request
2599async fn respond_to_contact_request(
2600 request: proto::RespondToContactRequest,
2601 response: Response<proto::RespondToContactRequest>,
2602 session: Session,
2603) -> Result<()> {
2604 let responder_id = session.user_id();
2605 let requester_id = UserId::from_proto(request.requester_id);
2606 let db = session.db().await;
2607 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2608 db.dismiss_contact_notification(responder_id, requester_id)
2609 .await?;
2610 } else {
2611 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2612
2613 let notifications = db
2614 .respond_to_contact_request(responder_id, requester_id, accept)
2615 .await?;
2616 let requester_busy = db.is_user_busy(requester_id).await?;
2617 let responder_busy = db.is_user_busy(responder_id).await?;
2618
2619 let pool = session.connection_pool().await;
2620 // Update responder with new contact
2621 let mut update = proto::UpdateContacts::default();
2622 if accept {
2623 update
2624 .contacts
2625 .push(contact_for_user(requester_id, requester_busy, &pool));
2626 }
2627 update
2628 .remove_incoming_requests
2629 .push(requester_id.to_proto());
2630 for connection_id in pool.user_connection_ids(responder_id) {
2631 session.peer.send(connection_id, update.clone())?;
2632 }
2633
2634 // Update requester with new contact
2635 let mut update = proto::UpdateContacts::default();
2636 if accept {
2637 update
2638 .contacts
2639 .push(contact_for_user(responder_id, responder_busy, &pool));
2640 }
2641 update
2642 .remove_outgoing_requests
2643 .push(responder_id.to_proto());
2644
2645 for connection_id in pool.user_connection_ids(requester_id) {
2646 session.peer.send(connection_id, update.clone())?;
2647 }
2648
2649 send_notifications(&pool, &session.peer, notifications);
2650 }
2651
2652 response.send(proto::Ack {})?;
2653 Ok(())
2654}
2655
2656/// Remove a contact.
2657async fn remove_contact(
2658 request: proto::RemoveContact,
2659 response: Response<proto::RemoveContact>,
2660 session: Session,
2661) -> Result<()> {
2662 let requester_id = session.user_id();
2663 let responder_id = UserId::from_proto(request.user_id);
2664 let db = session.db().await;
2665 let (contact_accepted, deleted_notification_id) =
2666 db.remove_contact(requester_id, responder_id).await?;
2667
2668 let pool = session.connection_pool().await;
2669 // Update outgoing contact requests of requester
2670 let mut update = proto::UpdateContacts::default();
2671 if contact_accepted {
2672 update.remove_contacts.push(responder_id.to_proto());
2673 } else {
2674 update
2675 .remove_outgoing_requests
2676 .push(responder_id.to_proto());
2677 }
2678 for connection_id in pool.user_connection_ids(requester_id) {
2679 session.peer.send(connection_id, update.clone())?;
2680 }
2681
2682 // Update incoming contact requests of responder
2683 let mut update = proto::UpdateContacts::default();
2684 if contact_accepted {
2685 update.remove_contacts.push(requester_id.to_proto());
2686 } else {
2687 update
2688 .remove_incoming_requests
2689 .push(requester_id.to_proto());
2690 }
2691 for connection_id in pool.user_connection_ids(responder_id) {
2692 session.peer.send(connection_id, update.clone())?;
2693 if let Some(notification_id) = deleted_notification_id {
2694 session.peer.send(
2695 connection_id,
2696 proto::DeleteNotification {
2697 notification_id: notification_id.to_proto(),
2698 },
2699 )?;
2700 }
2701 }
2702
2703 response.send(proto::Ack {})?;
2704 Ok(())
2705}
2706
2707fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2708 version.0.minor() < 139
2709}
2710
2711async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2712 let plan = session.current_plan(&session.db().await).await?;
2713
2714 session
2715 .peer
2716 .send(
2717 session.connection_id,
2718 proto::UpdateUserPlan { plan: plan.into() },
2719 )
2720 .trace_err();
2721
2722 Ok(())
2723}
2724
2725async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2726 subscribe_user_to_channels(session.user_id(), &session).await?;
2727 Ok(())
2728}
2729
2730async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2731 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2732 let mut pool = session.connection_pool().await;
2733 for membership in &channels_for_user.channel_memberships {
2734 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2735 }
2736 session.peer.send(
2737 session.connection_id,
2738 build_update_user_channels(&channels_for_user),
2739 )?;
2740 session.peer.send(
2741 session.connection_id,
2742 build_channels_update(channels_for_user),
2743 )?;
2744 Ok(())
2745}
2746
2747/// Creates a new channel.
2748async fn create_channel(
2749 request: proto::CreateChannel,
2750 response: Response<proto::CreateChannel>,
2751 session: Session,
2752) -> Result<()> {
2753 let db = session.db().await;
2754
2755 let parent_id = request.parent_id.map(ChannelId::from_proto);
2756 let (channel, membership) = db
2757 .create_channel(&request.name, parent_id, session.user_id())
2758 .await?;
2759
2760 let root_id = channel.root_id();
2761 let channel = Channel::from_model(channel);
2762
2763 response.send(proto::CreateChannelResponse {
2764 channel: Some(channel.to_proto()),
2765 parent_id: request.parent_id,
2766 })?;
2767
2768 let mut connection_pool = session.connection_pool().await;
2769 if let Some(membership) = membership {
2770 connection_pool.subscribe_to_channel(
2771 membership.user_id,
2772 membership.channel_id,
2773 membership.role,
2774 );
2775 let update = proto::UpdateUserChannels {
2776 channel_memberships: vec![proto::ChannelMembership {
2777 channel_id: membership.channel_id.to_proto(),
2778 role: membership.role.into(),
2779 }],
2780 ..Default::default()
2781 };
2782 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2783 session.peer.send(connection_id, update.clone())?;
2784 }
2785 }
2786
2787 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2788 if !role.can_see_channel(channel.visibility) {
2789 continue;
2790 }
2791
2792 let update = proto::UpdateChannels {
2793 channels: vec![channel.to_proto()],
2794 ..Default::default()
2795 };
2796 session.peer.send(connection_id, update.clone())?;
2797 }
2798
2799 Ok(())
2800}
2801
2802/// Delete a channel
2803async fn delete_channel(
2804 request: proto::DeleteChannel,
2805 response: Response<proto::DeleteChannel>,
2806 session: Session,
2807) -> Result<()> {
2808 let db = session.db().await;
2809
2810 let channel_id = request.channel_id;
2811 let (root_channel, removed_channels) = db
2812 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2813 .await?;
2814 response.send(proto::Ack {})?;
2815
2816 // Notify members of removed channels
2817 let mut update = proto::UpdateChannels::default();
2818 update
2819 .delete_channels
2820 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2821
2822 let connection_pool = session.connection_pool().await;
2823 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2824 session.peer.send(connection_id, update.clone())?;
2825 }
2826
2827 Ok(())
2828}
2829
2830/// Invite someone to join a channel.
2831async fn invite_channel_member(
2832 request: proto::InviteChannelMember,
2833 response: Response<proto::InviteChannelMember>,
2834 session: Session,
2835) -> Result<()> {
2836 let db = session.db().await;
2837 let channel_id = ChannelId::from_proto(request.channel_id);
2838 let invitee_id = UserId::from_proto(request.user_id);
2839 let InviteMemberResult {
2840 channel,
2841 notifications,
2842 } = db
2843 .invite_channel_member(
2844 channel_id,
2845 invitee_id,
2846 session.user_id(),
2847 request.role().into(),
2848 )
2849 .await?;
2850
2851 let update = proto::UpdateChannels {
2852 channel_invitations: vec![channel.to_proto()],
2853 ..Default::default()
2854 };
2855
2856 let connection_pool = session.connection_pool().await;
2857 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2858 session.peer.send(connection_id, update.clone())?;
2859 }
2860
2861 send_notifications(&connection_pool, &session.peer, notifications);
2862
2863 response.send(proto::Ack {})?;
2864 Ok(())
2865}
2866
2867/// remove someone from a channel
2868async fn remove_channel_member(
2869 request: proto::RemoveChannelMember,
2870 response: Response<proto::RemoveChannelMember>,
2871 session: Session,
2872) -> Result<()> {
2873 let db = session.db().await;
2874 let channel_id = ChannelId::from_proto(request.channel_id);
2875 let member_id = UserId::from_proto(request.user_id);
2876
2877 let RemoveChannelMemberResult {
2878 membership_update,
2879 notification_id,
2880 } = db
2881 .remove_channel_member(channel_id, member_id, session.user_id())
2882 .await?;
2883
2884 let mut connection_pool = session.connection_pool().await;
2885 notify_membership_updated(
2886 &mut connection_pool,
2887 membership_update,
2888 member_id,
2889 &session.peer,
2890 );
2891 for connection_id in connection_pool.user_connection_ids(member_id) {
2892 if let Some(notification_id) = notification_id {
2893 session
2894 .peer
2895 .send(
2896 connection_id,
2897 proto::DeleteNotification {
2898 notification_id: notification_id.to_proto(),
2899 },
2900 )
2901 .trace_err();
2902 }
2903 }
2904
2905 response.send(proto::Ack {})?;
2906 Ok(())
2907}
2908
2909/// Toggle the channel between public and private.
2910/// Care is taken to maintain the invariant that public channels only descend from public channels,
2911/// (though members-only channels can appear at any point in the hierarchy).
2912async fn set_channel_visibility(
2913 request: proto::SetChannelVisibility,
2914 response: Response<proto::SetChannelVisibility>,
2915 session: Session,
2916) -> Result<()> {
2917 let db = session.db().await;
2918 let channel_id = ChannelId::from_proto(request.channel_id);
2919 let visibility = request.visibility().into();
2920
2921 let channel_model = db
2922 .set_channel_visibility(channel_id, visibility, session.user_id())
2923 .await?;
2924 let root_id = channel_model.root_id();
2925 let channel = Channel::from_model(channel_model);
2926
2927 let mut connection_pool = session.connection_pool().await;
2928 for (user_id, role) in connection_pool
2929 .channel_user_ids(root_id)
2930 .collect::<Vec<_>>()
2931 .into_iter()
2932 {
2933 let update = if role.can_see_channel(channel.visibility) {
2934 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2935 proto::UpdateChannels {
2936 channels: vec![channel.to_proto()],
2937 ..Default::default()
2938 }
2939 } else {
2940 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2941 proto::UpdateChannels {
2942 delete_channels: vec![channel.id.to_proto()],
2943 ..Default::default()
2944 }
2945 };
2946
2947 for connection_id in connection_pool.user_connection_ids(user_id) {
2948 session.peer.send(connection_id, update.clone())?;
2949 }
2950 }
2951
2952 response.send(proto::Ack {})?;
2953 Ok(())
2954}
2955
2956/// Alter the role for a user in the channel.
2957async fn set_channel_member_role(
2958 request: proto::SetChannelMemberRole,
2959 response: Response<proto::SetChannelMemberRole>,
2960 session: Session,
2961) -> Result<()> {
2962 let db = session.db().await;
2963 let channel_id = ChannelId::from_proto(request.channel_id);
2964 let member_id = UserId::from_proto(request.user_id);
2965 let result = db
2966 .set_channel_member_role(
2967 channel_id,
2968 session.user_id(),
2969 member_id,
2970 request.role().into(),
2971 )
2972 .await?;
2973
2974 match result {
2975 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2976 let mut connection_pool = session.connection_pool().await;
2977 notify_membership_updated(
2978 &mut connection_pool,
2979 membership_update,
2980 member_id,
2981 &session.peer,
2982 )
2983 }
2984 db::SetMemberRoleResult::InviteUpdated(channel) => {
2985 let update = proto::UpdateChannels {
2986 channel_invitations: vec![channel.to_proto()],
2987 ..Default::default()
2988 };
2989
2990 for connection_id in session
2991 .connection_pool()
2992 .await
2993 .user_connection_ids(member_id)
2994 {
2995 session.peer.send(connection_id, update.clone())?;
2996 }
2997 }
2998 }
2999
3000 response.send(proto::Ack {})?;
3001 Ok(())
3002}
3003
3004/// Change the name of a channel
3005async fn rename_channel(
3006 request: proto::RenameChannel,
3007 response: Response<proto::RenameChannel>,
3008 session: Session,
3009) -> Result<()> {
3010 let db = session.db().await;
3011 let channel_id = ChannelId::from_proto(request.channel_id);
3012 let channel_model = db
3013 .rename_channel(channel_id, session.user_id(), &request.name)
3014 .await?;
3015 let root_id = channel_model.root_id();
3016 let channel = Channel::from_model(channel_model);
3017
3018 response.send(proto::RenameChannelResponse {
3019 channel: Some(channel.to_proto()),
3020 })?;
3021
3022 let connection_pool = session.connection_pool().await;
3023 let update = proto::UpdateChannels {
3024 channels: vec![channel.to_proto()],
3025 ..Default::default()
3026 };
3027 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3028 if role.can_see_channel(channel.visibility) {
3029 session.peer.send(connection_id, update.clone())?;
3030 }
3031 }
3032
3033 Ok(())
3034}
3035
3036/// Move a channel to a new parent.
3037async fn move_channel(
3038 request: proto::MoveChannel,
3039 response: Response<proto::MoveChannel>,
3040 session: Session,
3041) -> Result<()> {
3042 let channel_id = ChannelId::from_proto(request.channel_id);
3043 let to = ChannelId::from_proto(request.to);
3044
3045 let (root_id, channels) = session
3046 .db()
3047 .await
3048 .move_channel(channel_id, to, session.user_id())
3049 .await?;
3050
3051 let connection_pool = session.connection_pool().await;
3052 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3053 let channels = channels
3054 .iter()
3055 .filter_map(|channel| {
3056 if role.can_see_channel(channel.visibility) {
3057 Some(channel.to_proto())
3058 } else {
3059 None
3060 }
3061 })
3062 .collect::<Vec<_>>();
3063 if channels.is_empty() {
3064 continue;
3065 }
3066
3067 let update = proto::UpdateChannels {
3068 channels,
3069 ..Default::default()
3070 };
3071
3072 session.peer.send(connection_id, update.clone())?;
3073 }
3074
3075 response.send(Ack {})?;
3076 Ok(())
3077}
3078
3079/// Get the list of channel members
3080async fn get_channel_members(
3081 request: proto::GetChannelMembers,
3082 response: Response<proto::GetChannelMembers>,
3083 session: Session,
3084) -> Result<()> {
3085 let db = session.db().await;
3086 let channel_id = ChannelId::from_proto(request.channel_id);
3087 let limit = if request.limit == 0 {
3088 u16::MAX as u64
3089 } else {
3090 request.limit
3091 };
3092 let (members, users) = db
3093 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3094 .await?;
3095 response.send(proto::GetChannelMembersResponse { members, users })?;
3096 Ok(())
3097}
3098
3099/// Accept or decline a channel invitation.
3100async fn respond_to_channel_invite(
3101 request: proto::RespondToChannelInvite,
3102 response: Response<proto::RespondToChannelInvite>,
3103 session: Session,
3104) -> Result<()> {
3105 let db = session.db().await;
3106 let channel_id = ChannelId::from_proto(request.channel_id);
3107 let RespondToChannelInvite {
3108 membership_update,
3109 notifications,
3110 } = db
3111 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3112 .await?;
3113
3114 let mut connection_pool = session.connection_pool().await;
3115 if let Some(membership_update) = membership_update {
3116 notify_membership_updated(
3117 &mut connection_pool,
3118 membership_update,
3119 session.user_id(),
3120 &session.peer,
3121 );
3122 } else {
3123 let update = proto::UpdateChannels {
3124 remove_channel_invitations: vec![channel_id.to_proto()],
3125 ..Default::default()
3126 };
3127
3128 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3129 session.peer.send(connection_id, update.clone())?;
3130 }
3131 };
3132
3133 send_notifications(&connection_pool, &session.peer, notifications);
3134
3135 response.send(proto::Ack {})?;
3136
3137 Ok(())
3138}
3139
3140/// Join the channels' room
3141async fn join_channel(
3142 request: proto::JoinChannel,
3143 response: Response<proto::JoinChannel>,
3144 session: Session,
3145) -> Result<()> {
3146 let channel_id = ChannelId::from_proto(request.channel_id);
3147 join_channel_internal(channel_id, Box::new(response), session).await
3148}
3149
3150trait JoinChannelInternalResponse {
3151 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3152}
3153impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3154 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3155 Response::<proto::JoinChannel>::send(self, result)
3156 }
3157}
3158impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3159 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3160 Response::<proto::JoinRoom>::send(self, result)
3161 }
3162}
3163
3164async fn join_channel_internal(
3165 channel_id: ChannelId,
3166 response: Box<impl JoinChannelInternalResponse>,
3167 session: Session,
3168) -> Result<()> {
3169 let joined_room = {
3170 let mut db = session.db().await;
3171 // If zed quits without leaving the room, and the user re-opens zed before the
3172 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3173 // room they were in.
3174 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3175 tracing::info!(
3176 stale_connection_id = %connection,
3177 "cleaning up stale connection",
3178 );
3179 drop(db);
3180 leave_room_for_session(&session, connection).await?;
3181 db = session.db().await;
3182 }
3183
3184 let (joined_room, membership_updated, role) = db
3185 .join_channel(channel_id, session.user_id(), session.connection_id)
3186 .await?;
3187
3188 let live_kit_connection_info =
3189 session
3190 .app_state
3191 .livekit_client
3192 .as_ref()
3193 .and_then(|live_kit| {
3194 let (can_publish, token) = if role == ChannelRole::Guest {
3195 (
3196 false,
3197 live_kit
3198 .guest_token(
3199 &joined_room.room.livekit_room,
3200 &session.user_id().to_string(),
3201 )
3202 .trace_err()?,
3203 )
3204 } else {
3205 (
3206 true,
3207 live_kit
3208 .room_token(
3209 &joined_room.room.livekit_room,
3210 &session.user_id().to_string(),
3211 )
3212 .trace_err()?,
3213 )
3214 };
3215
3216 Some(LiveKitConnectionInfo {
3217 server_url: live_kit.url().into(),
3218 token,
3219 can_publish,
3220 })
3221 });
3222
3223 response.send(proto::JoinRoomResponse {
3224 room: Some(joined_room.room.clone()),
3225 channel_id: joined_room
3226 .channel
3227 .as_ref()
3228 .map(|channel| channel.id.to_proto()),
3229 live_kit_connection_info,
3230 })?;
3231
3232 let mut connection_pool = session.connection_pool().await;
3233 if let Some(membership_updated) = membership_updated {
3234 notify_membership_updated(
3235 &mut connection_pool,
3236 membership_updated,
3237 session.user_id(),
3238 &session.peer,
3239 );
3240 }
3241
3242 room_updated(&joined_room.room, &session.peer);
3243
3244 joined_room
3245 };
3246
3247 channel_updated(
3248 &joined_room
3249 .channel
3250 .ok_or_else(|| anyhow!("channel not returned"))?,
3251 &joined_room.room,
3252 &session.peer,
3253 &*session.connection_pool().await,
3254 );
3255
3256 update_user_contacts(session.user_id(), &session).await?;
3257 Ok(())
3258}
3259
3260/// Start editing the channel notes
3261async fn join_channel_buffer(
3262 request: proto::JoinChannelBuffer,
3263 response: Response<proto::JoinChannelBuffer>,
3264 session: Session,
3265) -> Result<()> {
3266 let db = session.db().await;
3267 let channel_id = ChannelId::from_proto(request.channel_id);
3268
3269 let open_response = db
3270 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3271 .await?;
3272
3273 let collaborators = open_response.collaborators.clone();
3274 response.send(open_response)?;
3275
3276 let update = UpdateChannelBufferCollaborators {
3277 channel_id: channel_id.to_proto(),
3278 collaborators: collaborators.clone(),
3279 };
3280 channel_buffer_updated(
3281 session.connection_id,
3282 collaborators
3283 .iter()
3284 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3285 &update,
3286 &session.peer,
3287 );
3288
3289 Ok(())
3290}
3291
3292/// Edit the channel notes
3293async fn update_channel_buffer(
3294 request: proto::UpdateChannelBuffer,
3295 session: Session,
3296) -> Result<()> {
3297 let db = session.db().await;
3298 let channel_id = ChannelId::from_proto(request.channel_id);
3299
3300 let (collaborators, epoch, version) = db
3301 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3302 .await?;
3303
3304 channel_buffer_updated(
3305 session.connection_id,
3306 collaborators.clone(),
3307 &proto::UpdateChannelBuffer {
3308 channel_id: channel_id.to_proto(),
3309 operations: request.operations,
3310 },
3311 &session.peer,
3312 );
3313
3314 let pool = &*session.connection_pool().await;
3315
3316 let non_collaborators =
3317 pool.channel_connection_ids(channel_id)
3318 .filter_map(|(connection_id, _)| {
3319 if collaborators.contains(&connection_id) {
3320 None
3321 } else {
3322 Some(connection_id)
3323 }
3324 });
3325
3326 broadcast(None, non_collaborators, |peer_id| {
3327 session.peer.send(
3328 peer_id,
3329 proto::UpdateChannels {
3330 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3331 channel_id: channel_id.to_proto(),
3332 epoch: epoch as u64,
3333 version: version.clone(),
3334 }],
3335 ..Default::default()
3336 },
3337 )
3338 });
3339
3340 Ok(())
3341}
3342
3343/// Rejoin the channel notes after a connection blip
3344async fn rejoin_channel_buffers(
3345 request: proto::RejoinChannelBuffers,
3346 response: Response<proto::RejoinChannelBuffers>,
3347 session: Session,
3348) -> Result<()> {
3349 let db = session.db().await;
3350 let buffers = db
3351 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3352 .await?;
3353
3354 for rejoined_buffer in &buffers {
3355 let collaborators_to_notify = rejoined_buffer
3356 .buffer
3357 .collaborators
3358 .iter()
3359 .filter_map(|c| Some(c.peer_id?.into()));
3360 channel_buffer_updated(
3361 session.connection_id,
3362 collaborators_to_notify,
3363 &proto::UpdateChannelBufferCollaborators {
3364 channel_id: rejoined_buffer.buffer.channel_id,
3365 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3366 },
3367 &session.peer,
3368 );
3369 }
3370
3371 response.send(proto::RejoinChannelBuffersResponse {
3372 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3373 })?;
3374
3375 Ok(())
3376}
3377
3378/// Stop editing the channel notes
3379async fn leave_channel_buffer(
3380 request: proto::LeaveChannelBuffer,
3381 response: Response<proto::LeaveChannelBuffer>,
3382 session: Session,
3383) -> Result<()> {
3384 let db = session.db().await;
3385 let channel_id = ChannelId::from_proto(request.channel_id);
3386
3387 let left_buffer = db
3388 .leave_channel_buffer(channel_id, session.connection_id)
3389 .await?;
3390
3391 response.send(Ack {})?;
3392
3393 channel_buffer_updated(
3394 session.connection_id,
3395 left_buffer.connections,
3396 &proto::UpdateChannelBufferCollaborators {
3397 channel_id: channel_id.to_proto(),
3398 collaborators: left_buffer.collaborators,
3399 },
3400 &session.peer,
3401 );
3402
3403 Ok(())
3404}
3405
3406fn channel_buffer_updated<T: EnvelopedMessage>(
3407 sender_id: ConnectionId,
3408 collaborators: impl IntoIterator<Item = ConnectionId>,
3409 message: &T,
3410 peer: &Peer,
3411) {
3412 broadcast(Some(sender_id), collaborators, |peer_id| {
3413 peer.send(peer_id, message.clone())
3414 });
3415}
3416
3417fn send_notifications(
3418 connection_pool: &ConnectionPool,
3419 peer: &Peer,
3420 notifications: db::NotificationBatch,
3421) {
3422 for (user_id, notification) in notifications {
3423 for connection_id in connection_pool.user_connection_ids(user_id) {
3424 if let Err(error) = peer.send(
3425 connection_id,
3426 proto::AddNotification {
3427 notification: Some(notification.clone()),
3428 },
3429 ) {
3430 tracing::error!(
3431 "failed to send notification to {:?} {}",
3432 connection_id,
3433 error
3434 );
3435 }
3436 }
3437 }
3438}
3439
3440/// Send a message to the channel
3441async fn send_channel_message(
3442 request: proto::SendChannelMessage,
3443 response: Response<proto::SendChannelMessage>,
3444 session: Session,
3445) -> Result<()> {
3446 // Validate the message body.
3447 let body = request.body.trim().to_string();
3448 if body.len() > MAX_MESSAGE_LEN {
3449 return Err(anyhow!("message is too long"))?;
3450 }
3451 if body.is_empty() {
3452 return Err(anyhow!("message can't be blank"))?;
3453 }
3454
3455 // TODO: adjust mentions if body is trimmed
3456
3457 let timestamp = OffsetDateTime::now_utc();
3458 let nonce = request
3459 .nonce
3460 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3461
3462 let channel_id = ChannelId::from_proto(request.channel_id);
3463 let CreatedChannelMessage {
3464 message_id,
3465 participant_connection_ids,
3466 notifications,
3467 } = session
3468 .db()
3469 .await
3470 .create_channel_message(
3471 channel_id,
3472 session.user_id(),
3473 &body,
3474 &request.mentions,
3475 timestamp,
3476 nonce.clone().into(),
3477 request.reply_to_message_id.map(MessageId::from_proto),
3478 )
3479 .await?;
3480
3481 let message = proto::ChannelMessage {
3482 sender_id: session.user_id().to_proto(),
3483 id: message_id.to_proto(),
3484 body,
3485 mentions: request.mentions,
3486 timestamp: timestamp.unix_timestamp() as u64,
3487 nonce: Some(nonce),
3488 reply_to_message_id: request.reply_to_message_id,
3489 edited_at: None,
3490 };
3491 broadcast(
3492 Some(session.connection_id),
3493 participant_connection_ids.clone(),
3494 |connection| {
3495 session.peer.send(
3496 connection,
3497 proto::ChannelMessageSent {
3498 channel_id: channel_id.to_proto(),
3499 message: Some(message.clone()),
3500 },
3501 )
3502 },
3503 );
3504 response.send(proto::SendChannelMessageResponse {
3505 message: Some(message),
3506 })?;
3507
3508 let pool = &*session.connection_pool().await;
3509 let non_participants =
3510 pool.channel_connection_ids(channel_id)
3511 .filter_map(|(connection_id, _)| {
3512 if participant_connection_ids.contains(&connection_id) {
3513 None
3514 } else {
3515 Some(connection_id)
3516 }
3517 });
3518 broadcast(None, non_participants, |peer_id| {
3519 session.peer.send(
3520 peer_id,
3521 proto::UpdateChannels {
3522 latest_channel_message_ids: vec![proto::ChannelMessageId {
3523 channel_id: channel_id.to_proto(),
3524 message_id: message_id.to_proto(),
3525 }],
3526 ..Default::default()
3527 },
3528 )
3529 });
3530 send_notifications(pool, &session.peer, notifications);
3531
3532 Ok(())
3533}
3534
3535/// Delete a channel message
3536async fn remove_channel_message(
3537 request: proto::RemoveChannelMessage,
3538 response: Response<proto::RemoveChannelMessage>,
3539 session: Session,
3540) -> Result<()> {
3541 let channel_id = ChannelId::from_proto(request.channel_id);
3542 let message_id = MessageId::from_proto(request.message_id);
3543 let (connection_ids, existing_notification_ids) = session
3544 .db()
3545 .await
3546 .remove_channel_message(channel_id, message_id, session.user_id())
3547 .await?;
3548
3549 broadcast(
3550 Some(session.connection_id),
3551 connection_ids,
3552 move |connection| {
3553 session.peer.send(connection, request.clone())?;
3554
3555 for notification_id in &existing_notification_ids {
3556 session.peer.send(
3557 connection,
3558 proto::DeleteNotification {
3559 notification_id: (*notification_id).to_proto(),
3560 },
3561 )?;
3562 }
3563
3564 Ok(())
3565 },
3566 );
3567 response.send(proto::Ack {})?;
3568 Ok(())
3569}
3570
3571async fn update_channel_message(
3572 request: proto::UpdateChannelMessage,
3573 response: Response<proto::UpdateChannelMessage>,
3574 session: Session,
3575) -> Result<()> {
3576 let channel_id = ChannelId::from_proto(request.channel_id);
3577 let message_id = MessageId::from_proto(request.message_id);
3578 let updated_at = OffsetDateTime::now_utc();
3579 let UpdatedChannelMessage {
3580 message_id,
3581 participant_connection_ids,
3582 notifications,
3583 reply_to_message_id,
3584 timestamp,
3585 deleted_mention_notification_ids,
3586 updated_mention_notifications,
3587 } = session
3588 .db()
3589 .await
3590 .update_channel_message(
3591 channel_id,
3592 message_id,
3593 session.user_id(),
3594 request.body.as_str(),
3595 &request.mentions,
3596 updated_at,
3597 )
3598 .await?;
3599
3600 let nonce = request
3601 .nonce
3602 .clone()
3603 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3604
3605 let message = proto::ChannelMessage {
3606 sender_id: session.user_id().to_proto(),
3607 id: message_id.to_proto(),
3608 body: request.body.clone(),
3609 mentions: request.mentions.clone(),
3610 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3611 nonce: Some(nonce),
3612 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3613 edited_at: Some(updated_at.unix_timestamp() as u64),
3614 };
3615
3616 response.send(proto::Ack {})?;
3617
3618 let pool = &*session.connection_pool().await;
3619 broadcast(
3620 Some(session.connection_id),
3621 participant_connection_ids,
3622 |connection| {
3623 session.peer.send(
3624 connection,
3625 proto::ChannelMessageUpdate {
3626 channel_id: channel_id.to_proto(),
3627 message: Some(message.clone()),
3628 },
3629 )?;
3630
3631 for notification_id in &deleted_mention_notification_ids {
3632 session.peer.send(
3633 connection,
3634 proto::DeleteNotification {
3635 notification_id: (*notification_id).to_proto(),
3636 },
3637 )?;
3638 }
3639
3640 for notification in &updated_mention_notifications {
3641 session.peer.send(
3642 connection,
3643 proto::UpdateNotification {
3644 notification: Some(notification.clone()),
3645 },
3646 )?;
3647 }
3648
3649 Ok(())
3650 },
3651 );
3652
3653 send_notifications(pool, &session.peer, notifications);
3654
3655 Ok(())
3656}
3657
3658/// Mark a channel message as read
3659async fn acknowledge_channel_message(
3660 request: proto::AckChannelMessage,
3661 session: Session,
3662) -> Result<()> {
3663 let channel_id = ChannelId::from_proto(request.channel_id);
3664 let message_id = MessageId::from_proto(request.message_id);
3665 let notifications = session
3666 .db()
3667 .await
3668 .observe_channel_message(channel_id, session.user_id(), message_id)
3669 .await?;
3670 send_notifications(
3671 &*session.connection_pool().await,
3672 &session.peer,
3673 notifications,
3674 );
3675 Ok(())
3676}
3677
3678/// Mark a buffer version as synced
3679async fn acknowledge_buffer_version(
3680 request: proto::AckBufferOperation,
3681 session: Session,
3682) -> Result<()> {
3683 let buffer_id = BufferId::from_proto(request.buffer_id);
3684 session
3685 .db()
3686 .await
3687 .observe_buffer_version(
3688 buffer_id,
3689 session.user_id(),
3690 request.epoch as i32,
3691 &request.version,
3692 )
3693 .await?;
3694 Ok(())
3695}
3696
3697async fn count_language_model_tokens(
3698 request: proto::CountLanguageModelTokens,
3699 response: Response<proto::CountLanguageModelTokens>,
3700 session: Session,
3701 config: &Config,
3702) -> Result<()> {
3703 authorize_access_to_legacy_llm_endpoints(&session).await?;
3704
3705 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3706 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3707 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3708 };
3709
3710 session
3711 .app_state
3712 .rate_limiter
3713 .check(&*rate_limit, session.user_id())
3714 .await?;
3715
3716 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3717 Some(proto::LanguageModelProvider::Google) => {
3718 let api_key = config
3719 .google_ai_api_key
3720 .as_ref()
3721 .context("no Google AI API key configured on the server")?;
3722 google_ai::count_tokens(
3723 session.http_client.as_ref(),
3724 google_ai::API_URL,
3725 api_key,
3726 serde_json::from_str(&request.request)?,
3727 )
3728 .await?
3729 }
3730 _ => return Err(anyhow!("unsupported provider"))?,
3731 };
3732
3733 response.send(proto::CountLanguageModelTokensResponse {
3734 token_count: result.total_tokens as u32,
3735 })?;
3736
3737 Ok(())
3738}
3739
3740struct ZedProCountLanguageModelTokensRateLimit;
3741
3742impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3743 fn capacity(&self) -> usize {
3744 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3745 .ok()
3746 .and_then(|v| v.parse().ok())
3747 .unwrap_or(600) // Picked arbitrarily
3748 }
3749
3750 fn refill_duration(&self) -> chrono::Duration {
3751 chrono::Duration::hours(1)
3752 }
3753
3754 fn db_name(&self) -> &'static str {
3755 "zed-pro:count-language-model-tokens"
3756 }
3757}
3758
3759struct FreeCountLanguageModelTokensRateLimit;
3760
3761impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3762 fn capacity(&self) -> usize {
3763 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3764 .ok()
3765 .and_then(|v| v.parse().ok())
3766 .unwrap_or(600 / 10) // Picked arbitrarily
3767 }
3768
3769 fn refill_duration(&self) -> chrono::Duration {
3770 chrono::Duration::hours(1)
3771 }
3772
3773 fn db_name(&self) -> &'static str {
3774 "free:count-language-model-tokens"
3775 }
3776}
3777
3778struct ZedProComputeEmbeddingsRateLimit;
3779
3780impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3781 fn capacity(&self) -> usize {
3782 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3783 .ok()
3784 .and_then(|v| v.parse().ok())
3785 .unwrap_or(5000) // Picked arbitrarily
3786 }
3787
3788 fn refill_duration(&self) -> chrono::Duration {
3789 chrono::Duration::hours(1)
3790 }
3791
3792 fn db_name(&self) -> &'static str {
3793 "zed-pro:compute-embeddings"
3794 }
3795}
3796
3797struct FreeComputeEmbeddingsRateLimit;
3798
3799impl RateLimit for FreeComputeEmbeddingsRateLimit {
3800 fn capacity(&self) -> usize {
3801 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3802 .ok()
3803 .and_then(|v| v.parse().ok())
3804 .unwrap_or(5000 / 10) // Picked arbitrarily
3805 }
3806
3807 fn refill_duration(&self) -> chrono::Duration {
3808 chrono::Duration::hours(1)
3809 }
3810
3811 fn db_name(&self) -> &'static str {
3812 "free:compute-embeddings"
3813 }
3814}
3815
3816async fn compute_embeddings(
3817 request: proto::ComputeEmbeddings,
3818 response: Response<proto::ComputeEmbeddings>,
3819 session: Session,
3820 api_key: Option<Arc<str>>,
3821) -> Result<()> {
3822 let api_key = api_key.context("no OpenAI API key configured on the server")?;
3823 authorize_access_to_legacy_llm_endpoints(&session).await?;
3824
3825 let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3826 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3827 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3828 };
3829
3830 session
3831 .app_state
3832 .rate_limiter
3833 .check(&*rate_limit, session.user_id())
3834 .await?;
3835
3836 let embeddings = match request.model.as_str() {
3837 "openai/text-embedding-3-small" => {
3838 open_ai::embed(
3839 session.http_client.as_ref(),
3840 OPEN_AI_API_URL,
3841 &api_key,
3842 OpenAiEmbeddingModel::TextEmbedding3Small,
3843 request.texts.iter().map(|text| text.as_str()),
3844 )
3845 .await?
3846 }
3847 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3848 };
3849
3850 let embeddings = request
3851 .texts
3852 .iter()
3853 .map(|text| {
3854 let mut hasher = sha2::Sha256::new();
3855 hasher.update(text.as_bytes());
3856 let result = hasher.finalize();
3857 result.to_vec()
3858 })
3859 .zip(
3860 embeddings
3861 .data
3862 .into_iter()
3863 .map(|embedding| embedding.embedding),
3864 )
3865 .collect::<HashMap<_, _>>();
3866
3867 let db = session.db().await;
3868 db.save_embeddings(&request.model, &embeddings)
3869 .await
3870 .context("failed to save embeddings")
3871 .trace_err();
3872
3873 response.send(proto::ComputeEmbeddingsResponse {
3874 embeddings: embeddings
3875 .into_iter()
3876 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3877 .collect(),
3878 })?;
3879 Ok(())
3880}
3881
3882async fn get_cached_embeddings(
3883 request: proto::GetCachedEmbeddings,
3884 response: Response<proto::GetCachedEmbeddings>,
3885 session: Session,
3886) -> Result<()> {
3887 authorize_access_to_legacy_llm_endpoints(&session).await?;
3888
3889 let db = session.db().await;
3890 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3891
3892 response.send(proto::GetCachedEmbeddingsResponse {
3893 embeddings: embeddings
3894 .into_iter()
3895 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3896 .collect(),
3897 })?;
3898 Ok(())
3899}
3900
3901/// This is leftover from before the LLM service.
3902///
3903/// The endpoints protected by this check will be moved there eventually.
3904async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3905 if session.is_staff() {
3906 Ok(())
3907 } else {
3908 Err(anyhow!("permission denied"))?
3909 }
3910}
3911
3912/// Get a Supermaven API key for the user
3913async fn get_supermaven_api_key(
3914 _request: proto::GetSupermavenApiKey,
3915 response: Response<proto::GetSupermavenApiKey>,
3916 session: Session,
3917) -> Result<()> {
3918 let user_id: String = session.user_id().to_string();
3919 if !session.is_staff() {
3920 return Err(anyhow!("supermaven not enabled for this account"))?;
3921 }
3922
3923 let email = session
3924 .email()
3925 .ok_or_else(|| anyhow!("user must have an email"))?;
3926
3927 let supermaven_admin_api = session
3928 .supermaven_client
3929 .as_ref()
3930 .ok_or_else(|| anyhow!("supermaven not configured"))?;
3931
3932 let result = supermaven_admin_api
3933 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3934 .await?;
3935
3936 response.send(proto::GetSupermavenApiKeyResponse {
3937 api_key: result.api_key,
3938 })?;
3939
3940 Ok(())
3941}
3942
3943/// Start receiving chat updates for a channel
3944async fn join_channel_chat(
3945 request: proto::JoinChannelChat,
3946 response: Response<proto::JoinChannelChat>,
3947 session: Session,
3948) -> Result<()> {
3949 let channel_id = ChannelId::from_proto(request.channel_id);
3950
3951 let db = session.db().await;
3952 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3953 .await?;
3954 let messages = db
3955 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3956 .await?;
3957 response.send(proto::JoinChannelChatResponse {
3958 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3959 messages,
3960 })?;
3961 Ok(())
3962}
3963
3964/// Stop receiving chat updates for a channel
3965async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3966 let channel_id = ChannelId::from_proto(request.channel_id);
3967 session
3968 .db()
3969 .await
3970 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3971 .await?;
3972 Ok(())
3973}
3974
3975/// Retrieve the chat history for a channel
3976async fn get_channel_messages(
3977 request: proto::GetChannelMessages,
3978 response: Response<proto::GetChannelMessages>,
3979 session: Session,
3980) -> Result<()> {
3981 let channel_id = ChannelId::from_proto(request.channel_id);
3982 let messages = session
3983 .db()
3984 .await
3985 .get_channel_messages(
3986 channel_id,
3987 session.user_id(),
3988 MESSAGE_COUNT_PER_PAGE,
3989 Some(MessageId::from_proto(request.before_message_id)),
3990 )
3991 .await?;
3992 response.send(proto::GetChannelMessagesResponse {
3993 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3994 messages,
3995 })?;
3996 Ok(())
3997}
3998
3999/// Retrieve specific chat messages
4000async fn get_channel_messages_by_id(
4001 request: proto::GetChannelMessagesById,
4002 response: Response<proto::GetChannelMessagesById>,
4003 session: Session,
4004) -> Result<()> {
4005 let message_ids = request
4006 .message_ids
4007 .iter()
4008 .map(|id| MessageId::from_proto(*id))
4009 .collect::<Vec<_>>();
4010 let messages = session
4011 .db()
4012 .await
4013 .get_channel_messages_by_id(session.user_id(), &message_ids)
4014 .await?;
4015 response.send(proto::GetChannelMessagesResponse {
4016 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4017 messages,
4018 })?;
4019 Ok(())
4020}
4021
4022/// Retrieve the current users notifications
4023async fn get_notifications(
4024 request: proto::GetNotifications,
4025 response: Response<proto::GetNotifications>,
4026 session: Session,
4027) -> Result<()> {
4028 let notifications = session
4029 .db()
4030 .await
4031 .get_notifications(
4032 session.user_id(),
4033 NOTIFICATION_COUNT_PER_PAGE,
4034 request.before_id.map(db::NotificationId::from_proto),
4035 )
4036 .await?;
4037 response.send(proto::GetNotificationsResponse {
4038 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4039 notifications,
4040 })?;
4041 Ok(())
4042}
4043
4044/// Mark notifications as read
4045async fn mark_notification_as_read(
4046 request: proto::MarkNotificationRead,
4047 response: Response<proto::MarkNotificationRead>,
4048 session: Session,
4049) -> Result<()> {
4050 let database = &session.db().await;
4051 let notifications = database
4052 .mark_notification_as_read_by_id(
4053 session.user_id(),
4054 NotificationId::from_proto(request.notification_id),
4055 )
4056 .await?;
4057 send_notifications(
4058 &*session.connection_pool().await,
4059 &session.peer,
4060 notifications,
4061 );
4062 response.send(proto::Ack {})?;
4063 Ok(())
4064}
4065
4066/// Get the current users information
4067async fn get_private_user_info(
4068 _request: proto::GetPrivateUserInfo,
4069 response: Response<proto::GetPrivateUserInfo>,
4070 session: Session,
4071) -> Result<()> {
4072 let db = session.db().await;
4073
4074 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4075 let user = db
4076 .get_user_by_id(session.user_id())
4077 .await?
4078 .ok_or_else(|| anyhow!("user not found"))?;
4079 let flags = db.get_user_flags(session.user_id()).await?;
4080
4081 response.send(proto::GetPrivateUserInfoResponse {
4082 metrics_id,
4083 staff: user.admin,
4084 flags,
4085 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4086 })?;
4087 Ok(())
4088}
4089
4090/// Accept the terms of service (tos) on behalf of the current user
4091async fn accept_terms_of_service(
4092 _request: proto::AcceptTermsOfService,
4093 response: Response<proto::AcceptTermsOfService>,
4094 session: Session,
4095) -> Result<()> {
4096 let db = session.db().await;
4097
4098 let accepted_tos_at = Utc::now();
4099 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4100 .await?;
4101
4102 response.send(proto::AcceptTermsOfServiceResponse {
4103 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4104 })?;
4105 Ok(())
4106}
4107
4108/// The minimum account age an account must have in order to use the LLM service.
4109pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4110
4111async fn get_llm_api_token(
4112 _request: proto::GetLlmToken,
4113 response: Response<proto::GetLlmToken>,
4114 session: Session,
4115) -> Result<()> {
4116 let db = session.db().await;
4117
4118 let flags = db.get_user_flags(session.user_id()).await?;
4119 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4120
4121 if !session.is_staff() && !has_language_models_feature_flag {
4122 Err(anyhow!("permission denied"))?
4123 }
4124
4125 let user_id = session.user_id();
4126 let user = db
4127 .get_user_by_id(user_id)
4128 .await?
4129 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4130
4131 if user.accepted_tos_at.is_none() {
4132 Err(anyhow!("terms of service not accepted"))?
4133 }
4134
4135 let has_llm_subscription = session.has_llm_subscription(&db).await?;
4136 let billing_preferences = db.get_billing_preferences(user.id).await?;
4137
4138 let token = LlmTokenClaims::create(
4139 &user,
4140 session.is_staff(),
4141 billing_preferences,
4142 &flags,
4143 has_llm_subscription,
4144 session.current_plan(&db).await?,
4145 session.system_id.clone(),
4146 &session.app_state.config,
4147 )?;
4148 response.send(proto::GetLlmTokenResponse { token })?;
4149 Ok(())
4150}
4151
4152fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4153 let message = match message {
4154 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4155 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4156 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4157 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4158 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4159 code: frame.code.into(),
4160 reason: frame.reason,
4161 })),
4162 // We should never receive a frame while reading the message, according
4163 // to the `tungstenite` maintainers:
4164 //
4165 // > It cannot occur when you read messages from the WebSocket, but it
4166 // > can be used when you want to send the raw frames (e.g. you want to
4167 // > send the frames to the WebSocket without composing the full message first).
4168 // >
4169 // > — https://github.com/snapview/tungstenite-rs/issues/268
4170 TungsteniteMessage::Frame(_) => {
4171 bail!("received an unexpected frame while reading the message")
4172 }
4173 };
4174
4175 Ok(message)
4176}
4177
4178fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4179 match message {
4180 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4181 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4182 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4183 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4184 AxumMessage::Close(frame) => {
4185 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4186 code: frame.code.into(),
4187 reason: frame.reason,
4188 }))
4189 }
4190 }
4191}
4192
4193fn notify_membership_updated(
4194 connection_pool: &mut ConnectionPool,
4195 result: MembershipUpdated,
4196 user_id: UserId,
4197 peer: &Peer,
4198) {
4199 for membership in &result.new_channels.channel_memberships {
4200 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4201 }
4202 for channel_id in &result.removed_channels {
4203 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4204 }
4205
4206 let user_channels_update = proto::UpdateUserChannels {
4207 channel_memberships: result
4208 .new_channels
4209 .channel_memberships
4210 .iter()
4211 .map(|cm| proto::ChannelMembership {
4212 channel_id: cm.channel_id.to_proto(),
4213 role: cm.role.into(),
4214 })
4215 .collect(),
4216 ..Default::default()
4217 };
4218
4219 let mut update = build_channels_update(result.new_channels);
4220 update.delete_channels = result
4221 .removed_channels
4222 .into_iter()
4223 .map(|id| id.to_proto())
4224 .collect();
4225 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4226
4227 for connection_id in connection_pool.user_connection_ids(user_id) {
4228 peer.send(connection_id, user_channels_update.clone())
4229 .trace_err();
4230 peer.send(connection_id, update.clone()).trace_err();
4231 }
4232}
4233
4234fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4235 proto::UpdateUserChannels {
4236 channel_memberships: channels
4237 .channel_memberships
4238 .iter()
4239 .map(|m| proto::ChannelMembership {
4240 channel_id: m.channel_id.to_proto(),
4241 role: m.role.into(),
4242 })
4243 .collect(),
4244 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4245 observed_channel_message_id: channels.observed_channel_messages.clone(),
4246 }
4247}
4248
4249fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4250 let mut update = proto::UpdateChannels::default();
4251
4252 for channel in channels.channels {
4253 update.channels.push(channel.to_proto());
4254 }
4255
4256 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4257 update.latest_channel_message_ids = channels.latest_channel_messages;
4258
4259 for (channel_id, participants) in channels.channel_participants {
4260 update
4261 .channel_participants
4262 .push(proto::ChannelParticipants {
4263 channel_id: channel_id.to_proto(),
4264 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4265 });
4266 }
4267
4268 for channel in channels.invited_channels {
4269 update.channel_invitations.push(channel.to_proto());
4270 }
4271
4272 update
4273}
4274
4275fn build_initial_contacts_update(
4276 contacts: Vec<db::Contact>,
4277 pool: &ConnectionPool,
4278) -> proto::UpdateContacts {
4279 let mut update = proto::UpdateContacts::default();
4280
4281 for contact in contacts {
4282 match contact {
4283 db::Contact::Accepted { user_id, busy } => {
4284 update.contacts.push(contact_for_user(user_id, busy, pool));
4285 }
4286 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4287 db::Contact::Incoming { user_id } => {
4288 update
4289 .incoming_requests
4290 .push(proto::IncomingContactRequest {
4291 requester_id: user_id.to_proto(),
4292 })
4293 }
4294 }
4295 }
4296
4297 update
4298}
4299
4300fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4301 proto::Contact {
4302 user_id: user_id.to_proto(),
4303 online: pool.is_user_online(user_id),
4304 busy,
4305 }
4306}
4307
4308fn room_updated(room: &proto::Room, peer: &Peer) {
4309 broadcast(
4310 None,
4311 room.participants
4312 .iter()
4313 .filter_map(|participant| Some(participant.peer_id?.into())),
4314 |peer_id| {
4315 peer.send(
4316 peer_id,
4317 proto::RoomUpdated {
4318 room: Some(room.clone()),
4319 },
4320 )
4321 },
4322 );
4323}
4324
4325fn channel_updated(
4326 channel: &db::channel::Model,
4327 room: &proto::Room,
4328 peer: &Peer,
4329 pool: &ConnectionPool,
4330) {
4331 let participants = room
4332 .participants
4333 .iter()
4334 .map(|p| p.user_id)
4335 .collect::<Vec<_>>();
4336
4337 broadcast(
4338 None,
4339 pool.channel_connection_ids(channel.root_id())
4340 .filter_map(|(channel_id, role)| {
4341 role.can_see_channel(channel.visibility)
4342 .then_some(channel_id)
4343 }),
4344 |peer_id| {
4345 peer.send(
4346 peer_id,
4347 proto::UpdateChannels {
4348 channel_participants: vec![proto::ChannelParticipants {
4349 channel_id: channel.id.to_proto(),
4350 participant_user_ids: participants.clone(),
4351 }],
4352 ..Default::default()
4353 },
4354 )
4355 },
4356 );
4357}
4358
4359async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4360 let db = session.db().await;
4361
4362 let contacts = db.get_contacts(user_id).await?;
4363 let busy = db.is_user_busy(user_id).await?;
4364
4365 let pool = session.connection_pool().await;
4366 let updated_contact = contact_for_user(user_id, busy, &pool);
4367 for contact in contacts {
4368 if let db::Contact::Accepted {
4369 user_id: contact_user_id,
4370 ..
4371 } = contact
4372 {
4373 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4374 session
4375 .peer
4376 .send(
4377 contact_conn_id,
4378 proto::UpdateContacts {
4379 contacts: vec![updated_contact.clone()],
4380 remove_contacts: Default::default(),
4381 incoming_requests: Default::default(),
4382 remove_incoming_requests: Default::default(),
4383 outgoing_requests: Default::default(),
4384 remove_outgoing_requests: Default::default(),
4385 },
4386 )
4387 .trace_err();
4388 }
4389 }
4390 }
4391 Ok(())
4392}
4393
4394async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4395 let mut contacts_to_update = HashSet::default();
4396
4397 let room_id;
4398 let canceled_calls_to_user_ids;
4399 let livekit_room;
4400 let delete_livekit_room;
4401 let room;
4402 let channel;
4403
4404 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4405 contacts_to_update.insert(session.user_id());
4406
4407 for project in left_room.left_projects.values() {
4408 project_left(project, session);
4409 }
4410
4411 room_id = RoomId::from_proto(left_room.room.id);
4412 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4413 livekit_room = mem::take(&mut left_room.room.livekit_room);
4414 delete_livekit_room = left_room.deleted;
4415 room = mem::take(&mut left_room.room);
4416 channel = mem::take(&mut left_room.channel);
4417
4418 room_updated(&room, &session.peer);
4419 } else {
4420 return Ok(());
4421 }
4422
4423 if let Some(channel) = channel {
4424 channel_updated(
4425 &channel,
4426 &room,
4427 &session.peer,
4428 &*session.connection_pool().await,
4429 );
4430 }
4431
4432 {
4433 let pool = session.connection_pool().await;
4434 for canceled_user_id in canceled_calls_to_user_ids {
4435 for connection_id in pool.user_connection_ids(canceled_user_id) {
4436 session
4437 .peer
4438 .send(
4439 connection_id,
4440 proto::CallCanceled {
4441 room_id: room_id.to_proto(),
4442 },
4443 )
4444 .trace_err();
4445 }
4446 contacts_to_update.insert(canceled_user_id);
4447 }
4448 }
4449
4450 for contact_user_id in contacts_to_update {
4451 update_user_contacts(contact_user_id, session).await?;
4452 }
4453
4454 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4455 live_kit
4456 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4457 .await
4458 .trace_err();
4459
4460 if delete_livekit_room {
4461 live_kit.delete_room(livekit_room).await.trace_err();
4462 }
4463 }
4464
4465 Ok(())
4466}
4467
4468async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4469 let left_channel_buffers = session
4470 .db()
4471 .await
4472 .leave_channel_buffers(session.connection_id)
4473 .await?;
4474
4475 for left_buffer in left_channel_buffers {
4476 channel_buffer_updated(
4477 session.connection_id,
4478 left_buffer.connections,
4479 &proto::UpdateChannelBufferCollaborators {
4480 channel_id: left_buffer.channel_id.to_proto(),
4481 collaborators: left_buffer.collaborators,
4482 },
4483 &session.peer,
4484 );
4485 }
4486
4487 Ok(())
4488}
4489
4490fn project_left(project: &db::LeftProject, session: &Session) {
4491 for connection_id in &project.connection_ids {
4492 if project.should_unshare {
4493 session
4494 .peer
4495 .send(
4496 *connection_id,
4497 proto::UnshareProject {
4498 project_id: project.id.to_proto(),
4499 },
4500 )
4501 .trace_err();
4502 } else {
4503 session
4504 .peer
4505 .send(
4506 *connection_id,
4507 proto::RemoveProjectCollaborator {
4508 project_id: project.id.to_proto(),
4509 peer_id: Some(session.connection_id.into()),
4510 },
4511 )
4512 .trace_err();
4513 }
4514 }
4515}
4516
4517pub trait ResultExt {
4518 type Ok;
4519
4520 fn trace_err(self) -> Option<Self::Ok>;
4521}
4522
4523impl<T, E> ResultExt for Result<T, E>
4524where
4525 E: std::fmt::Debug,
4526{
4527 type Ok = T;
4528
4529 #[track_caller]
4530 fn trace_err(self) -> Option<T> {
4531 match self {
4532 Ok(value) => Some(value),
4533 Err(error) => {
4534 tracing::error!("{:?}", error);
4535 None
4536 }
4537 }
4538 }
4539}