1mod connection_pool;
2
3use crate::{
4 auth::{self},
5 db::{
6 self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId,
9 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
10 },
11 executor::Executor,
12 AppState, Error, RateLimit, RateLimiter, Result,
13};
14use anyhow::{anyhow, Context as _};
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34
35use futures::{
36 channel::oneshot,
37 future::{self, BoxFuture},
38 stream::FuturesUnordered,
39 FutureExt, SinkExt, StreamExt, TryStreamExt,
40};
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
45 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48};
49use semantic_version::SemanticVersion;
50use serde::{Serialize, Serializer};
51use std::{
52 any::TypeId,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc, OnceLock,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{
69 field::{self},
70 info_span, instrument, Instrument,
71};
72use util::http::IsahcHttpClient;
73
74pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
75
76// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
77pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
78
79const MESSAGE_COUNT_PER_PAGE: usize = 100;
80const MAX_MESSAGE_LEN: usize = 1024;
81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
82
83type MessageHandler =
84 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
85
86struct Response<R> {
87 peer: Arc<Peer>,
88 receipt: Receipt<R>,
89 responded: Arc<AtomicBool>,
90}
91
92impl<R: RequestMessage> Response<R> {
93 fn send(self, payload: R::Response) -> Result<()> {
94 self.responded.store(true, SeqCst);
95 self.peer.respond(self.receipt, payload)?;
96 Ok(())
97 }
98}
99
100struct StreamingResponse<R: RequestMessage> {
101 peer: Arc<Peer>,
102 receipt: Receipt<R>,
103}
104
105impl<R: RequestMessage> StreamingResponse<R> {
106 fn send(&self, payload: R::Response) -> Result<()> {
107 self.peer.respond(self.receipt, payload)?;
108 Ok(())
109 }
110}
111
112#[derive(Clone, Debug)]
113pub enum Principal {
114 User(User),
115 Impersonated { user: User, admin: User },
116 DevServer(dev_server::Model),
117}
118
119impl Principal {
120 fn update_span(&self, span: &tracing::Span) {
121 match &self {
122 Principal::User(user) => {
123 span.record("user_id", &user.id.0);
124 span.record("login", &user.github_login);
125 }
126 Principal::Impersonated { user, admin } => {
127 span.record("user_id", &user.id.0);
128 span.record("login", &user.github_login);
129 span.record("impersonator", &admin.github_login);
130 }
131 Principal::DevServer(dev_server) => {
132 span.record("dev_server_id", &dev_server.id.0);
133 }
134 }
135 }
136}
137
138#[derive(Clone)]
139struct Session {
140 principal: Principal,
141 connection_id: ConnectionId,
142 db: Arc<tokio::sync::Mutex<DbHandle>>,
143 peer: Arc<Peer>,
144 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
145 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
146 http_client: IsahcHttpClient,
147 rate_limiter: Arc<RateLimiter>,
148 _executor: Executor,
149}
150
151impl Session {
152 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
153 #[cfg(test)]
154 tokio::task::yield_now().await;
155 let guard = self.db.lock().await;
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 guard
159 }
160
161 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
162 #[cfg(test)]
163 tokio::task::yield_now().await;
164 let guard = self.connection_pool.lock();
165 ConnectionPoolGuard {
166 guard,
167 _not_send: PhantomData,
168 }
169 }
170
171 fn for_user(self) -> Option<UserSession> {
172 UserSession::new(self)
173 }
174
175 fn user_id(&self) -> Option<UserId> {
176 match &self.principal {
177 Principal::User(user) => Some(user.id),
178 Principal::Impersonated { user, .. } => Some(user.id),
179 Principal::DevServer(_) => None,
180 }
181 }
182}
183
184impl Debug for Session {
185 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
186 let mut result = f.debug_struct("Session");
187 match &self.principal {
188 Principal::User(user) => {
189 result.field("user", &user.github_login);
190 }
191 Principal::Impersonated { user, admin } => {
192 result.field("user", &user.github_login);
193 result.field("impersonator", &admin.github_login);
194 }
195 Principal::DevServer(dev_server) => {
196 result.field("dev_server", &dev_server.id);
197 }
198 }
199 result.field("connection_id", &self.connection_id).finish()
200 }
201}
202
203struct UserSession(Session);
204
205impl UserSession {
206 pub fn new(s: Session) -> Option<Self> {
207 s.user_id().map(|_| UserSession(s))
208 }
209 pub fn user_id(&self) -> UserId {
210 self.0.user_id().unwrap()
211 }
212}
213
214impl Deref for UserSession {
215 type Target = Session;
216
217 fn deref(&self) -> &Self::Target {
218 &self.0
219 }
220}
221impl DerefMut for UserSession {
222 fn deref_mut(&mut self) -> &mut Self::Target {
223 &mut self.0
224 }
225}
226
227fn user_handler<M: RequestMessage, Fut>(
228 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
229) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
230where
231 Fut: Send + Future<Output = Result<()>>,
232{
233 let handler = Arc::new(handler);
234 move |message, response, session| {
235 let handler = handler.clone();
236 Box::pin(async move {
237 if let Some(user_session) = session.for_user() {
238 Ok(handler(message, response, user_session).await?)
239 } else {
240 Err(Error::Internal(anyhow!("must be a user")))
241 }
242 })
243 }
244}
245
246fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
247 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
248) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
249where
250 InnertRetFut: Send + Future<Output = Result<()>>,
251{
252 let handler = Arc::new(handler);
253 move |message, session| {
254 let handler = handler.clone();
255 Box::pin(async move {
256 if let Some(user_session) = session.for_user() {
257 Ok(handler(message, user_session).await?)
258 } else {
259 Err(Error::Internal(anyhow!("must be a user")))
260 }
261 })
262 }
263}
264
265struct DbHandle(Arc<Database>);
266
267impl Deref for DbHandle {
268 type Target = Database;
269
270 fn deref(&self) -> &Self::Target {
271 self.0.as_ref()
272 }
273}
274
275pub struct Server {
276 id: parking_lot::Mutex<ServerId>,
277 peer: Arc<Peer>,
278 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
279 app_state: Arc<AppState>,
280 handlers: HashMap<TypeId, MessageHandler>,
281 teardown: watch::Sender<bool>,
282}
283
284pub(crate) struct ConnectionPoolGuard<'a> {
285 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
286 _not_send: PhantomData<Rc<()>>,
287}
288
289#[derive(Serialize)]
290pub struct ServerSnapshot<'a> {
291 peer: &'a Peer,
292 #[serde(serialize_with = "serialize_deref")]
293 connection_pool: ConnectionPoolGuard<'a>,
294}
295
296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
297where
298 S: Serializer,
299 T: Deref<Target = U>,
300 U: Serialize,
301{
302 Serialize::serialize(value.deref(), serializer)
303}
304
305impl Server {
306 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
307 let mut server = Self {
308 id: parking_lot::Mutex::new(id),
309 peer: Peer::new(id.0 as u32),
310 app_state: app_state.clone(),
311 connection_pool: Default::default(),
312 handlers: Default::default(),
313 teardown: watch::channel(false).0,
314 };
315
316 server
317 .add_request_handler(ping)
318 .add_request_handler(user_handler(create_room))
319 .add_request_handler(user_handler(join_room))
320 .add_request_handler(user_handler(rejoin_room))
321 .add_request_handler(user_handler(leave_room))
322 .add_request_handler(user_handler(set_room_participant_role))
323 .add_request_handler(user_handler(call))
324 .add_request_handler(user_handler(cancel_call))
325 .add_message_handler(user_message_handler(decline_call))
326 .add_request_handler(user_handler(update_participant_location))
327 .add_request_handler(share_project)
328 .add_message_handler(unshare_project)
329 .add_request_handler(user_handler(join_project))
330 .add_request_handler(user_handler(join_hosted_project))
331 .add_message_handler(user_message_handler(leave_project))
332 .add_request_handler(update_project)
333 .add_request_handler(update_worktree)
334 .add_message_handler(start_language_server)
335 .add_message_handler(update_language_server)
336 .add_message_handler(update_diagnostic_summary)
337 .add_message_handler(update_worktree_settings)
338 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
339 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
341 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
342 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
344 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
345 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
346 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
347 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
348 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
349 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
350 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
353 )
354 .add_request_handler(
355 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
356 )
357 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
358 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
359 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
360 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
361 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
362 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
363 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
364 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
365 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
368 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
369 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
370 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
371 .add_message_handler(create_buffer_for_peer)
372 .add_request_handler(update_buffer)
373 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
374 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
375 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
376 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
377 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
378 .add_request_handler(get_users)
379 .add_request_handler(user_handler(fuzzy_search_users))
380 .add_request_handler(user_handler(request_contact))
381 .add_request_handler(user_handler(remove_contact))
382 .add_request_handler(user_handler(respond_to_contact_request))
383 .add_request_handler(user_handler(create_channel))
384 .add_request_handler(user_handler(delete_channel))
385 .add_request_handler(user_handler(invite_channel_member))
386 .add_request_handler(user_handler(remove_channel_member))
387 .add_request_handler(user_handler(set_channel_member_role))
388 .add_request_handler(user_handler(set_channel_visibility))
389 .add_request_handler(user_handler(rename_channel))
390 .add_request_handler(user_handler(join_channel_buffer))
391 .add_request_handler(user_handler(leave_channel_buffer))
392 .add_message_handler(user_message_handler(update_channel_buffer))
393 .add_request_handler(user_handler(rejoin_channel_buffers))
394 .add_request_handler(user_handler(get_channel_members))
395 .add_request_handler(user_handler(respond_to_channel_invite))
396 .add_request_handler(user_handler(join_channel))
397 .add_request_handler(user_handler(join_channel_chat))
398 .add_message_handler(user_message_handler(leave_channel_chat))
399 .add_request_handler(user_handler(send_channel_message))
400 .add_request_handler(user_handler(remove_channel_message))
401 .add_request_handler(user_handler(update_channel_message))
402 .add_request_handler(user_handler(get_channel_messages))
403 .add_request_handler(user_handler(get_channel_messages_by_id))
404 .add_request_handler(user_handler(get_notifications))
405 .add_request_handler(user_handler(mark_notification_as_read))
406 .add_request_handler(user_handler(move_channel))
407 .add_request_handler(user_handler(follow))
408 .add_message_handler(user_message_handler(unfollow))
409 .add_message_handler(user_message_handler(update_followers))
410 .add_request_handler(user_handler(get_private_user_info))
411 .add_message_handler(user_message_handler(acknowledge_channel_message))
412 .add_message_handler(user_message_handler(acknowledge_buffer_version))
413 .add_streaming_request_handler({
414 let app_state = app_state.clone();
415 move |request, response, session| {
416 complete_with_language_model(
417 request,
418 response,
419 session,
420 app_state.config.openai_api_key.clone(),
421 app_state.config.google_ai_api_key.clone(),
422 app_state.config.anthropic_api_key.clone(),
423 )
424 }
425 })
426 .add_request_handler({
427 let app_state = app_state.clone();
428 user_handler(move |request, response, session| {
429 count_tokens_with_language_model(
430 request,
431 response,
432 session,
433 app_state.config.google_ai_api_key.clone(),
434 )
435 })
436 });
437
438 Arc::new(server)
439 }
440
441 pub async fn start(&self) -> Result<()> {
442 let server_id = *self.id.lock();
443 let app_state = self.app_state.clone();
444 let peer = self.peer.clone();
445 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
446 let pool = self.connection_pool.clone();
447 let live_kit_client = self.app_state.live_kit_client.clone();
448
449 let span = info_span!("start server");
450 self.app_state.executor.spawn_detached(
451 async move {
452 tracing::info!("waiting for cleanup timeout");
453 timeout.await;
454 tracing::info!("cleanup timeout expired, retrieving stale rooms");
455 if let Some((room_ids, channel_ids)) = app_state
456 .db
457 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
458 .await
459 .trace_err()
460 {
461 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
462 tracing::info!(
463 stale_channel_buffer_count = channel_ids.len(),
464 "retrieved stale channel buffers"
465 );
466
467 for channel_id in channel_ids {
468 if let Some(refreshed_channel_buffer) = app_state
469 .db
470 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
471 .await
472 .trace_err()
473 {
474 for connection_id in refreshed_channel_buffer.connection_ids {
475 peer.send(
476 connection_id,
477 proto::UpdateChannelBufferCollaborators {
478 channel_id: channel_id.to_proto(),
479 collaborators: refreshed_channel_buffer
480 .collaborators
481 .clone(),
482 },
483 )
484 .trace_err();
485 }
486 }
487 }
488
489 for room_id in room_ids {
490 let mut contacts_to_update = HashSet::default();
491 let mut canceled_calls_to_user_ids = Vec::new();
492 let mut live_kit_room = String::new();
493 let mut delete_live_kit_room = false;
494
495 if let Some(mut refreshed_room) = app_state
496 .db
497 .clear_stale_room_participants(room_id, server_id)
498 .await
499 .trace_err()
500 {
501 tracing::info!(
502 room_id = room_id.0,
503 new_participant_count = refreshed_room.room.participants.len(),
504 "refreshed room"
505 );
506 room_updated(&refreshed_room.room, &peer);
507 if let Some(channel) = refreshed_room.channel.as_ref() {
508 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
509 }
510 contacts_to_update
511 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
512 contacts_to_update
513 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
514 canceled_calls_to_user_ids =
515 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
516 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
517 delete_live_kit_room = refreshed_room.room.participants.is_empty();
518 }
519
520 {
521 let pool = pool.lock();
522 for canceled_user_id in canceled_calls_to_user_ids {
523 for connection_id in pool.user_connection_ids(canceled_user_id) {
524 peer.send(
525 connection_id,
526 proto::CallCanceled {
527 room_id: room_id.to_proto(),
528 },
529 )
530 .trace_err();
531 }
532 }
533 }
534
535 for user_id in contacts_to_update {
536 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
537 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
538 if let Some((busy, contacts)) = busy.zip(contacts) {
539 let pool = pool.lock();
540 let updated_contact = contact_for_user(user_id, busy, &pool);
541 for contact in contacts {
542 if let db::Contact::Accepted {
543 user_id: contact_user_id,
544 ..
545 } = contact
546 {
547 for contact_conn_id in
548 pool.user_connection_ids(contact_user_id)
549 {
550 peer.send(
551 contact_conn_id,
552 proto::UpdateContacts {
553 contacts: vec![updated_contact.clone()],
554 remove_contacts: Default::default(),
555 incoming_requests: Default::default(),
556 remove_incoming_requests: Default::default(),
557 outgoing_requests: Default::default(),
558 remove_outgoing_requests: Default::default(),
559 },
560 )
561 .trace_err();
562 }
563 }
564 }
565 }
566 }
567
568 if let Some(live_kit) = live_kit_client.as_ref() {
569 if delete_live_kit_room {
570 live_kit.delete_room(live_kit_room).await.trace_err();
571 }
572 }
573 }
574 }
575
576 app_state
577 .db
578 .delete_stale_servers(&app_state.config.zed_environment, server_id)
579 .await
580 .trace_err();
581 }
582 .instrument(span),
583 );
584 Ok(())
585 }
586
587 pub fn teardown(&self) {
588 self.peer.teardown();
589 self.connection_pool.lock().reset();
590 let _ = self.teardown.send(true);
591 }
592
593 #[cfg(test)]
594 pub fn reset(&self, id: ServerId) {
595 self.teardown();
596 *self.id.lock() = id;
597 self.peer.reset(id.0 as u32);
598 let _ = self.teardown.send(false);
599 }
600
601 #[cfg(test)]
602 pub fn id(&self) -> ServerId {
603 *self.id.lock()
604 }
605
606 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
607 where
608 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
609 Fut: 'static + Send + Future<Output = Result<()>>,
610 M: EnvelopedMessage,
611 {
612 let prev_handler = self.handlers.insert(
613 TypeId::of::<M>(),
614 Box::new(move |envelope, session| {
615 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
616 let received_at = envelope.received_at;
617 tracing::info!(
618 "message received"
619 );
620 let start_time = Instant::now();
621 let future = (handler)(*envelope, session);
622 async move {
623 let result = future.await;
624 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
625 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
626 let queue_duration_ms = total_duration_ms - processing_duration_ms;
627 match result {
628 Err(error) => {
629 tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
630 }
631 Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
632 }
633 }
634 .boxed()
635 }),
636 );
637 if prev_handler.is_some() {
638 panic!("registered a handler for the same message twice");
639 }
640 self
641 }
642
643 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
644 where
645 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
646 Fut: 'static + Send + Future<Output = Result<()>>,
647 M: EnvelopedMessage,
648 {
649 self.add_handler(move |envelope, session| handler(envelope.payload, session));
650 self
651 }
652
653 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
654 where
655 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
656 Fut: Send + Future<Output = Result<()>>,
657 M: RequestMessage,
658 {
659 let handler = Arc::new(handler);
660 self.add_handler(move |envelope, session| {
661 let receipt = envelope.receipt();
662 let handler = handler.clone();
663 async move {
664 let peer = session.peer.clone();
665 let responded = Arc::new(AtomicBool::default());
666 let response = Response {
667 peer: peer.clone(),
668 responded: responded.clone(),
669 receipt,
670 };
671 match (handler)(envelope.payload, response, session).await {
672 Ok(()) => {
673 if responded.load(std::sync::atomic::Ordering::SeqCst) {
674 Ok(())
675 } else {
676 Err(anyhow!("handler did not send a response"))?
677 }
678 }
679 Err(error) => {
680 let proto_err = match &error {
681 Error::Internal(err) => err.to_proto(),
682 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
683 };
684 peer.respond_with_error(receipt, proto_err)?;
685 Err(error)
686 }
687 }
688 }
689 })
690 }
691
692 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
693 where
694 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
695 Fut: Send + Future<Output = Result<()>>,
696 M: RequestMessage,
697 {
698 let handler = Arc::new(handler);
699 self.add_handler(move |envelope, session| {
700 let receipt = envelope.receipt();
701 let handler = handler.clone();
702 async move {
703 let peer = session.peer.clone();
704 let response = StreamingResponse {
705 peer: peer.clone(),
706 receipt,
707 };
708 match (handler)(envelope.payload, response, session).await {
709 Ok(()) => {
710 peer.end_stream(receipt)?;
711 Ok(())
712 }
713 Err(error) => {
714 let proto_err = match &error {
715 Error::Internal(err) => err.to_proto(),
716 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
717 };
718 peer.respond_with_error(receipt, proto_err)?;
719 Err(error)
720 }
721 }
722 }
723 })
724 }
725
726 #[allow(clippy::too_many_arguments)]
727 pub fn handle_connection(
728 self: &Arc<Self>,
729 connection: Connection,
730 address: String,
731 principal: Principal,
732 zed_version: ZedVersion,
733 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
734 executor: Executor,
735 ) -> impl Future<Output = ()> {
736 let this = self.clone();
737 let span = info_span!("handle connection", %address,
738 connection_id=field::Empty,
739 user_id=field::Empty,
740 login=field::Empty,
741 impersonator=field::Empty,
742 dev_server_id=field::Empty
743 );
744 principal.update_span(&span);
745
746 let mut teardown = self.teardown.subscribe();
747 async move {
748 if *teardown.borrow() {
749 tracing::error!("server is tearing down");
750 return
751 }
752 let (connection_id, handle_io, mut incoming_rx) = this
753 .peer
754 .add_connection(connection, {
755 let executor = executor.clone();
756 move |duration| executor.sleep(duration)
757 });
758 tracing::Span::current().record("connection_id", format!("{}", connection_id));
759 tracing::info!("connection opened");
760
761 let http_client = match IsahcHttpClient::new() {
762 Ok(http_client) => http_client,
763 Err(error) => {
764 tracing::error!(?error, "failed to create HTTP client");
765 return;
766 }
767 };
768
769 let session = Session {
770 principal: principal.clone(),
771 connection_id,
772 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
773 peer: this.peer.clone(),
774 connection_pool: this.connection_pool.clone(),
775 live_kit_client: this.app_state.live_kit_client.clone(),
776 http_client,
777 rate_limiter: this.app_state.rate_limiter.clone(),
778 _executor: executor.clone(),
779 };
780
781 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
782 tracing::error!(?error, "failed to send initial client update");
783 return;
784 }
785
786 let handle_io = handle_io.fuse();
787 futures::pin_mut!(handle_io);
788
789 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
790 // This prevents deadlocks when e.g., client A performs a request to client B and
791 // client B performs a request to client A. If both clients stop processing further
792 // messages until their respective request completes, they won't have a chance to
793 // respond to the other client's request and cause a deadlock.
794 //
795 // This arrangement ensures we will attempt to process earlier messages first, but fall
796 // back to processing messages arrived later in the spirit of making progress.
797 let mut foreground_message_handlers = FuturesUnordered::new();
798 let concurrent_handlers = Arc::new(Semaphore::new(256));
799 loop {
800 let next_message = async {
801 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
802 let message = incoming_rx.next().await;
803 (permit, message)
804 }.fuse();
805 futures::pin_mut!(next_message);
806 futures::select_biased! {
807 _ = teardown.changed().fuse() => return,
808 result = handle_io => {
809 if let Err(error) = result {
810 tracing::error!(?error, "error handling I/O");
811 }
812 break;
813 }
814 _ = foreground_message_handlers.next() => {}
815 next_message = next_message => {
816 let (permit, message) = next_message;
817 if let Some(message) = message {
818 let type_name = message.payload_type_name();
819 // note: we copy all the fields from the parent span so we can query them in the logs.
820 // (https://github.com/tokio-rs/tracing/issues/2670).
821 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
822 user_id=field::Empty,
823 login=field::Empty,
824 impersonator=field::Empty,
825 dev_server_id=field::Empty
826 );
827 principal.update_span(&span);
828 let span_enter = span.enter();
829 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
830 let is_background = message.is_background();
831 let handle_message = (handler)(message, session.clone());
832 drop(span_enter);
833
834 let handle_message = async move {
835 handle_message.await;
836 drop(permit);
837 }.instrument(span);
838 if is_background {
839 executor.spawn_detached(handle_message);
840 } else {
841 foreground_message_handlers.push(handle_message);
842 }
843 } else {
844 tracing::error!("no message handler");
845 }
846 } else {
847 tracing::info!("connection closed");
848 break;
849 }
850 }
851 }
852 }
853
854 drop(foreground_message_handlers);
855 tracing::info!("signing out");
856 if let Err(error) = connection_lost(session, teardown, executor).await {
857 tracing::error!(?error, "error signing out");
858 }
859
860 }.instrument(span)
861 }
862
863 async fn send_initial_client_update(
864 &self,
865 connection_id: ConnectionId,
866 principal: &Principal,
867 zed_version: ZedVersion,
868 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
869 session: &Session,
870 ) -> Result<()> {
871 self.peer.send(
872 connection_id,
873 proto::Hello {
874 peer_id: Some(connection_id.into()),
875 },
876 )?;
877 tracing::info!("sent hello message");
878
879 let Principal::User(user) = principal else {
880 return Ok(());
881 };
882
883 if let Some(send_connection_id) = send_connection_id.take() {
884 let _ = send_connection_id.send(connection_id);
885 }
886
887 if !user.connected_once {
888 self.peer.send(connection_id, proto::ShowContacts {})?;
889 self.app_state
890 .db
891 .set_user_connected_once(user.id, true)
892 .await?;
893 }
894
895 let (contacts, channels_for_user, channel_invites) = future::try_join3(
896 self.app_state.db.get_contacts(user.id),
897 self.app_state.db.get_channels_for_user(user.id),
898 self.app_state.db.get_channel_invites_for_user(user.id),
899 )
900 .await?;
901
902 {
903 let mut pool = self.connection_pool.lock();
904 pool.add_connection(connection_id, user.id, user.admin, zed_version);
905 for membership in &channels_for_user.channel_memberships {
906 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
907 }
908 self.peer.send(
909 connection_id,
910 build_initial_contacts_update(contacts, &pool),
911 )?;
912 self.peer.send(
913 connection_id,
914 build_update_user_channels(&channels_for_user),
915 )?;
916 self.peer.send(
917 connection_id,
918 build_channels_update(channels_for_user, channel_invites),
919 )?;
920 }
921
922 if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
923 self.peer.send(connection_id, incoming_call)?;
924 }
925
926 update_user_contacts(user.id, &session).await?;
927 Ok(())
928 }
929
930 pub async fn invite_code_redeemed(
931 self: &Arc<Self>,
932 inviter_id: UserId,
933 invitee_id: UserId,
934 ) -> Result<()> {
935 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
936 if let Some(code) = &user.invite_code {
937 let pool = self.connection_pool.lock();
938 let invitee_contact = contact_for_user(invitee_id, false, &pool);
939 for connection_id in pool.user_connection_ids(inviter_id) {
940 self.peer.send(
941 connection_id,
942 proto::UpdateContacts {
943 contacts: vec![invitee_contact.clone()],
944 ..Default::default()
945 },
946 )?;
947 self.peer.send(
948 connection_id,
949 proto::UpdateInviteInfo {
950 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
951 count: user.invite_count as u32,
952 },
953 )?;
954 }
955 }
956 }
957 Ok(())
958 }
959
960 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
961 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
962 if let Some(invite_code) = &user.invite_code {
963 let pool = self.connection_pool.lock();
964 for connection_id in pool.user_connection_ids(user_id) {
965 self.peer.send(
966 connection_id,
967 proto::UpdateInviteInfo {
968 url: format!(
969 "{}{}",
970 self.app_state.config.invite_link_prefix, invite_code
971 ),
972 count: user.invite_count as u32,
973 },
974 )?;
975 }
976 }
977 }
978 Ok(())
979 }
980
981 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
982 ServerSnapshot {
983 connection_pool: ConnectionPoolGuard {
984 guard: self.connection_pool.lock(),
985 _not_send: PhantomData,
986 },
987 peer: &self.peer,
988 }
989 }
990}
991
992impl<'a> Deref for ConnectionPoolGuard<'a> {
993 type Target = ConnectionPool;
994
995 fn deref(&self) -> &Self::Target {
996 &self.guard
997 }
998}
999
1000impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1001 fn deref_mut(&mut self) -> &mut Self::Target {
1002 &mut self.guard
1003 }
1004}
1005
1006impl<'a> Drop for ConnectionPoolGuard<'a> {
1007 fn drop(&mut self) {
1008 #[cfg(test)]
1009 self.check_invariants();
1010 }
1011}
1012
1013fn broadcast<F>(
1014 sender_id: Option<ConnectionId>,
1015 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1016 mut f: F,
1017) where
1018 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1019{
1020 for receiver_id in receiver_ids {
1021 if Some(receiver_id) != sender_id {
1022 if let Err(error) = f(receiver_id) {
1023 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1024 }
1025 }
1026 }
1027}
1028
1029pub struct ProtocolVersion(u32);
1030
1031impl Header for ProtocolVersion {
1032 fn name() -> &'static HeaderName {
1033 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1034 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1035 }
1036
1037 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1038 where
1039 Self: Sized,
1040 I: Iterator<Item = &'i axum::http::HeaderValue>,
1041 {
1042 let version = values
1043 .next()
1044 .ok_or_else(axum::headers::Error::invalid)?
1045 .to_str()
1046 .map_err(|_| axum::headers::Error::invalid())?
1047 .parse()
1048 .map_err(|_| axum::headers::Error::invalid())?;
1049 Ok(Self(version))
1050 }
1051
1052 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1053 values.extend([self.0.to_string().parse().unwrap()]);
1054 }
1055}
1056
1057pub struct AppVersionHeader(SemanticVersion);
1058impl Header for AppVersionHeader {
1059 fn name() -> &'static HeaderName {
1060 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1061 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1062 }
1063
1064 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1065 where
1066 Self: Sized,
1067 I: Iterator<Item = &'i axum::http::HeaderValue>,
1068 {
1069 let version = values
1070 .next()
1071 .ok_or_else(axum::headers::Error::invalid)?
1072 .to_str()
1073 .map_err(|_| axum::headers::Error::invalid())?
1074 .parse()
1075 .map_err(|_| axum::headers::Error::invalid())?;
1076 Ok(Self(version))
1077 }
1078
1079 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1080 values.extend([self.0.to_string().parse().unwrap()]);
1081 }
1082}
1083
1084pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1085 Router::new()
1086 .route("/rpc", get(handle_websocket_request))
1087 .layer(
1088 ServiceBuilder::new()
1089 .layer(Extension(server.app_state.clone()))
1090 .layer(middleware::from_fn(auth::validate_header)),
1091 )
1092 .route("/metrics", get(handle_metrics))
1093 .layer(Extension(server))
1094}
1095
1096pub async fn handle_websocket_request(
1097 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1098 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1099 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1100 Extension(server): Extension<Arc<Server>>,
1101 Extension(principal): Extension<Principal>,
1102 ws: WebSocketUpgrade,
1103) -> axum::response::Response {
1104 if protocol_version != rpc::PROTOCOL_VERSION {
1105 return (
1106 StatusCode::UPGRADE_REQUIRED,
1107 "client must be upgraded".to_string(),
1108 )
1109 .into_response();
1110 }
1111
1112 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1113 return (
1114 StatusCode::UPGRADE_REQUIRED,
1115 "no version header found".to_string(),
1116 )
1117 .into_response();
1118 };
1119
1120 if !version.can_collaborate() {
1121 return (
1122 StatusCode::UPGRADE_REQUIRED,
1123 "client must be upgraded".to_string(),
1124 )
1125 .into_response();
1126 }
1127
1128 let socket_address = socket_address.to_string();
1129 ws.on_upgrade(move |socket| {
1130 let socket = socket
1131 .map_ok(to_tungstenite_message)
1132 .err_into()
1133 .with(|message| async move { Ok(to_axum_message(message)) });
1134 let connection = Connection::new(Box::pin(socket));
1135 async move {
1136 server
1137 .handle_connection(
1138 connection,
1139 socket_address,
1140 principal,
1141 version,
1142 None,
1143 Executor::Production,
1144 )
1145 .await;
1146 }
1147 })
1148}
1149
1150pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1151 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1152 let connections_metric = CONNECTIONS_METRIC
1153 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1154
1155 let connections = server
1156 .connection_pool
1157 .lock()
1158 .connections()
1159 .filter(|connection| !connection.admin)
1160 .count();
1161 connections_metric.set(connections as _);
1162
1163 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1164 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1165 register_int_gauge!(
1166 "shared_projects",
1167 "number of open projects with one or more guests"
1168 )
1169 .unwrap()
1170 });
1171
1172 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1173 shared_projects_metric.set(shared_projects as _);
1174
1175 let encoder = prometheus::TextEncoder::new();
1176 let metric_families = prometheus::gather();
1177 let encoded_metrics = encoder
1178 .encode_to_string(&metric_families)
1179 .map_err(|err| anyhow!("{}", err))?;
1180 Ok(encoded_metrics)
1181}
1182
1183#[instrument(err, skip(executor))]
1184async fn connection_lost(
1185 session: Session,
1186 mut teardown: watch::Receiver<bool>,
1187 executor: Executor,
1188) -> Result<()> {
1189 session.peer.disconnect(session.connection_id);
1190 session
1191 .connection_pool()
1192 .await
1193 .remove_connection(session.connection_id)?;
1194
1195 session
1196 .db()
1197 .await
1198 .connection_lost(session.connection_id)
1199 .await
1200 .trace_err();
1201
1202 futures::select_biased! {
1203 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1204 if let Some(session) = session.for_user() {
1205 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1206 leave_room_for_session(&session, session.connection_id).await.trace_err();
1207 leave_channel_buffers_for_session(&session)
1208 .await
1209 .trace_err();
1210
1211 if !session
1212 .connection_pool()
1213 .await
1214 .is_user_online(session.user_id())
1215 {
1216 let db = session.db().await;
1217 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1218 room_updated(&room, &session.peer);
1219 }
1220 }
1221
1222 update_user_contacts(session.user_id(), &session).await?;
1223 }
1224 }
1225 _ = teardown.changed().fuse() => {}
1226 }
1227
1228 Ok(())
1229}
1230
1231/// Acknowledges a ping from a client, used to keep the connection alive.
1232async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1233 response.send(proto::Ack {})?;
1234 Ok(())
1235}
1236
1237/// Creates a new room for calling (outside of channels)
1238async fn create_room(
1239 _request: proto::CreateRoom,
1240 response: Response<proto::CreateRoom>,
1241 session: UserSession,
1242) -> Result<()> {
1243 let live_kit_room = nanoid::nanoid!(30);
1244
1245 let live_kit_connection_info = util::maybe!(async {
1246 let live_kit = session.live_kit_client.as_ref();
1247 let live_kit = live_kit?;
1248 let user_id = session.user_id().to_string();
1249
1250 let token = live_kit
1251 .room_token(&live_kit_room, &user_id.to_string())
1252 .trace_err()?;
1253
1254 Some(proto::LiveKitConnectionInfo {
1255 server_url: live_kit.url().into(),
1256 token,
1257 can_publish: true,
1258 })
1259 })
1260 .await;
1261
1262 let room = session
1263 .db()
1264 .await
1265 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1266 .await?;
1267
1268 response.send(proto::CreateRoomResponse {
1269 room: Some(room.clone()),
1270 live_kit_connection_info,
1271 })?;
1272
1273 update_user_contacts(session.user_id(), &session).await?;
1274 Ok(())
1275}
1276
1277/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1278async fn join_room(
1279 request: proto::JoinRoom,
1280 response: Response<proto::JoinRoom>,
1281 session: UserSession,
1282) -> Result<()> {
1283 let room_id = RoomId::from_proto(request.id);
1284
1285 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1286
1287 if let Some(channel_id) = channel_id {
1288 return join_channel_internal(channel_id, Box::new(response), session).await;
1289 }
1290
1291 let joined_room = {
1292 let room = session
1293 .db()
1294 .await
1295 .join_room(room_id, session.user_id(), session.connection_id)
1296 .await?;
1297 room_updated(&room.room, &session.peer);
1298 room.into_inner()
1299 };
1300
1301 for connection_id in session
1302 .connection_pool()
1303 .await
1304 .user_connection_ids(session.user_id())
1305 {
1306 session
1307 .peer
1308 .send(
1309 connection_id,
1310 proto::CallCanceled {
1311 room_id: room_id.to_proto(),
1312 },
1313 )
1314 .trace_err();
1315 }
1316
1317 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1318 if let Some(token) = live_kit
1319 .room_token(
1320 &joined_room.room.live_kit_room,
1321 &session.user_id().to_string(),
1322 )
1323 .trace_err()
1324 {
1325 Some(proto::LiveKitConnectionInfo {
1326 server_url: live_kit.url().into(),
1327 token,
1328 can_publish: true,
1329 })
1330 } else {
1331 None
1332 }
1333 } else {
1334 None
1335 };
1336
1337 response.send(proto::JoinRoomResponse {
1338 room: Some(joined_room.room),
1339 channel_id: None,
1340 live_kit_connection_info,
1341 })?;
1342
1343 update_user_contacts(session.user_id(), &session).await?;
1344 Ok(())
1345}
1346
1347/// Rejoin room is used to reconnect to a room after connection errors.
1348async fn rejoin_room(
1349 request: proto::RejoinRoom,
1350 response: Response<proto::RejoinRoom>,
1351 session: UserSession,
1352) -> Result<()> {
1353 let room;
1354 let channel;
1355 {
1356 let mut rejoined_room = session
1357 .db()
1358 .await
1359 .rejoin_room(request, session.user_id(), session.connection_id)
1360 .await?;
1361
1362 response.send(proto::RejoinRoomResponse {
1363 room: Some(rejoined_room.room.clone()),
1364 reshared_projects: rejoined_room
1365 .reshared_projects
1366 .iter()
1367 .map(|project| proto::ResharedProject {
1368 id: project.id.to_proto(),
1369 collaborators: project
1370 .collaborators
1371 .iter()
1372 .map(|collaborator| collaborator.to_proto())
1373 .collect(),
1374 })
1375 .collect(),
1376 rejoined_projects: rejoined_room
1377 .rejoined_projects
1378 .iter()
1379 .map(|rejoined_project| proto::RejoinedProject {
1380 id: rejoined_project.id.to_proto(),
1381 worktrees: rejoined_project
1382 .worktrees
1383 .iter()
1384 .map(|worktree| proto::WorktreeMetadata {
1385 id: worktree.id,
1386 root_name: worktree.root_name.clone(),
1387 visible: worktree.visible,
1388 abs_path: worktree.abs_path.clone(),
1389 })
1390 .collect(),
1391 collaborators: rejoined_project
1392 .collaborators
1393 .iter()
1394 .map(|collaborator| collaborator.to_proto())
1395 .collect(),
1396 language_servers: rejoined_project.language_servers.clone(),
1397 })
1398 .collect(),
1399 })?;
1400 room_updated(&rejoined_room.room, &session.peer);
1401
1402 for project in &rejoined_room.reshared_projects {
1403 for collaborator in &project.collaborators {
1404 session
1405 .peer
1406 .send(
1407 collaborator.connection_id,
1408 proto::UpdateProjectCollaborator {
1409 project_id: project.id.to_proto(),
1410 old_peer_id: Some(project.old_connection_id.into()),
1411 new_peer_id: Some(session.connection_id.into()),
1412 },
1413 )
1414 .trace_err();
1415 }
1416
1417 broadcast(
1418 Some(session.connection_id),
1419 project
1420 .collaborators
1421 .iter()
1422 .map(|collaborator| collaborator.connection_id),
1423 |connection_id| {
1424 session.peer.forward_send(
1425 session.connection_id,
1426 connection_id,
1427 proto::UpdateProject {
1428 project_id: project.id.to_proto(),
1429 worktrees: project.worktrees.clone(),
1430 },
1431 )
1432 },
1433 );
1434 }
1435
1436 for project in &rejoined_room.rejoined_projects {
1437 for collaborator in &project.collaborators {
1438 session
1439 .peer
1440 .send(
1441 collaborator.connection_id,
1442 proto::UpdateProjectCollaborator {
1443 project_id: project.id.to_proto(),
1444 old_peer_id: Some(project.old_connection_id.into()),
1445 new_peer_id: Some(session.connection_id.into()),
1446 },
1447 )
1448 .trace_err();
1449 }
1450 }
1451
1452 for project in &mut rejoined_room.rejoined_projects {
1453 for worktree in mem::take(&mut project.worktrees) {
1454 #[cfg(any(test, feature = "test-support"))]
1455 const MAX_CHUNK_SIZE: usize = 2;
1456 #[cfg(not(any(test, feature = "test-support")))]
1457 const MAX_CHUNK_SIZE: usize = 256;
1458
1459 // Stream this worktree's entries.
1460 let message = proto::UpdateWorktree {
1461 project_id: project.id.to_proto(),
1462 worktree_id: worktree.id,
1463 abs_path: worktree.abs_path.clone(),
1464 root_name: worktree.root_name,
1465 updated_entries: worktree.updated_entries,
1466 removed_entries: worktree.removed_entries,
1467 scan_id: worktree.scan_id,
1468 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1469 updated_repositories: worktree.updated_repositories,
1470 removed_repositories: worktree.removed_repositories,
1471 };
1472 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1473 session.peer.send(session.connection_id, update.clone())?;
1474 }
1475
1476 // Stream this worktree's diagnostics.
1477 for summary in worktree.diagnostic_summaries {
1478 session.peer.send(
1479 session.connection_id,
1480 proto::UpdateDiagnosticSummary {
1481 project_id: project.id.to_proto(),
1482 worktree_id: worktree.id,
1483 summary: Some(summary),
1484 },
1485 )?;
1486 }
1487
1488 for settings_file in worktree.settings_files {
1489 session.peer.send(
1490 session.connection_id,
1491 proto::UpdateWorktreeSettings {
1492 project_id: project.id.to_proto(),
1493 worktree_id: worktree.id,
1494 path: settings_file.path,
1495 content: Some(settings_file.content),
1496 },
1497 )?;
1498 }
1499 }
1500
1501 for language_server in &project.language_servers {
1502 session.peer.send(
1503 session.connection_id,
1504 proto::UpdateLanguageServer {
1505 project_id: project.id.to_proto(),
1506 language_server_id: language_server.id,
1507 variant: Some(
1508 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1509 proto::LspDiskBasedDiagnosticsUpdated {},
1510 ),
1511 ),
1512 },
1513 )?;
1514 }
1515 }
1516
1517 let rejoined_room = rejoined_room.into_inner();
1518
1519 room = rejoined_room.room;
1520 channel = rejoined_room.channel;
1521 }
1522
1523 if let Some(channel) = channel {
1524 channel_updated(
1525 &channel,
1526 &room,
1527 &session.peer,
1528 &*session.connection_pool().await,
1529 );
1530 }
1531
1532 update_user_contacts(session.user_id(), &session).await?;
1533 Ok(())
1534}
1535
1536/// leave room disconnects from the room.
1537async fn leave_room(
1538 _: proto::LeaveRoom,
1539 response: Response<proto::LeaveRoom>,
1540 session: UserSession,
1541) -> Result<()> {
1542 leave_room_for_session(&session, session.connection_id).await?;
1543 response.send(proto::Ack {})?;
1544 Ok(())
1545}
1546
1547/// Updates the permissions of someone else in the room.
1548async fn set_room_participant_role(
1549 request: proto::SetRoomParticipantRole,
1550 response: Response<proto::SetRoomParticipantRole>,
1551 session: UserSession,
1552) -> Result<()> {
1553 let user_id = UserId::from_proto(request.user_id);
1554 let role = ChannelRole::from(request.role());
1555
1556 let (live_kit_room, can_publish) = {
1557 let room = session
1558 .db()
1559 .await
1560 .set_room_participant_role(
1561 session.user_id(),
1562 RoomId::from_proto(request.room_id),
1563 user_id,
1564 role,
1565 )
1566 .await?;
1567
1568 let live_kit_room = room.live_kit_room.clone();
1569 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1570 room_updated(&room, &session.peer);
1571 (live_kit_room, can_publish)
1572 };
1573
1574 if let Some(live_kit) = session.live_kit_client.as_ref() {
1575 live_kit
1576 .update_participant(
1577 live_kit_room.clone(),
1578 request.user_id.to_string(),
1579 live_kit_server::proto::ParticipantPermission {
1580 can_subscribe: true,
1581 can_publish,
1582 can_publish_data: can_publish,
1583 hidden: false,
1584 recorder: false,
1585 },
1586 )
1587 .await
1588 .trace_err();
1589 }
1590
1591 response.send(proto::Ack {})?;
1592 Ok(())
1593}
1594
1595/// Call someone else into the current room
1596async fn call(
1597 request: proto::Call,
1598 response: Response<proto::Call>,
1599 session: UserSession,
1600) -> Result<()> {
1601 let room_id = RoomId::from_proto(request.room_id);
1602 let calling_user_id = session.user_id();
1603 let calling_connection_id = session.connection_id;
1604 let called_user_id = UserId::from_proto(request.called_user_id);
1605 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1606 if !session
1607 .db()
1608 .await
1609 .has_contact(calling_user_id, called_user_id)
1610 .await?
1611 {
1612 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1613 }
1614
1615 let incoming_call = {
1616 let (room, incoming_call) = &mut *session
1617 .db()
1618 .await
1619 .call(
1620 room_id,
1621 calling_user_id,
1622 calling_connection_id,
1623 called_user_id,
1624 initial_project_id,
1625 )
1626 .await?;
1627 room_updated(&room, &session.peer);
1628 mem::take(incoming_call)
1629 };
1630 update_user_contacts(called_user_id, &session).await?;
1631
1632 let mut calls = session
1633 .connection_pool()
1634 .await
1635 .user_connection_ids(called_user_id)
1636 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1637 .collect::<FuturesUnordered<_>>();
1638
1639 while let Some(call_response) = calls.next().await {
1640 match call_response.as_ref() {
1641 Ok(_) => {
1642 response.send(proto::Ack {})?;
1643 return Ok(());
1644 }
1645 Err(_) => {
1646 call_response.trace_err();
1647 }
1648 }
1649 }
1650
1651 {
1652 let room = session
1653 .db()
1654 .await
1655 .call_failed(room_id, called_user_id)
1656 .await?;
1657 room_updated(&room, &session.peer);
1658 }
1659 update_user_contacts(called_user_id, &session).await?;
1660
1661 Err(anyhow!("failed to ring user"))?
1662}
1663
1664/// Cancel an outgoing call.
1665async fn cancel_call(
1666 request: proto::CancelCall,
1667 response: Response<proto::CancelCall>,
1668 session: UserSession,
1669) -> Result<()> {
1670 let called_user_id = UserId::from_proto(request.called_user_id);
1671 let room_id = RoomId::from_proto(request.room_id);
1672 {
1673 let room = session
1674 .db()
1675 .await
1676 .cancel_call(room_id, session.connection_id, called_user_id)
1677 .await?;
1678 room_updated(&room, &session.peer);
1679 }
1680
1681 for connection_id in session
1682 .connection_pool()
1683 .await
1684 .user_connection_ids(called_user_id)
1685 {
1686 session
1687 .peer
1688 .send(
1689 connection_id,
1690 proto::CallCanceled {
1691 room_id: room_id.to_proto(),
1692 },
1693 )
1694 .trace_err();
1695 }
1696 response.send(proto::Ack {})?;
1697
1698 update_user_contacts(called_user_id, &session).await?;
1699 Ok(())
1700}
1701
1702/// Decline an incoming call.
1703async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1704 let room_id = RoomId::from_proto(message.room_id);
1705 {
1706 let room = session
1707 .db()
1708 .await
1709 .decline_call(Some(room_id), session.user_id())
1710 .await?
1711 .ok_or_else(|| anyhow!("failed to decline call"))?;
1712 room_updated(&room, &session.peer);
1713 }
1714
1715 for connection_id in session
1716 .connection_pool()
1717 .await
1718 .user_connection_ids(session.user_id())
1719 {
1720 session
1721 .peer
1722 .send(
1723 connection_id,
1724 proto::CallCanceled {
1725 room_id: room_id.to_proto(),
1726 },
1727 )
1728 .trace_err();
1729 }
1730 update_user_contacts(session.user_id(), &session).await?;
1731 Ok(())
1732}
1733
1734/// Updates other participants in the room with your current location.
1735async fn update_participant_location(
1736 request: proto::UpdateParticipantLocation,
1737 response: Response<proto::UpdateParticipantLocation>,
1738 session: UserSession,
1739) -> Result<()> {
1740 let room_id = RoomId::from_proto(request.room_id);
1741 let location = request
1742 .location
1743 .ok_or_else(|| anyhow!("invalid location"))?;
1744
1745 let db = session.db().await;
1746 let room = db
1747 .update_room_participant_location(room_id, session.connection_id, location)
1748 .await?;
1749
1750 room_updated(&room, &session.peer);
1751 response.send(proto::Ack {})?;
1752 Ok(())
1753}
1754
1755/// Share a project into the room.
1756async fn share_project(
1757 request: proto::ShareProject,
1758 response: Response<proto::ShareProject>,
1759 session: Session,
1760) -> Result<()> {
1761 let (project_id, room) = &*session
1762 .db()
1763 .await
1764 .share_project(
1765 RoomId::from_proto(request.room_id),
1766 session.connection_id,
1767 &request.worktrees,
1768 )
1769 .await?;
1770 response.send(proto::ShareProjectResponse {
1771 project_id: project_id.to_proto(),
1772 })?;
1773 room_updated(&room, &session.peer);
1774
1775 Ok(())
1776}
1777
1778/// Unshare a project from the room.
1779async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1780 let project_id = ProjectId::from_proto(message.project_id);
1781
1782 let (room, guest_connection_ids) = &*session
1783 .db()
1784 .await
1785 .unshare_project(project_id, session.connection_id)
1786 .await?;
1787
1788 broadcast(
1789 Some(session.connection_id),
1790 guest_connection_ids.iter().copied(),
1791 |conn_id| session.peer.send(conn_id, message.clone()),
1792 );
1793 room_updated(&room, &session.peer);
1794
1795 Ok(())
1796}
1797
1798/// Join someone elses shared project.
1799async fn join_project(
1800 request: proto::JoinProject,
1801 response: Response<proto::JoinProject>,
1802 session: UserSession,
1803) -> Result<()> {
1804 let project_id = ProjectId::from_proto(request.project_id);
1805
1806 tracing::info!(%project_id, "join project");
1807
1808 let (project, replica_id) = &mut *session
1809 .db()
1810 .await
1811 .join_project_in_room(project_id, session.connection_id)
1812 .await?;
1813
1814 join_project_internal(response, session, project, replica_id)
1815}
1816
1817trait JoinProjectInternalResponse {
1818 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1819}
1820impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1821 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1822 Response::<proto::JoinProject>::send(self, result)
1823 }
1824}
1825impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1826 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1827 Response::<proto::JoinHostedProject>::send(self, result)
1828 }
1829}
1830
1831fn join_project_internal(
1832 response: impl JoinProjectInternalResponse,
1833 session: UserSession,
1834 project: &mut Project,
1835 replica_id: &ReplicaId,
1836) -> Result<()> {
1837 let collaborators = project
1838 .collaborators
1839 .iter()
1840 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1841 .map(|collaborator| collaborator.to_proto())
1842 .collect::<Vec<_>>();
1843 let project_id = project.id;
1844 let guest_user_id = session.user_id();
1845
1846 let worktrees = project
1847 .worktrees
1848 .iter()
1849 .map(|(id, worktree)| proto::WorktreeMetadata {
1850 id: *id,
1851 root_name: worktree.root_name.clone(),
1852 visible: worktree.visible,
1853 abs_path: worktree.abs_path.clone(),
1854 })
1855 .collect::<Vec<_>>();
1856
1857 for collaborator in &collaborators {
1858 session
1859 .peer
1860 .send(
1861 collaborator.peer_id.unwrap().into(),
1862 proto::AddProjectCollaborator {
1863 project_id: project_id.to_proto(),
1864 collaborator: Some(proto::Collaborator {
1865 peer_id: Some(session.connection_id.into()),
1866 replica_id: replica_id.0 as u32,
1867 user_id: guest_user_id.to_proto(),
1868 }),
1869 },
1870 )
1871 .trace_err();
1872 }
1873
1874 // First, we send the metadata associated with each worktree.
1875 response.send(proto::JoinProjectResponse {
1876 project_id: project.id.0 as u64,
1877 worktrees: worktrees.clone(),
1878 replica_id: replica_id.0 as u32,
1879 collaborators: collaborators.clone(),
1880 language_servers: project.language_servers.clone(),
1881 role: project.role.into(), // todo
1882 })?;
1883
1884 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1885 #[cfg(any(test, feature = "test-support"))]
1886 const MAX_CHUNK_SIZE: usize = 2;
1887 #[cfg(not(any(test, feature = "test-support")))]
1888 const MAX_CHUNK_SIZE: usize = 256;
1889
1890 // Stream this worktree's entries.
1891 let message = proto::UpdateWorktree {
1892 project_id: project_id.to_proto(),
1893 worktree_id,
1894 abs_path: worktree.abs_path.clone(),
1895 root_name: worktree.root_name,
1896 updated_entries: worktree.entries,
1897 removed_entries: Default::default(),
1898 scan_id: worktree.scan_id,
1899 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1900 updated_repositories: worktree.repository_entries.into_values().collect(),
1901 removed_repositories: Default::default(),
1902 };
1903 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1904 session.peer.send(session.connection_id, update.clone())?;
1905 }
1906
1907 // Stream this worktree's diagnostics.
1908 for summary in worktree.diagnostic_summaries {
1909 session.peer.send(
1910 session.connection_id,
1911 proto::UpdateDiagnosticSummary {
1912 project_id: project_id.to_proto(),
1913 worktree_id: worktree.id,
1914 summary: Some(summary),
1915 },
1916 )?;
1917 }
1918
1919 for settings_file in worktree.settings_files {
1920 session.peer.send(
1921 session.connection_id,
1922 proto::UpdateWorktreeSettings {
1923 project_id: project_id.to_proto(),
1924 worktree_id: worktree.id,
1925 path: settings_file.path,
1926 content: Some(settings_file.content),
1927 },
1928 )?;
1929 }
1930 }
1931
1932 for language_server in &project.language_servers {
1933 session.peer.send(
1934 session.connection_id,
1935 proto::UpdateLanguageServer {
1936 project_id: project_id.to_proto(),
1937 language_server_id: language_server.id,
1938 variant: Some(
1939 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1940 proto::LspDiskBasedDiagnosticsUpdated {},
1941 ),
1942 ),
1943 },
1944 )?;
1945 }
1946
1947 Ok(())
1948}
1949
1950/// Leave someone elses shared project.
1951async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
1952 let sender_id = session.connection_id;
1953 let project_id = ProjectId::from_proto(request.project_id);
1954 let db = session.db().await;
1955 if db.is_hosted_project(project_id).await? {
1956 let project = db.leave_hosted_project(project_id, sender_id).await?;
1957 project_left(&project, &session);
1958 return Ok(());
1959 }
1960
1961 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1962 tracing::info!(
1963 %project_id,
1964 host_user_id = ?project.host_user_id,
1965 host_connection_id = ?project.host_connection_id,
1966 "leave project"
1967 );
1968
1969 project_left(&project, &session);
1970 room_updated(&room, &session.peer);
1971
1972 Ok(())
1973}
1974
1975async fn join_hosted_project(
1976 request: proto::JoinHostedProject,
1977 response: Response<proto::JoinHostedProject>,
1978 session: UserSession,
1979) -> Result<()> {
1980 let (mut project, replica_id) = session
1981 .db()
1982 .await
1983 .join_hosted_project(
1984 ProjectId(request.project_id as i32),
1985 session.user_id(),
1986 session.connection_id,
1987 )
1988 .await?;
1989
1990 join_project_internal(response, session, &mut project, &replica_id)
1991}
1992
1993/// Updates other participants with changes to the project
1994async fn update_project(
1995 request: proto::UpdateProject,
1996 response: Response<proto::UpdateProject>,
1997 session: Session,
1998) -> Result<()> {
1999 let project_id = ProjectId::from_proto(request.project_id);
2000 let (room, guest_connection_ids) = &*session
2001 .db()
2002 .await
2003 .update_project(project_id, session.connection_id, &request.worktrees)
2004 .await?;
2005 broadcast(
2006 Some(session.connection_id),
2007 guest_connection_ids.iter().copied(),
2008 |connection_id| {
2009 session
2010 .peer
2011 .forward_send(session.connection_id, connection_id, request.clone())
2012 },
2013 );
2014 room_updated(&room, &session.peer);
2015 response.send(proto::Ack {})?;
2016
2017 Ok(())
2018}
2019
2020/// Updates other participants with changes to the worktree
2021async fn update_worktree(
2022 request: proto::UpdateWorktree,
2023 response: Response<proto::UpdateWorktree>,
2024 session: Session,
2025) -> Result<()> {
2026 let guest_connection_ids = session
2027 .db()
2028 .await
2029 .update_worktree(&request, session.connection_id)
2030 .await?;
2031
2032 broadcast(
2033 Some(session.connection_id),
2034 guest_connection_ids.iter().copied(),
2035 |connection_id| {
2036 session
2037 .peer
2038 .forward_send(session.connection_id, connection_id, request.clone())
2039 },
2040 );
2041 response.send(proto::Ack {})?;
2042 Ok(())
2043}
2044
2045/// Updates other participants with changes to the diagnostics
2046async fn update_diagnostic_summary(
2047 message: proto::UpdateDiagnosticSummary,
2048 session: Session,
2049) -> Result<()> {
2050 let guest_connection_ids = session
2051 .db()
2052 .await
2053 .update_diagnostic_summary(&message, session.connection_id)
2054 .await?;
2055
2056 broadcast(
2057 Some(session.connection_id),
2058 guest_connection_ids.iter().copied(),
2059 |connection_id| {
2060 session
2061 .peer
2062 .forward_send(session.connection_id, connection_id, message.clone())
2063 },
2064 );
2065
2066 Ok(())
2067}
2068
2069/// Updates other participants with changes to the worktree settings
2070async fn update_worktree_settings(
2071 message: proto::UpdateWorktreeSettings,
2072 session: Session,
2073) -> Result<()> {
2074 let guest_connection_ids = session
2075 .db()
2076 .await
2077 .update_worktree_settings(&message, session.connection_id)
2078 .await?;
2079
2080 broadcast(
2081 Some(session.connection_id),
2082 guest_connection_ids.iter().copied(),
2083 |connection_id| {
2084 session
2085 .peer
2086 .forward_send(session.connection_id, connection_id, message.clone())
2087 },
2088 );
2089
2090 Ok(())
2091}
2092
2093/// Notify other participants that a language server has started.
2094async fn start_language_server(
2095 request: proto::StartLanguageServer,
2096 session: Session,
2097) -> Result<()> {
2098 let guest_connection_ids = session
2099 .db()
2100 .await
2101 .start_language_server(&request, session.connection_id)
2102 .await?;
2103
2104 broadcast(
2105 Some(session.connection_id),
2106 guest_connection_ids.iter().copied(),
2107 |connection_id| {
2108 session
2109 .peer
2110 .forward_send(session.connection_id, connection_id, request.clone())
2111 },
2112 );
2113 Ok(())
2114}
2115
2116/// Notify other participants that a language server has changed.
2117async fn update_language_server(
2118 request: proto::UpdateLanguageServer,
2119 session: Session,
2120) -> Result<()> {
2121 let project_id = ProjectId::from_proto(request.project_id);
2122 let project_connection_ids = session
2123 .db()
2124 .await
2125 .project_connection_ids(project_id, session.connection_id)
2126 .await?;
2127 broadcast(
2128 Some(session.connection_id),
2129 project_connection_ids.iter().copied(),
2130 |connection_id| {
2131 session
2132 .peer
2133 .forward_send(session.connection_id, connection_id, request.clone())
2134 },
2135 );
2136 Ok(())
2137}
2138
2139/// forward a project request to the host. These requests should be read only
2140/// as guests are allowed to send them.
2141async fn forward_read_only_project_request<T>(
2142 request: T,
2143 response: Response<T>,
2144 session: Session,
2145) -> Result<()>
2146where
2147 T: EntityMessage + RequestMessage,
2148{
2149 let project_id = ProjectId::from_proto(request.remote_entity_id());
2150 let host_connection_id = session
2151 .db()
2152 .await
2153 .host_for_read_only_project_request(project_id, session.connection_id)
2154 .await?;
2155 let payload = session
2156 .peer
2157 .forward_request(session.connection_id, host_connection_id, request)
2158 .await?;
2159 response.send(payload)?;
2160 Ok(())
2161}
2162
2163/// forward a project request to the host. These requests are disallowed
2164/// for guests.
2165async fn forward_mutating_project_request<T>(
2166 request: T,
2167 response: Response<T>,
2168 session: Session,
2169) -> Result<()>
2170where
2171 T: EntityMessage + RequestMessage,
2172{
2173 let project_id = ProjectId::from_proto(request.remote_entity_id());
2174 let host_connection_id = session
2175 .db()
2176 .await
2177 .host_for_mutating_project_request(project_id, session.connection_id)
2178 .await?;
2179 let payload = session
2180 .peer
2181 .forward_request(session.connection_id, host_connection_id, request)
2182 .await?;
2183 response.send(payload)?;
2184 Ok(())
2185}
2186
2187/// Notify other participants that a new buffer has been created
2188async fn create_buffer_for_peer(
2189 request: proto::CreateBufferForPeer,
2190 session: Session,
2191) -> Result<()> {
2192 session
2193 .db()
2194 .await
2195 .check_user_is_project_host(
2196 ProjectId::from_proto(request.project_id),
2197 session.connection_id,
2198 )
2199 .await?;
2200 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2201 session
2202 .peer
2203 .forward_send(session.connection_id, peer_id.into(), request)?;
2204 Ok(())
2205}
2206
2207/// Notify other participants that a buffer has been updated. This is
2208/// allowed for guests as long as the update is limited to selections.
2209async fn update_buffer(
2210 request: proto::UpdateBuffer,
2211 response: Response<proto::UpdateBuffer>,
2212 session: Session,
2213) -> Result<()> {
2214 let project_id = ProjectId::from_proto(request.project_id);
2215 let mut guest_connection_ids;
2216 let mut host_connection_id = None;
2217
2218 let mut requires_write_permission = false;
2219
2220 for op in request.operations.iter() {
2221 match op.variant {
2222 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2223 Some(_) => requires_write_permission = true,
2224 }
2225 }
2226
2227 {
2228 let collaborators = session
2229 .db()
2230 .await
2231 .project_collaborators_for_buffer_update(
2232 project_id,
2233 session.connection_id,
2234 requires_write_permission,
2235 )
2236 .await?;
2237 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2238 for collaborator in collaborators.iter() {
2239 if collaborator.is_host {
2240 host_connection_id = Some(collaborator.connection_id);
2241 } else {
2242 guest_connection_ids.push(collaborator.connection_id);
2243 }
2244 }
2245 }
2246 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2247
2248 broadcast(
2249 Some(session.connection_id),
2250 guest_connection_ids,
2251 |connection_id| {
2252 session
2253 .peer
2254 .forward_send(session.connection_id, connection_id, request.clone())
2255 },
2256 );
2257 if host_connection_id != session.connection_id {
2258 session
2259 .peer
2260 .forward_request(session.connection_id, host_connection_id, request.clone())
2261 .await?;
2262 }
2263
2264 response.send(proto::Ack {})?;
2265 Ok(())
2266}
2267
2268/// Notify other participants that a project has been updated.
2269async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2270 request: T,
2271 session: Session,
2272) -> Result<()> {
2273 let project_id = ProjectId::from_proto(request.remote_entity_id());
2274 let project_connection_ids = session
2275 .db()
2276 .await
2277 .project_connection_ids(project_id, session.connection_id)
2278 .await?;
2279
2280 broadcast(
2281 Some(session.connection_id),
2282 project_connection_ids.iter().copied(),
2283 |connection_id| {
2284 session
2285 .peer
2286 .forward_send(session.connection_id, connection_id, request.clone())
2287 },
2288 );
2289 Ok(())
2290}
2291
2292/// Start following another user in a call.
2293async fn follow(
2294 request: proto::Follow,
2295 response: Response<proto::Follow>,
2296 session: UserSession,
2297) -> Result<()> {
2298 let room_id = RoomId::from_proto(request.room_id);
2299 let project_id = request.project_id.map(ProjectId::from_proto);
2300 let leader_id = request
2301 .leader_id
2302 .ok_or_else(|| anyhow!("invalid leader id"))?
2303 .into();
2304 let follower_id = session.connection_id;
2305
2306 session
2307 .db()
2308 .await
2309 .check_room_participants(room_id, leader_id, session.connection_id)
2310 .await?;
2311
2312 let response_payload = session
2313 .peer
2314 .forward_request(session.connection_id, leader_id, request)
2315 .await?;
2316 response.send(response_payload)?;
2317
2318 if let Some(project_id) = project_id {
2319 let room = session
2320 .db()
2321 .await
2322 .follow(room_id, project_id, leader_id, follower_id)
2323 .await?;
2324 room_updated(&room, &session.peer);
2325 }
2326
2327 Ok(())
2328}
2329
2330/// Stop following another user in a call.
2331async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2332 let room_id = RoomId::from_proto(request.room_id);
2333 let project_id = request.project_id.map(ProjectId::from_proto);
2334 let leader_id = request
2335 .leader_id
2336 .ok_or_else(|| anyhow!("invalid leader id"))?
2337 .into();
2338 let follower_id = session.connection_id;
2339
2340 session
2341 .db()
2342 .await
2343 .check_room_participants(room_id, leader_id, session.connection_id)
2344 .await?;
2345
2346 session
2347 .peer
2348 .forward_send(session.connection_id, leader_id, request)?;
2349
2350 if let Some(project_id) = project_id {
2351 let room = session
2352 .db()
2353 .await
2354 .unfollow(room_id, project_id, leader_id, follower_id)
2355 .await?;
2356 room_updated(&room, &session.peer);
2357 }
2358
2359 Ok(())
2360}
2361
2362/// Notify everyone following you of your current location.
2363async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2364 let room_id = RoomId::from_proto(request.room_id);
2365 let database = session.db.lock().await;
2366
2367 let connection_ids = if let Some(project_id) = request.project_id {
2368 let project_id = ProjectId::from_proto(project_id);
2369 database
2370 .project_connection_ids(project_id, session.connection_id)
2371 .await?
2372 } else {
2373 database
2374 .room_connection_ids(room_id, session.connection_id)
2375 .await?
2376 };
2377
2378 // For now, don't send view update messages back to that view's current leader.
2379 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2380 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2381 _ => None,
2382 });
2383
2384 for connection_id in connection_ids.iter().cloned() {
2385 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2386 session
2387 .peer
2388 .forward_send(session.connection_id, connection_id, request.clone())?;
2389 }
2390 }
2391 Ok(())
2392}
2393
2394/// Get public data about users.
2395async fn get_users(
2396 request: proto::GetUsers,
2397 response: Response<proto::GetUsers>,
2398 session: Session,
2399) -> Result<()> {
2400 let user_ids = request
2401 .user_ids
2402 .into_iter()
2403 .map(UserId::from_proto)
2404 .collect();
2405 let users = session
2406 .db()
2407 .await
2408 .get_users_by_ids(user_ids)
2409 .await?
2410 .into_iter()
2411 .map(|user| proto::User {
2412 id: user.id.to_proto(),
2413 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2414 github_login: user.github_login,
2415 })
2416 .collect();
2417 response.send(proto::UsersResponse { users })?;
2418 Ok(())
2419}
2420
2421/// Search for users (to invite) buy Github login
2422async fn fuzzy_search_users(
2423 request: proto::FuzzySearchUsers,
2424 response: Response<proto::FuzzySearchUsers>,
2425 session: UserSession,
2426) -> Result<()> {
2427 let query = request.query;
2428 let users = match query.len() {
2429 0 => vec![],
2430 1 | 2 => session
2431 .db()
2432 .await
2433 .get_user_by_github_login(&query)
2434 .await?
2435 .into_iter()
2436 .collect(),
2437 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2438 };
2439 let users = users
2440 .into_iter()
2441 .filter(|user| user.id != session.user_id())
2442 .map(|user| proto::User {
2443 id: user.id.to_proto(),
2444 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2445 github_login: user.github_login,
2446 })
2447 .collect();
2448 response.send(proto::UsersResponse { users })?;
2449 Ok(())
2450}
2451
2452/// Send a contact request to another user.
2453async fn request_contact(
2454 request: proto::RequestContact,
2455 response: Response<proto::RequestContact>,
2456 session: UserSession,
2457) -> Result<()> {
2458 let requester_id = session.user_id();
2459 let responder_id = UserId::from_proto(request.responder_id);
2460 if requester_id == responder_id {
2461 return Err(anyhow!("cannot add yourself as a contact"))?;
2462 }
2463
2464 let notifications = session
2465 .db()
2466 .await
2467 .send_contact_request(requester_id, responder_id)
2468 .await?;
2469
2470 // Update outgoing contact requests of requester
2471 let mut update = proto::UpdateContacts::default();
2472 update.outgoing_requests.push(responder_id.to_proto());
2473 for connection_id in session
2474 .connection_pool()
2475 .await
2476 .user_connection_ids(requester_id)
2477 {
2478 session.peer.send(connection_id, update.clone())?;
2479 }
2480
2481 // Update incoming contact requests of responder
2482 let mut update = proto::UpdateContacts::default();
2483 update
2484 .incoming_requests
2485 .push(proto::IncomingContactRequest {
2486 requester_id: requester_id.to_proto(),
2487 });
2488 let connection_pool = session.connection_pool().await;
2489 for connection_id in connection_pool.user_connection_ids(responder_id) {
2490 session.peer.send(connection_id, update.clone())?;
2491 }
2492
2493 send_notifications(&connection_pool, &session.peer, notifications);
2494
2495 response.send(proto::Ack {})?;
2496 Ok(())
2497}
2498
2499/// Accept or decline a contact request
2500async fn respond_to_contact_request(
2501 request: proto::RespondToContactRequest,
2502 response: Response<proto::RespondToContactRequest>,
2503 session: UserSession,
2504) -> Result<()> {
2505 let responder_id = session.user_id();
2506 let requester_id = UserId::from_proto(request.requester_id);
2507 let db = session.db().await;
2508 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2509 db.dismiss_contact_notification(responder_id, requester_id)
2510 .await?;
2511 } else {
2512 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2513
2514 let notifications = db
2515 .respond_to_contact_request(responder_id, requester_id, accept)
2516 .await?;
2517 let requester_busy = db.is_user_busy(requester_id).await?;
2518 let responder_busy = db.is_user_busy(responder_id).await?;
2519
2520 let pool = session.connection_pool().await;
2521 // Update responder with new contact
2522 let mut update = proto::UpdateContacts::default();
2523 if accept {
2524 update
2525 .contacts
2526 .push(contact_for_user(requester_id, requester_busy, &pool));
2527 }
2528 update
2529 .remove_incoming_requests
2530 .push(requester_id.to_proto());
2531 for connection_id in pool.user_connection_ids(responder_id) {
2532 session.peer.send(connection_id, update.clone())?;
2533 }
2534
2535 // Update requester with new contact
2536 let mut update = proto::UpdateContacts::default();
2537 if accept {
2538 update
2539 .contacts
2540 .push(contact_for_user(responder_id, responder_busy, &pool));
2541 }
2542 update
2543 .remove_outgoing_requests
2544 .push(responder_id.to_proto());
2545
2546 for connection_id in pool.user_connection_ids(requester_id) {
2547 session.peer.send(connection_id, update.clone())?;
2548 }
2549
2550 send_notifications(&pool, &session.peer, notifications);
2551 }
2552
2553 response.send(proto::Ack {})?;
2554 Ok(())
2555}
2556
2557/// Remove a contact.
2558async fn remove_contact(
2559 request: proto::RemoveContact,
2560 response: Response<proto::RemoveContact>,
2561 session: UserSession,
2562) -> Result<()> {
2563 let requester_id = session.user_id();
2564 let responder_id = UserId::from_proto(request.user_id);
2565 let db = session.db().await;
2566 let (contact_accepted, deleted_notification_id) =
2567 db.remove_contact(requester_id, responder_id).await?;
2568
2569 let pool = session.connection_pool().await;
2570 // Update outgoing contact requests of requester
2571 let mut update = proto::UpdateContacts::default();
2572 if contact_accepted {
2573 update.remove_contacts.push(responder_id.to_proto());
2574 } else {
2575 update
2576 .remove_outgoing_requests
2577 .push(responder_id.to_proto());
2578 }
2579 for connection_id in pool.user_connection_ids(requester_id) {
2580 session.peer.send(connection_id, update.clone())?;
2581 }
2582
2583 // Update incoming contact requests of responder
2584 let mut update = proto::UpdateContacts::default();
2585 if contact_accepted {
2586 update.remove_contacts.push(requester_id.to_proto());
2587 } else {
2588 update
2589 .remove_incoming_requests
2590 .push(requester_id.to_proto());
2591 }
2592 for connection_id in pool.user_connection_ids(responder_id) {
2593 session.peer.send(connection_id, update.clone())?;
2594 if let Some(notification_id) = deleted_notification_id {
2595 session.peer.send(
2596 connection_id,
2597 proto::DeleteNotification {
2598 notification_id: notification_id.to_proto(),
2599 },
2600 )?;
2601 }
2602 }
2603
2604 response.send(proto::Ack {})?;
2605 Ok(())
2606}
2607
2608/// Creates a new channel.
2609async fn create_channel(
2610 request: proto::CreateChannel,
2611 response: Response<proto::CreateChannel>,
2612 session: UserSession,
2613) -> Result<()> {
2614 let db = session.db().await;
2615
2616 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2617 let (channel, membership) = db
2618 .create_channel(&request.name, parent_id, session.user_id())
2619 .await?;
2620
2621 let root_id = channel.root_id();
2622 let channel = Channel::from_model(channel);
2623
2624 response.send(proto::CreateChannelResponse {
2625 channel: Some(channel.to_proto()),
2626 parent_id: request.parent_id,
2627 })?;
2628
2629 let mut connection_pool = session.connection_pool().await;
2630 if let Some(membership) = membership {
2631 connection_pool.subscribe_to_channel(
2632 membership.user_id,
2633 membership.channel_id,
2634 membership.role,
2635 );
2636 let update = proto::UpdateUserChannels {
2637 channel_memberships: vec![proto::ChannelMembership {
2638 channel_id: membership.channel_id.to_proto(),
2639 role: membership.role.into(),
2640 }],
2641 ..Default::default()
2642 };
2643 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2644 session.peer.send(connection_id, update.clone())?;
2645 }
2646 }
2647
2648 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2649 if !role.can_see_channel(channel.visibility) {
2650 continue;
2651 }
2652
2653 let update = proto::UpdateChannels {
2654 channels: vec![channel.to_proto()],
2655 ..Default::default()
2656 };
2657 session.peer.send(connection_id, update.clone())?;
2658 }
2659
2660 Ok(())
2661}
2662
2663/// Delete a channel
2664async fn delete_channel(
2665 request: proto::DeleteChannel,
2666 response: Response<proto::DeleteChannel>,
2667 session: UserSession,
2668) -> Result<()> {
2669 let db = session.db().await;
2670
2671 let channel_id = request.channel_id;
2672 let (root_channel, removed_channels) = db
2673 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2674 .await?;
2675 response.send(proto::Ack {})?;
2676
2677 // Notify members of removed channels
2678 let mut update = proto::UpdateChannels::default();
2679 update
2680 .delete_channels
2681 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2682
2683 let connection_pool = session.connection_pool().await;
2684 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2685 session.peer.send(connection_id, update.clone())?;
2686 }
2687
2688 Ok(())
2689}
2690
2691/// Invite someone to join a channel.
2692async fn invite_channel_member(
2693 request: proto::InviteChannelMember,
2694 response: Response<proto::InviteChannelMember>,
2695 session: UserSession,
2696) -> Result<()> {
2697 let db = session.db().await;
2698 let channel_id = ChannelId::from_proto(request.channel_id);
2699 let invitee_id = UserId::from_proto(request.user_id);
2700 let InviteMemberResult {
2701 channel,
2702 notifications,
2703 } = db
2704 .invite_channel_member(
2705 channel_id,
2706 invitee_id,
2707 session.user_id(),
2708 request.role().into(),
2709 )
2710 .await?;
2711
2712 let update = proto::UpdateChannels {
2713 channel_invitations: vec![channel.to_proto()],
2714 ..Default::default()
2715 };
2716
2717 let connection_pool = session.connection_pool().await;
2718 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2719 session.peer.send(connection_id, update.clone())?;
2720 }
2721
2722 send_notifications(&connection_pool, &session.peer, notifications);
2723
2724 response.send(proto::Ack {})?;
2725 Ok(())
2726}
2727
2728/// remove someone from a channel
2729async fn remove_channel_member(
2730 request: proto::RemoveChannelMember,
2731 response: Response<proto::RemoveChannelMember>,
2732 session: UserSession,
2733) -> Result<()> {
2734 let db = session.db().await;
2735 let channel_id = ChannelId::from_proto(request.channel_id);
2736 let member_id = UserId::from_proto(request.user_id);
2737
2738 let RemoveChannelMemberResult {
2739 membership_update,
2740 notification_id,
2741 } = db
2742 .remove_channel_member(channel_id, member_id, session.user_id())
2743 .await?;
2744
2745 let mut connection_pool = session.connection_pool().await;
2746 notify_membership_updated(
2747 &mut connection_pool,
2748 membership_update,
2749 member_id,
2750 &session.peer,
2751 );
2752 for connection_id in connection_pool.user_connection_ids(member_id) {
2753 if let Some(notification_id) = notification_id {
2754 session
2755 .peer
2756 .send(
2757 connection_id,
2758 proto::DeleteNotification {
2759 notification_id: notification_id.to_proto(),
2760 },
2761 )
2762 .trace_err();
2763 }
2764 }
2765
2766 response.send(proto::Ack {})?;
2767 Ok(())
2768}
2769
2770/// Toggle the channel between public and private.
2771/// Care is taken to maintain the invariant that public channels only descend from public channels,
2772/// (though members-only channels can appear at any point in the hierarchy).
2773async fn set_channel_visibility(
2774 request: proto::SetChannelVisibility,
2775 response: Response<proto::SetChannelVisibility>,
2776 session: UserSession,
2777) -> Result<()> {
2778 let db = session.db().await;
2779 let channel_id = ChannelId::from_proto(request.channel_id);
2780 let visibility = request.visibility().into();
2781
2782 let channel_model = db
2783 .set_channel_visibility(channel_id, visibility, session.user_id())
2784 .await?;
2785 let root_id = channel_model.root_id();
2786 let channel = Channel::from_model(channel_model);
2787
2788 let mut connection_pool = session.connection_pool().await;
2789 for (user_id, role) in connection_pool
2790 .channel_user_ids(root_id)
2791 .collect::<Vec<_>>()
2792 .into_iter()
2793 {
2794 let update = if role.can_see_channel(channel.visibility) {
2795 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2796 proto::UpdateChannels {
2797 channels: vec![channel.to_proto()],
2798 ..Default::default()
2799 }
2800 } else {
2801 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2802 proto::UpdateChannels {
2803 delete_channels: vec![channel.id.to_proto()],
2804 ..Default::default()
2805 }
2806 };
2807
2808 for connection_id in connection_pool.user_connection_ids(user_id) {
2809 session.peer.send(connection_id, update.clone())?;
2810 }
2811 }
2812
2813 response.send(proto::Ack {})?;
2814 Ok(())
2815}
2816
2817/// Alter the role for a user in the channel.
2818async fn set_channel_member_role(
2819 request: proto::SetChannelMemberRole,
2820 response: Response<proto::SetChannelMemberRole>,
2821 session: UserSession,
2822) -> Result<()> {
2823 let db = session.db().await;
2824 let channel_id = ChannelId::from_proto(request.channel_id);
2825 let member_id = UserId::from_proto(request.user_id);
2826 let result = db
2827 .set_channel_member_role(
2828 channel_id,
2829 session.user_id(),
2830 member_id,
2831 request.role().into(),
2832 )
2833 .await?;
2834
2835 match result {
2836 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2837 let mut connection_pool = session.connection_pool().await;
2838 notify_membership_updated(
2839 &mut connection_pool,
2840 membership_update,
2841 member_id,
2842 &session.peer,
2843 )
2844 }
2845 db::SetMemberRoleResult::InviteUpdated(channel) => {
2846 let update = proto::UpdateChannels {
2847 channel_invitations: vec![channel.to_proto()],
2848 ..Default::default()
2849 };
2850
2851 for connection_id in session
2852 .connection_pool()
2853 .await
2854 .user_connection_ids(member_id)
2855 {
2856 session.peer.send(connection_id, update.clone())?;
2857 }
2858 }
2859 }
2860
2861 response.send(proto::Ack {})?;
2862 Ok(())
2863}
2864
2865/// Change the name of a channel
2866async fn rename_channel(
2867 request: proto::RenameChannel,
2868 response: Response<proto::RenameChannel>,
2869 session: UserSession,
2870) -> Result<()> {
2871 let db = session.db().await;
2872 let channel_id = ChannelId::from_proto(request.channel_id);
2873 let channel_model = db
2874 .rename_channel(channel_id, session.user_id(), &request.name)
2875 .await?;
2876 let root_id = channel_model.root_id();
2877 let channel = Channel::from_model(channel_model);
2878
2879 response.send(proto::RenameChannelResponse {
2880 channel: Some(channel.to_proto()),
2881 })?;
2882
2883 let connection_pool = session.connection_pool().await;
2884 let update = proto::UpdateChannels {
2885 channels: vec![channel.to_proto()],
2886 ..Default::default()
2887 };
2888 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2889 if role.can_see_channel(channel.visibility) {
2890 session.peer.send(connection_id, update.clone())?;
2891 }
2892 }
2893
2894 Ok(())
2895}
2896
2897/// Move a channel to a new parent.
2898async fn move_channel(
2899 request: proto::MoveChannel,
2900 response: Response<proto::MoveChannel>,
2901 session: UserSession,
2902) -> Result<()> {
2903 let channel_id = ChannelId::from_proto(request.channel_id);
2904 let to = ChannelId::from_proto(request.to);
2905
2906 let (root_id, channels) = session
2907 .db()
2908 .await
2909 .move_channel(channel_id, to, session.user_id())
2910 .await?;
2911
2912 let connection_pool = session.connection_pool().await;
2913 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2914 let channels = channels
2915 .iter()
2916 .filter_map(|channel| {
2917 if role.can_see_channel(channel.visibility) {
2918 Some(channel.to_proto())
2919 } else {
2920 None
2921 }
2922 })
2923 .collect::<Vec<_>>();
2924 if channels.is_empty() {
2925 continue;
2926 }
2927
2928 let update = proto::UpdateChannels {
2929 channels,
2930 ..Default::default()
2931 };
2932
2933 session.peer.send(connection_id, update.clone())?;
2934 }
2935
2936 response.send(Ack {})?;
2937 Ok(())
2938}
2939
2940/// Get the list of channel members
2941async fn get_channel_members(
2942 request: proto::GetChannelMembers,
2943 response: Response<proto::GetChannelMembers>,
2944 session: UserSession,
2945) -> Result<()> {
2946 let db = session.db().await;
2947 let channel_id = ChannelId::from_proto(request.channel_id);
2948 let members = db
2949 .get_channel_participant_details(channel_id, session.user_id())
2950 .await?;
2951 response.send(proto::GetChannelMembersResponse { members })?;
2952 Ok(())
2953}
2954
2955/// Accept or decline a channel invitation.
2956async fn respond_to_channel_invite(
2957 request: proto::RespondToChannelInvite,
2958 response: Response<proto::RespondToChannelInvite>,
2959 session: UserSession,
2960) -> Result<()> {
2961 let db = session.db().await;
2962 let channel_id = ChannelId::from_proto(request.channel_id);
2963 let RespondToChannelInvite {
2964 membership_update,
2965 notifications,
2966 } = db
2967 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
2968 .await?;
2969
2970 let mut connection_pool = session.connection_pool().await;
2971 if let Some(membership_update) = membership_update {
2972 notify_membership_updated(
2973 &mut connection_pool,
2974 membership_update,
2975 session.user_id(),
2976 &session.peer,
2977 );
2978 } else {
2979 let update = proto::UpdateChannels {
2980 remove_channel_invitations: vec![channel_id.to_proto()],
2981 ..Default::default()
2982 };
2983
2984 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
2985 session.peer.send(connection_id, update.clone())?;
2986 }
2987 };
2988
2989 send_notifications(&connection_pool, &session.peer, notifications);
2990
2991 response.send(proto::Ack {})?;
2992
2993 Ok(())
2994}
2995
2996/// Join the channels' room
2997async fn join_channel(
2998 request: proto::JoinChannel,
2999 response: Response<proto::JoinChannel>,
3000 session: UserSession,
3001) -> Result<()> {
3002 let channel_id = ChannelId::from_proto(request.channel_id);
3003 join_channel_internal(channel_id, Box::new(response), session).await
3004}
3005
3006trait JoinChannelInternalResponse {
3007 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3008}
3009impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3010 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3011 Response::<proto::JoinChannel>::send(self, result)
3012 }
3013}
3014impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3015 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3016 Response::<proto::JoinRoom>::send(self, result)
3017 }
3018}
3019
3020async fn join_channel_internal(
3021 channel_id: ChannelId,
3022 response: Box<impl JoinChannelInternalResponse>,
3023 session: UserSession,
3024) -> Result<()> {
3025 let joined_room = {
3026 let mut db = session.db().await;
3027 // If zed quits without leaving the room, and the user re-opens zed before the
3028 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3029 // room they were in.
3030 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3031 tracing::info!(
3032 stale_connection_id = %connection,
3033 "cleaning up stale connection",
3034 );
3035 drop(db);
3036 leave_room_for_session(&session, connection).await?;
3037 db = session.db().await;
3038 }
3039
3040 let (joined_room, membership_updated, role) = db
3041 .join_channel(channel_id, session.user_id(), session.connection_id)
3042 .await?;
3043
3044 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3045 let (can_publish, token) = if role == ChannelRole::Guest {
3046 (
3047 false,
3048 live_kit
3049 .guest_token(
3050 &joined_room.room.live_kit_room,
3051 &session.user_id().to_string(),
3052 )
3053 .trace_err()?,
3054 )
3055 } else {
3056 (
3057 true,
3058 live_kit
3059 .room_token(
3060 &joined_room.room.live_kit_room,
3061 &session.user_id().to_string(),
3062 )
3063 .trace_err()?,
3064 )
3065 };
3066
3067 Some(LiveKitConnectionInfo {
3068 server_url: live_kit.url().into(),
3069 token,
3070 can_publish,
3071 })
3072 });
3073
3074 response.send(proto::JoinRoomResponse {
3075 room: Some(joined_room.room.clone()),
3076 channel_id: joined_room
3077 .channel
3078 .as_ref()
3079 .map(|channel| channel.id.to_proto()),
3080 live_kit_connection_info,
3081 })?;
3082
3083 let mut connection_pool = session.connection_pool().await;
3084 if let Some(membership_updated) = membership_updated {
3085 notify_membership_updated(
3086 &mut connection_pool,
3087 membership_updated,
3088 session.user_id(),
3089 &session.peer,
3090 );
3091 }
3092
3093 room_updated(&joined_room.room, &session.peer);
3094
3095 joined_room
3096 };
3097
3098 channel_updated(
3099 &joined_room
3100 .channel
3101 .ok_or_else(|| anyhow!("channel not returned"))?,
3102 &joined_room.room,
3103 &session.peer,
3104 &*session.connection_pool().await,
3105 );
3106
3107 update_user_contacts(session.user_id(), &session).await?;
3108 Ok(())
3109}
3110
3111/// Start editing the channel notes
3112async fn join_channel_buffer(
3113 request: proto::JoinChannelBuffer,
3114 response: Response<proto::JoinChannelBuffer>,
3115 session: UserSession,
3116) -> Result<()> {
3117 let db = session.db().await;
3118 let channel_id = ChannelId::from_proto(request.channel_id);
3119
3120 let open_response = db
3121 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3122 .await?;
3123
3124 let collaborators = open_response.collaborators.clone();
3125 response.send(open_response)?;
3126
3127 let update = UpdateChannelBufferCollaborators {
3128 channel_id: channel_id.to_proto(),
3129 collaborators: collaborators.clone(),
3130 };
3131 channel_buffer_updated(
3132 session.connection_id,
3133 collaborators
3134 .iter()
3135 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3136 &update,
3137 &session.peer,
3138 );
3139
3140 Ok(())
3141}
3142
3143/// Edit the channel notes
3144async fn update_channel_buffer(
3145 request: proto::UpdateChannelBuffer,
3146 session: UserSession,
3147) -> Result<()> {
3148 let db = session.db().await;
3149 let channel_id = ChannelId::from_proto(request.channel_id);
3150
3151 let (collaborators, non_collaborators, epoch, version) = db
3152 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3153 .await?;
3154
3155 channel_buffer_updated(
3156 session.connection_id,
3157 collaborators,
3158 &proto::UpdateChannelBuffer {
3159 channel_id: channel_id.to_proto(),
3160 operations: request.operations,
3161 },
3162 &session.peer,
3163 );
3164
3165 let pool = &*session.connection_pool().await;
3166
3167 broadcast(
3168 None,
3169 non_collaborators
3170 .iter()
3171 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3172 |peer_id| {
3173 session.peer.send(
3174 peer_id,
3175 proto::UpdateChannels {
3176 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3177 channel_id: channel_id.to_proto(),
3178 epoch: epoch as u64,
3179 version: version.clone(),
3180 }],
3181 ..Default::default()
3182 },
3183 )
3184 },
3185 );
3186
3187 Ok(())
3188}
3189
3190/// Rejoin the channel notes after a connection blip
3191async fn rejoin_channel_buffers(
3192 request: proto::RejoinChannelBuffers,
3193 response: Response<proto::RejoinChannelBuffers>,
3194 session: UserSession,
3195) -> Result<()> {
3196 let db = session.db().await;
3197 let buffers = db
3198 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3199 .await?;
3200
3201 for rejoined_buffer in &buffers {
3202 let collaborators_to_notify = rejoined_buffer
3203 .buffer
3204 .collaborators
3205 .iter()
3206 .filter_map(|c| Some(c.peer_id?.into()));
3207 channel_buffer_updated(
3208 session.connection_id,
3209 collaborators_to_notify,
3210 &proto::UpdateChannelBufferCollaborators {
3211 channel_id: rejoined_buffer.buffer.channel_id,
3212 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3213 },
3214 &session.peer,
3215 );
3216 }
3217
3218 response.send(proto::RejoinChannelBuffersResponse {
3219 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3220 })?;
3221
3222 Ok(())
3223}
3224
3225/// Stop editing the channel notes
3226async fn leave_channel_buffer(
3227 request: proto::LeaveChannelBuffer,
3228 response: Response<proto::LeaveChannelBuffer>,
3229 session: UserSession,
3230) -> Result<()> {
3231 let db = session.db().await;
3232 let channel_id = ChannelId::from_proto(request.channel_id);
3233
3234 let left_buffer = db
3235 .leave_channel_buffer(channel_id, session.connection_id)
3236 .await?;
3237
3238 response.send(Ack {})?;
3239
3240 channel_buffer_updated(
3241 session.connection_id,
3242 left_buffer.connections,
3243 &proto::UpdateChannelBufferCollaborators {
3244 channel_id: channel_id.to_proto(),
3245 collaborators: left_buffer.collaborators,
3246 },
3247 &session.peer,
3248 );
3249
3250 Ok(())
3251}
3252
3253fn channel_buffer_updated<T: EnvelopedMessage>(
3254 sender_id: ConnectionId,
3255 collaborators: impl IntoIterator<Item = ConnectionId>,
3256 message: &T,
3257 peer: &Peer,
3258) {
3259 broadcast(Some(sender_id), collaborators, |peer_id| {
3260 peer.send(peer_id, message.clone())
3261 });
3262}
3263
3264fn send_notifications(
3265 connection_pool: &ConnectionPool,
3266 peer: &Peer,
3267 notifications: db::NotificationBatch,
3268) {
3269 for (user_id, notification) in notifications {
3270 for connection_id in connection_pool.user_connection_ids(user_id) {
3271 if let Err(error) = peer.send(
3272 connection_id,
3273 proto::AddNotification {
3274 notification: Some(notification.clone()),
3275 },
3276 ) {
3277 tracing::error!(
3278 "failed to send notification to {:?} {}",
3279 connection_id,
3280 error
3281 );
3282 }
3283 }
3284 }
3285}
3286
3287/// Send a message to the channel
3288async fn send_channel_message(
3289 request: proto::SendChannelMessage,
3290 response: Response<proto::SendChannelMessage>,
3291 session: UserSession,
3292) -> Result<()> {
3293 // Validate the message body.
3294 let body = request.body.trim().to_string();
3295 if body.len() > MAX_MESSAGE_LEN {
3296 return Err(anyhow!("message is too long"))?;
3297 }
3298 if body.is_empty() {
3299 return Err(anyhow!("message can't be blank"))?;
3300 }
3301
3302 // TODO: adjust mentions if body is trimmed
3303
3304 let timestamp = OffsetDateTime::now_utc();
3305 let nonce = request
3306 .nonce
3307 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3308
3309 let channel_id = ChannelId::from_proto(request.channel_id);
3310 let CreatedChannelMessage {
3311 message_id,
3312 participant_connection_ids,
3313 channel_members,
3314 notifications,
3315 } = session
3316 .db()
3317 .await
3318 .create_channel_message(
3319 channel_id,
3320 session.user_id(),
3321 &body,
3322 &request.mentions,
3323 timestamp,
3324 nonce.clone().into(),
3325 match request.reply_to_message_id {
3326 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3327 None => None,
3328 },
3329 )
3330 .await?;
3331
3332 let message = proto::ChannelMessage {
3333 sender_id: session.user_id().to_proto(),
3334 id: message_id.to_proto(),
3335 body,
3336 mentions: request.mentions,
3337 timestamp: timestamp.unix_timestamp() as u64,
3338 nonce: Some(nonce),
3339 reply_to_message_id: request.reply_to_message_id,
3340 edited_at: None,
3341 };
3342 broadcast(
3343 Some(session.connection_id),
3344 participant_connection_ids,
3345 |connection| {
3346 session.peer.send(
3347 connection,
3348 proto::ChannelMessageSent {
3349 channel_id: channel_id.to_proto(),
3350 message: Some(message.clone()),
3351 },
3352 )
3353 },
3354 );
3355 response.send(proto::SendChannelMessageResponse {
3356 message: Some(message),
3357 })?;
3358
3359 let pool = &*session.connection_pool().await;
3360 broadcast(
3361 None,
3362 channel_members
3363 .iter()
3364 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3365 |peer_id| {
3366 session.peer.send(
3367 peer_id,
3368 proto::UpdateChannels {
3369 latest_channel_message_ids: vec![proto::ChannelMessageId {
3370 channel_id: channel_id.to_proto(),
3371 message_id: message_id.to_proto(),
3372 }],
3373 ..Default::default()
3374 },
3375 )
3376 },
3377 );
3378 send_notifications(pool, &session.peer, notifications);
3379
3380 Ok(())
3381}
3382
3383/// Delete a channel message
3384async fn remove_channel_message(
3385 request: proto::RemoveChannelMessage,
3386 response: Response<proto::RemoveChannelMessage>,
3387 session: UserSession,
3388) -> Result<()> {
3389 let channel_id = ChannelId::from_proto(request.channel_id);
3390 let message_id = MessageId::from_proto(request.message_id);
3391 let (connection_ids, existing_notification_ids) = session
3392 .db()
3393 .await
3394 .remove_channel_message(channel_id, message_id, session.user_id())
3395 .await?;
3396
3397 broadcast(
3398 Some(session.connection_id),
3399 connection_ids,
3400 move |connection| {
3401 session.peer.send(connection, request.clone())?;
3402
3403 for notification_id in &existing_notification_ids {
3404 session.peer.send(
3405 connection,
3406 proto::DeleteNotification {
3407 notification_id: (*notification_id).to_proto(),
3408 },
3409 )?;
3410 }
3411
3412 Ok(())
3413 },
3414 );
3415 response.send(proto::Ack {})?;
3416 Ok(())
3417}
3418
3419async fn update_channel_message(
3420 request: proto::UpdateChannelMessage,
3421 response: Response<proto::UpdateChannelMessage>,
3422 session: UserSession,
3423) -> Result<()> {
3424 let channel_id = ChannelId::from_proto(request.channel_id);
3425 let message_id = MessageId::from_proto(request.message_id);
3426 let updated_at = OffsetDateTime::now_utc();
3427 let UpdatedChannelMessage {
3428 message_id,
3429 participant_connection_ids,
3430 notifications,
3431 reply_to_message_id,
3432 timestamp,
3433 deleted_mention_notification_ids,
3434 updated_mention_notifications,
3435 } = session
3436 .db()
3437 .await
3438 .update_channel_message(
3439 channel_id,
3440 message_id,
3441 session.user_id(),
3442 request.body.as_str(),
3443 &request.mentions,
3444 updated_at,
3445 )
3446 .await?;
3447
3448 let nonce = request
3449 .nonce
3450 .clone()
3451 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3452
3453 let message = proto::ChannelMessage {
3454 sender_id: session.user_id().to_proto(),
3455 id: message_id.to_proto(),
3456 body: request.body.clone(),
3457 mentions: request.mentions.clone(),
3458 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3459 nonce: Some(nonce),
3460 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3461 edited_at: Some(updated_at.unix_timestamp() as u64),
3462 };
3463
3464 response.send(proto::Ack {})?;
3465
3466 let pool = &*session.connection_pool().await;
3467 broadcast(
3468 Some(session.connection_id),
3469 participant_connection_ids,
3470 |connection| {
3471 session.peer.send(
3472 connection,
3473 proto::ChannelMessageUpdate {
3474 channel_id: channel_id.to_proto(),
3475 message: Some(message.clone()),
3476 },
3477 )?;
3478
3479 for notification_id in &deleted_mention_notification_ids {
3480 session.peer.send(
3481 connection,
3482 proto::DeleteNotification {
3483 notification_id: (*notification_id).to_proto(),
3484 },
3485 )?;
3486 }
3487
3488 for notification in &updated_mention_notifications {
3489 session.peer.send(
3490 connection,
3491 proto::UpdateNotification {
3492 notification: Some(notification.clone()),
3493 },
3494 )?;
3495 }
3496
3497 Ok(())
3498 },
3499 );
3500
3501 send_notifications(pool, &session.peer, notifications);
3502
3503 Ok(())
3504}
3505
3506/// Mark a channel message as read
3507async fn acknowledge_channel_message(
3508 request: proto::AckChannelMessage,
3509 session: UserSession,
3510) -> Result<()> {
3511 let channel_id = ChannelId::from_proto(request.channel_id);
3512 let message_id = MessageId::from_proto(request.message_id);
3513 let notifications = session
3514 .db()
3515 .await
3516 .observe_channel_message(channel_id, session.user_id(), message_id)
3517 .await?;
3518 send_notifications(
3519 &*session.connection_pool().await,
3520 &session.peer,
3521 notifications,
3522 );
3523 Ok(())
3524}
3525
3526/// Mark a buffer version as synced
3527async fn acknowledge_buffer_version(
3528 request: proto::AckBufferOperation,
3529 session: UserSession,
3530) -> Result<()> {
3531 let buffer_id = BufferId::from_proto(request.buffer_id);
3532 session
3533 .db()
3534 .await
3535 .observe_buffer_version(
3536 buffer_id,
3537 session.user_id(),
3538 request.epoch as i32,
3539 &request.version,
3540 )
3541 .await?;
3542 Ok(())
3543}
3544
3545struct CompleteWithLanguageModelRateLimit;
3546
3547impl RateLimit for CompleteWithLanguageModelRateLimit {
3548 fn capacity() -> usize {
3549 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3550 .ok()
3551 .and_then(|v| v.parse().ok())
3552 .unwrap_or(120) // Picked arbitrarily
3553 }
3554
3555 fn refill_duration() -> chrono::Duration {
3556 chrono::Duration::hours(1)
3557 }
3558
3559 fn db_name() -> &'static str {
3560 "complete-with-language-model"
3561 }
3562}
3563
3564async fn complete_with_language_model(
3565 request: proto::CompleteWithLanguageModel,
3566 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3567 session: Session,
3568 open_ai_api_key: Option<Arc<str>>,
3569 google_ai_api_key: Option<Arc<str>>,
3570 anthropic_api_key: Option<Arc<str>>,
3571) -> Result<()> {
3572 let Some(session) = session.for_user() else {
3573 return Err(anyhow!("user not found"))?;
3574 };
3575 authorize_access_to_language_models(&session).await?;
3576 session
3577 .rate_limiter
3578 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3579 .await?;
3580
3581 if request.model.starts_with("gpt") {
3582 let api_key =
3583 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3584 complete_with_open_ai(request, response, session, api_key).await?;
3585 } else if request.model.starts_with("gemini") {
3586 let api_key = google_ai_api_key
3587 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3588 complete_with_google_ai(request, response, session, api_key).await?;
3589 } else if request.model.starts_with("claude") {
3590 let api_key = anthropic_api_key
3591 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
3592 complete_with_anthropic(request, response, session, api_key).await?;
3593 }
3594
3595 Ok(())
3596}
3597
3598async fn complete_with_open_ai(
3599 request: proto::CompleteWithLanguageModel,
3600 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3601 session: UserSession,
3602 api_key: Arc<str>,
3603) -> Result<()> {
3604 const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3605
3606 let mut completion_stream = open_ai::stream_completion(
3607 &session.http_client,
3608 OPEN_AI_API_URL,
3609 &api_key,
3610 crate::ai::language_model_request_to_open_ai(request)?,
3611 )
3612 .await
3613 .context("open_ai::stream_completion request failed")?;
3614
3615 while let Some(event) = completion_stream.next().await {
3616 let event = event?;
3617 response.send(proto::LanguageModelResponse {
3618 choices: event
3619 .choices
3620 .into_iter()
3621 .map(|choice| proto::LanguageModelChoiceDelta {
3622 index: choice.index,
3623 delta: Some(proto::LanguageModelResponseMessage {
3624 role: choice.delta.role.map(|role| match role {
3625 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3626 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3627 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3628 } as i32),
3629 content: choice.delta.content,
3630 }),
3631 finish_reason: choice.finish_reason,
3632 })
3633 .collect(),
3634 })?;
3635 }
3636
3637 Ok(())
3638}
3639
3640async fn complete_with_google_ai(
3641 request: proto::CompleteWithLanguageModel,
3642 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3643 session: UserSession,
3644 api_key: Arc<str>,
3645) -> Result<()> {
3646 let mut stream = google_ai::stream_generate_content(
3647 &session.http_client,
3648 google_ai::API_URL,
3649 api_key.as_ref(),
3650 crate::ai::language_model_request_to_google_ai(request)?,
3651 )
3652 .await
3653 .context("google_ai::stream_generate_content request failed")?;
3654
3655 while let Some(event) = stream.next().await {
3656 let event = event?;
3657 response.send(proto::LanguageModelResponse {
3658 choices: event
3659 .candidates
3660 .unwrap_or_default()
3661 .into_iter()
3662 .map(|candidate| proto::LanguageModelChoiceDelta {
3663 index: candidate.index as u32,
3664 delta: Some(proto::LanguageModelResponseMessage {
3665 role: Some(match candidate.content.role {
3666 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3667 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3668 } as i32),
3669 content: Some(
3670 candidate
3671 .content
3672 .parts
3673 .into_iter()
3674 .filter_map(|part| match part {
3675 google_ai::Part::TextPart(part) => Some(part.text),
3676 google_ai::Part::InlineDataPart(_) => None,
3677 })
3678 .collect(),
3679 ),
3680 }),
3681 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3682 })
3683 .collect(),
3684 })?;
3685 }
3686
3687 Ok(())
3688}
3689
3690async fn complete_with_anthropic(
3691 request: proto::CompleteWithLanguageModel,
3692 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3693 session: UserSession,
3694 api_key: Arc<str>,
3695) -> Result<()> {
3696 let model = anthropic::Model::from_id(&request.model)?;
3697
3698 let mut system_message = String::new();
3699 let messages = request
3700 .messages
3701 .into_iter()
3702 .filter_map(|message| match message.role() {
3703 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
3704 role: anthropic::Role::User,
3705 content: message.content,
3706 }),
3707 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
3708 role: anthropic::Role::Assistant,
3709 content: message.content,
3710 }),
3711 // Anthropic's API breaks system instructions out as a separate field rather
3712 // than having a system message role.
3713 LanguageModelRole::LanguageModelSystem => {
3714 if !system_message.is_empty() {
3715 system_message.push_str("\n\n");
3716 }
3717 system_message.push_str(&message.content);
3718
3719 None
3720 }
3721 })
3722 .collect();
3723
3724 let mut stream = anthropic::stream_completion(
3725 &session.http_client,
3726 "https://api.anthropic.com",
3727 &api_key,
3728 anthropic::Request {
3729 model,
3730 messages,
3731 stream: true,
3732 system: system_message,
3733 max_tokens: 4092,
3734 },
3735 )
3736 .await?;
3737
3738 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
3739
3740 while let Some(event) = stream.next().await {
3741 let event = event?;
3742
3743 match event {
3744 anthropic::ResponseEvent::MessageStart { message } => {
3745 if let Some(role) = message.role {
3746 if role == "assistant" {
3747 current_role = proto::LanguageModelRole::LanguageModelAssistant;
3748 } else if role == "user" {
3749 current_role = proto::LanguageModelRole::LanguageModelUser;
3750 }
3751 }
3752 }
3753 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
3754 match content_block {
3755 anthropic::ContentBlock::Text { text } => {
3756 if !text.is_empty() {
3757 response.send(proto::LanguageModelResponse {
3758 choices: vec![proto::LanguageModelChoiceDelta {
3759 index: 0,
3760 delta: Some(proto::LanguageModelResponseMessage {
3761 role: Some(current_role as i32),
3762 content: Some(text),
3763 }),
3764 finish_reason: None,
3765 }],
3766 })?;
3767 }
3768 }
3769 }
3770 }
3771 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
3772 anthropic::TextDelta::TextDelta { text } => {
3773 response.send(proto::LanguageModelResponse {
3774 choices: vec![proto::LanguageModelChoiceDelta {
3775 index: 0,
3776 delta: Some(proto::LanguageModelResponseMessage {
3777 role: Some(current_role as i32),
3778 content: Some(text),
3779 }),
3780 finish_reason: None,
3781 }],
3782 })?;
3783 }
3784 },
3785 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
3786 if let Some(stop_reason) = delta.stop_reason {
3787 response.send(proto::LanguageModelResponse {
3788 choices: vec![proto::LanguageModelChoiceDelta {
3789 index: 0,
3790 delta: None,
3791 finish_reason: Some(stop_reason),
3792 }],
3793 })?;
3794 }
3795 }
3796 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
3797 anthropic::ResponseEvent::MessageStop {} => {}
3798 anthropic::ResponseEvent::Ping {} => {}
3799 }
3800 }
3801
3802 Ok(())
3803}
3804
3805struct CountTokensWithLanguageModelRateLimit;
3806
3807impl RateLimit for CountTokensWithLanguageModelRateLimit {
3808 fn capacity() -> usize {
3809 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3810 .ok()
3811 .and_then(|v| v.parse().ok())
3812 .unwrap_or(600) // Picked arbitrarily
3813 }
3814
3815 fn refill_duration() -> chrono::Duration {
3816 chrono::Duration::hours(1)
3817 }
3818
3819 fn db_name() -> &'static str {
3820 "count-tokens-with-language-model"
3821 }
3822}
3823
3824async fn count_tokens_with_language_model(
3825 request: proto::CountTokensWithLanguageModel,
3826 response: Response<proto::CountTokensWithLanguageModel>,
3827 session: UserSession,
3828 google_ai_api_key: Option<Arc<str>>,
3829) -> Result<()> {
3830 authorize_access_to_language_models(&session).await?;
3831
3832 if !request.model.starts_with("gemini") {
3833 return Err(anyhow!(
3834 "counting tokens for model: {:?} is not supported",
3835 request.model
3836 ))?;
3837 }
3838
3839 session
3840 .rate_limiter
3841 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
3842 .await?;
3843
3844 let api_key = google_ai_api_key
3845 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3846 let tokens_response = google_ai::count_tokens(
3847 &session.http_client,
3848 google_ai::API_URL,
3849 &api_key,
3850 crate::ai::count_tokens_request_to_google_ai(request)?,
3851 )
3852 .await?;
3853 response.send(proto::CountTokensResponse {
3854 token_count: tokens_response.total_tokens as u32,
3855 })?;
3856 Ok(())
3857}
3858
3859async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
3860 let db = session.db().await;
3861 let flags = db.get_user_flags(session.user_id()).await?;
3862 if flags.iter().any(|flag| flag == "language-models") {
3863 Ok(())
3864 } else {
3865 Err(anyhow!("permission denied"))?
3866 }
3867}
3868
3869/// Start receiving chat updates for a channel
3870async fn join_channel_chat(
3871 request: proto::JoinChannelChat,
3872 response: Response<proto::JoinChannelChat>,
3873 session: UserSession,
3874) -> Result<()> {
3875 let channel_id = ChannelId::from_proto(request.channel_id);
3876
3877 let db = session.db().await;
3878 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3879 .await?;
3880 let messages = db
3881 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3882 .await?;
3883 response.send(proto::JoinChannelChatResponse {
3884 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3885 messages,
3886 })?;
3887 Ok(())
3888}
3889
3890/// Stop receiving chat updates for a channel
3891async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
3892 let channel_id = ChannelId::from_proto(request.channel_id);
3893 session
3894 .db()
3895 .await
3896 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3897 .await?;
3898 Ok(())
3899}
3900
3901/// Retrieve the chat history for a channel
3902async fn get_channel_messages(
3903 request: proto::GetChannelMessages,
3904 response: Response<proto::GetChannelMessages>,
3905 session: UserSession,
3906) -> Result<()> {
3907 let channel_id = ChannelId::from_proto(request.channel_id);
3908 let messages = session
3909 .db()
3910 .await
3911 .get_channel_messages(
3912 channel_id,
3913 session.user_id(),
3914 MESSAGE_COUNT_PER_PAGE,
3915 Some(MessageId::from_proto(request.before_message_id)),
3916 )
3917 .await?;
3918 response.send(proto::GetChannelMessagesResponse {
3919 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3920 messages,
3921 })?;
3922 Ok(())
3923}
3924
3925/// Retrieve specific chat messages
3926async fn get_channel_messages_by_id(
3927 request: proto::GetChannelMessagesById,
3928 response: Response<proto::GetChannelMessagesById>,
3929 session: UserSession,
3930) -> Result<()> {
3931 let message_ids = request
3932 .message_ids
3933 .iter()
3934 .map(|id| MessageId::from_proto(*id))
3935 .collect::<Vec<_>>();
3936 let messages = session
3937 .db()
3938 .await
3939 .get_channel_messages_by_id(session.user_id(), &message_ids)
3940 .await?;
3941 response.send(proto::GetChannelMessagesResponse {
3942 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3943 messages,
3944 })?;
3945 Ok(())
3946}
3947
3948/// Retrieve the current users notifications
3949async fn get_notifications(
3950 request: proto::GetNotifications,
3951 response: Response<proto::GetNotifications>,
3952 session: UserSession,
3953) -> Result<()> {
3954 let notifications = session
3955 .db()
3956 .await
3957 .get_notifications(
3958 session.user_id(),
3959 NOTIFICATION_COUNT_PER_PAGE,
3960 request
3961 .before_id
3962 .map(|id| db::NotificationId::from_proto(id)),
3963 )
3964 .await?;
3965 response.send(proto::GetNotificationsResponse {
3966 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3967 notifications,
3968 })?;
3969 Ok(())
3970}
3971
3972/// Mark notifications as read
3973async fn mark_notification_as_read(
3974 request: proto::MarkNotificationRead,
3975 response: Response<proto::MarkNotificationRead>,
3976 session: UserSession,
3977) -> Result<()> {
3978 let database = &session.db().await;
3979 let notifications = database
3980 .mark_notification_as_read_by_id(
3981 session.user_id(),
3982 NotificationId::from_proto(request.notification_id),
3983 )
3984 .await?;
3985 send_notifications(
3986 &*session.connection_pool().await,
3987 &session.peer,
3988 notifications,
3989 );
3990 response.send(proto::Ack {})?;
3991 Ok(())
3992}
3993
3994/// Get the current users information
3995async fn get_private_user_info(
3996 _request: proto::GetPrivateUserInfo,
3997 response: Response<proto::GetPrivateUserInfo>,
3998 session: UserSession,
3999) -> Result<()> {
4000 let db = session.db().await;
4001
4002 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4003 let user = db
4004 .get_user_by_id(session.user_id())
4005 .await?
4006 .ok_or_else(|| anyhow!("user not found"))?;
4007 let flags = db.get_user_flags(session.user_id()).await?;
4008
4009 response.send(proto::GetPrivateUserInfoResponse {
4010 metrics_id,
4011 staff: user.admin,
4012 flags,
4013 })?;
4014 Ok(())
4015}
4016
4017fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4018 match message {
4019 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4020 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4021 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4022 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4023 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4024 code: frame.code.into(),
4025 reason: frame.reason,
4026 })),
4027 }
4028}
4029
4030fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4031 match message {
4032 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4033 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4034 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4035 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4036 AxumMessage::Close(frame) => {
4037 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4038 code: frame.code.into(),
4039 reason: frame.reason,
4040 }))
4041 }
4042 }
4043}
4044
4045fn notify_membership_updated(
4046 connection_pool: &mut ConnectionPool,
4047 result: MembershipUpdated,
4048 user_id: UserId,
4049 peer: &Peer,
4050) {
4051 for membership in &result.new_channels.channel_memberships {
4052 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4053 }
4054 for channel_id in &result.removed_channels {
4055 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4056 }
4057
4058 let user_channels_update = proto::UpdateUserChannels {
4059 channel_memberships: result
4060 .new_channels
4061 .channel_memberships
4062 .iter()
4063 .map(|cm| proto::ChannelMembership {
4064 channel_id: cm.channel_id.to_proto(),
4065 role: cm.role.into(),
4066 })
4067 .collect(),
4068 ..Default::default()
4069 };
4070
4071 let mut update = build_channels_update(result.new_channels, vec![]);
4072 update.delete_channels = result
4073 .removed_channels
4074 .into_iter()
4075 .map(|id| id.to_proto())
4076 .collect();
4077 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4078
4079 for connection_id in connection_pool.user_connection_ids(user_id) {
4080 peer.send(connection_id, user_channels_update.clone())
4081 .trace_err();
4082 peer.send(connection_id, update.clone()).trace_err();
4083 }
4084}
4085
4086fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4087 proto::UpdateUserChannels {
4088 channel_memberships: channels
4089 .channel_memberships
4090 .iter()
4091 .map(|m| proto::ChannelMembership {
4092 channel_id: m.channel_id.to_proto(),
4093 role: m.role.into(),
4094 })
4095 .collect(),
4096 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4097 observed_channel_message_id: channels.observed_channel_messages.clone(),
4098 }
4099}
4100
4101fn build_channels_update(
4102 channels: ChannelsForUser,
4103 channel_invites: Vec<db::Channel>,
4104) -> proto::UpdateChannels {
4105 let mut update = proto::UpdateChannels::default();
4106
4107 for channel in channels.channels {
4108 update.channels.push(channel.to_proto());
4109 }
4110
4111 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4112 update.latest_channel_message_ids = channels.latest_channel_messages;
4113
4114 for (channel_id, participants) in channels.channel_participants {
4115 update
4116 .channel_participants
4117 .push(proto::ChannelParticipants {
4118 channel_id: channel_id.to_proto(),
4119 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4120 });
4121 }
4122
4123 for channel in channel_invites {
4124 update.channel_invitations.push(channel.to_proto());
4125 }
4126 for project in channels.hosted_projects {
4127 update.hosted_projects.push(project);
4128 }
4129
4130 update
4131}
4132
4133fn build_initial_contacts_update(
4134 contacts: Vec<db::Contact>,
4135 pool: &ConnectionPool,
4136) -> proto::UpdateContacts {
4137 let mut update = proto::UpdateContacts::default();
4138
4139 for contact in contacts {
4140 match contact {
4141 db::Contact::Accepted { user_id, busy } => {
4142 update.contacts.push(contact_for_user(user_id, busy, &pool));
4143 }
4144 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4145 db::Contact::Incoming { user_id } => {
4146 update
4147 .incoming_requests
4148 .push(proto::IncomingContactRequest {
4149 requester_id: user_id.to_proto(),
4150 })
4151 }
4152 }
4153 }
4154
4155 update
4156}
4157
4158fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4159 proto::Contact {
4160 user_id: user_id.to_proto(),
4161 online: pool.is_user_online(user_id),
4162 busy,
4163 }
4164}
4165
4166fn room_updated(room: &proto::Room, peer: &Peer) {
4167 broadcast(
4168 None,
4169 room.participants
4170 .iter()
4171 .filter_map(|participant| Some(participant.peer_id?.into())),
4172 |peer_id| {
4173 peer.send(
4174 peer_id,
4175 proto::RoomUpdated {
4176 room: Some(room.clone()),
4177 },
4178 )
4179 },
4180 );
4181}
4182
4183fn channel_updated(
4184 channel: &db::channel::Model,
4185 room: &proto::Room,
4186 peer: &Peer,
4187 pool: &ConnectionPool,
4188) {
4189 let participants = room
4190 .participants
4191 .iter()
4192 .map(|p| p.user_id)
4193 .collect::<Vec<_>>();
4194
4195 broadcast(
4196 None,
4197 pool.channel_connection_ids(channel.root_id())
4198 .filter_map(|(channel_id, role)| {
4199 role.can_see_channel(channel.visibility).then(|| channel_id)
4200 }),
4201 |peer_id| {
4202 peer.send(
4203 peer_id,
4204 proto::UpdateChannels {
4205 channel_participants: vec![proto::ChannelParticipants {
4206 channel_id: channel.id.to_proto(),
4207 participant_user_ids: participants.clone(),
4208 }],
4209 ..Default::default()
4210 },
4211 )
4212 },
4213 );
4214}
4215
4216async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4217 let db = session.db().await;
4218
4219 let contacts = db.get_contacts(user_id).await?;
4220 let busy = db.is_user_busy(user_id).await?;
4221
4222 let pool = session.connection_pool().await;
4223 let updated_contact = contact_for_user(user_id, busy, &pool);
4224 for contact in contacts {
4225 if let db::Contact::Accepted {
4226 user_id: contact_user_id,
4227 ..
4228 } = contact
4229 {
4230 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4231 session
4232 .peer
4233 .send(
4234 contact_conn_id,
4235 proto::UpdateContacts {
4236 contacts: vec![updated_contact.clone()],
4237 remove_contacts: Default::default(),
4238 incoming_requests: Default::default(),
4239 remove_incoming_requests: Default::default(),
4240 outgoing_requests: Default::default(),
4241 remove_outgoing_requests: Default::default(),
4242 },
4243 )
4244 .trace_err();
4245 }
4246 }
4247 }
4248 Ok(())
4249}
4250
4251async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
4252 let mut contacts_to_update = HashSet::default();
4253
4254 let room_id;
4255 let canceled_calls_to_user_ids;
4256 let live_kit_room;
4257 let delete_live_kit_room;
4258 let room;
4259 let channel;
4260
4261 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4262 contacts_to_update.insert(session.user_id());
4263
4264 for project in left_room.left_projects.values() {
4265 project_left(project, session);
4266 }
4267
4268 room_id = RoomId::from_proto(left_room.room.id);
4269 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4270 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4271 delete_live_kit_room = left_room.deleted;
4272 room = mem::take(&mut left_room.room);
4273 channel = mem::take(&mut left_room.channel);
4274
4275 room_updated(&room, &session.peer);
4276 } else {
4277 return Ok(());
4278 }
4279
4280 if let Some(channel) = channel {
4281 channel_updated(
4282 &channel,
4283 &room,
4284 &session.peer,
4285 &*session.connection_pool().await,
4286 );
4287 }
4288
4289 {
4290 let pool = session.connection_pool().await;
4291 for canceled_user_id in canceled_calls_to_user_ids {
4292 for connection_id in pool.user_connection_ids(canceled_user_id) {
4293 session
4294 .peer
4295 .send(
4296 connection_id,
4297 proto::CallCanceled {
4298 room_id: room_id.to_proto(),
4299 },
4300 )
4301 .trace_err();
4302 }
4303 contacts_to_update.insert(canceled_user_id);
4304 }
4305 }
4306
4307 for contact_user_id in contacts_to_update {
4308 update_user_contacts(contact_user_id, &session).await?;
4309 }
4310
4311 if let Some(live_kit) = session.live_kit_client.as_ref() {
4312 live_kit
4313 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4314 .await
4315 .trace_err();
4316
4317 if delete_live_kit_room {
4318 live_kit.delete_room(live_kit_room).await.trace_err();
4319 }
4320 }
4321
4322 Ok(())
4323}
4324
4325async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4326 let left_channel_buffers = session
4327 .db()
4328 .await
4329 .leave_channel_buffers(session.connection_id)
4330 .await?;
4331
4332 for left_buffer in left_channel_buffers {
4333 channel_buffer_updated(
4334 session.connection_id,
4335 left_buffer.connections,
4336 &proto::UpdateChannelBufferCollaborators {
4337 channel_id: left_buffer.channel_id.to_proto(),
4338 collaborators: left_buffer.collaborators,
4339 },
4340 &session.peer,
4341 );
4342 }
4343
4344 Ok(())
4345}
4346
4347fn project_left(project: &db::LeftProject, session: &UserSession) {
4348 for connection_id in &project.connection_ids {
4349 if project.host_user_id == Some(session.user_id()) {
4350 session
4351 .peer
4352 .send(
4353 *connection_id,
4354 proto::UnshareProject {
4355 project_id: project.id.to_proto(),
4356 },
4357 )
4358 .trace_err();
4359 } else {
4360 session
4361 .peer
4362 .send(
4363 *connection_id,
4364 proto::RemoveProjectCollaborator {
4365 project_id: project.id.to_proto(),
4366 peer_id: Some(session.connection_id.into()),
4367 },
4368 )
4369 .trace_err();
4370 }
4371 }
4372}
4373
4374pub trait ResultExt {
4375 type Ok;
4376
4377 fn trace_err(self) -> Option<Self::Ok>;
4378}
4379
4380impl<T, E> ResultExt for Result<T, E>
4381where
4382 E: std::fmt::Debug,
4383{
4384 type Ok = T;
4385
4386 fn trace_err(self) -> Option<T> {
4387 match self {
4388 Ok(value) => Some(value),
4389 Err(error) => {
4390 tracing::error!("{:?}", error);
4391 None
4392 }
4393 }
4394 }
4395}