1mod connection_pool;
2
3use crate::{
4 auth::{self},
5 db::{
6 self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId,
9 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
10 },
11 executor::Executor,
12 AppState, Error, RateLimit, RateLimiter, Result,
13};
14use anyhow::{anyhow, Context as _};
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34
35use futures::{
36 channel::oneshot,
37 future::{self, BoxFuture},
38 stream::FuturesUnordered,
39 FutureExt, SinkExt, StreamExt, TryStreamExt,
40};
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
45 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48};
49use semantic_version::SemanticVersion;
50use serde::{Serialize, Serializer};
51use std::{
52 any::TypeId,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc, OnceLock,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{
69 field::{self},
70 info_span, instrument, Instrument,
71};
72use util::http::IsahcHttpClient;
73
74pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
75
76// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
77pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
78
79const MESSAGE_COUNT_PER_PAGE: usize = 100;
80const MAX_MESSAGE_LEN: usize = 1024;
81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
82
83type MessageHandler =
84 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
85
86struct Response<R> {
87 peer: Arc<Peer>,
88 receipt: Receipt<R>,
89 responded: Arc<AtomicBool>,
90}
91
92impl<R: RequestMessage> Response<R> {
93 fn send(self, payload: R::Response) -> Result<()> {
94 self.responded.store(true, SeqCst);
95 self.peer.respond(self.receipt, payload)?;
96 Ok(())
97 }
98}
99
100struct StreamingResponse<R: RequestMessage> {
101 peer: Arc<Peer>,
102 receipt: Receipt<R>,
103}
104
105impl<R: RequestMessage> StreamingResponse<R> {
106 fn send(&self, payload: R::Response) -> Result<()> {
107 self.peer.respond(self.receipt, payload)?;
108 Ok(())
109 }
110}
111
112#[derive(Clone, Debug)]
113pub enum Principal {
114 User(User),
115 Impersonated { user: User, admin: User },
116 DevServer(dev_server::Model),
117}
118
119impl Principal {
120 fn update_span(&self, span: &tracing::Span) {
121 match &self {
122 Principal::User(user) => {
123 span.record("user_id", &user.id.0);
124 span.record("login", &user.github_login);
125 }
126 Principal::Impersonated { user, admin } => {
127 span.record("user_id", &user.id.0);
128 span.record("login", &user.github_login);
129 span.record("impersonator", &admin.github_login);
130 }
131 Principal::DevServer(dev_server) => {
132 span.record("dev_server_id", &dev_server.id.0);
133 }
134 }
135 }
136}
137
138#[derive(Clone)]
139struct Session {
140 principal: Principal,
141 connection_id: ConnectionId,
142 db: Arc<tokio::sync::Mutex<DbHandle>>,
143 peer: Arc<Peer>,
144 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
145 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
146 http_client: IsahcHttpClient,
147 rate_limiter: Arc<RateLimiter>,
148 _executor: Executor,
149}
150
151impl Session {
152 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
153 #[cfg(test)]
154 tokio::task::yield_now().await;
155 let guard = self.db.lock().await;
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 guard
159 }
160
161 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
162 #[cfg(test)]
163 tokio::task::yield_now().await;
164 let guard = self.connection_pool.lock();
165 ConnectionPoolGuard {
166 guard,
167 _not_send: PhantomData,
168 }
169 }
170
171 fn for_user(self) -> Option<UserSession> {
172 UserSession::new(self)
173 }
174
175 fn user_id(&self) -> Option<UserId> {
176 match &self.principal {
177 Principal::User(user) => Some(user.id),
178 Principal::Impersonated { user, .. } => Some(user.id),
179 Principal::DevServer(_) => None,
180 }
181 }
182}
183
184impl Debug for Session {
185 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
186 let mut result = f.debug_struct("Session");
187 match &self.principal {
188 Principal::User(user) => {
189 result.field("user", &user.github_login);
190 }
191 Principal::Impersonated { user, admin } => {
192 result.field("user", &user.github_login);
193 result.field("impersonator", &admin.github_login);
194 }
195 Principal::DevServer(dev_server) => {
196 result.field("dev_server", &dev_server.id);
197 }
198 }
199 result.field("connection_id", &self.connection_id).finish()
200 }
201}
202
203struct UserSession(Session);
204
205impl UserSession {
206 pub fn new(s: Session) -> Option<Self> {
207 s.user_id().map(|_| UserSession(s))
208 }
209 pub fn user_id(&self) -> UserId {
210 self.0.user_id().unwrap()
211 }
212}
213
214impl Deref for UserSession {
215 type Target = Session;
216
217 fn deref(&self) -> &Self::Target {
218 &self.0
219 }
220}
221impl DerefMut for UserSession {
222 fn deref_mut(&mut self) -> &mut Self::Target {
223 &mut self.0
224 }
225}
226
227fn user_handler<M: RequestMessage, Fut>(
228 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
229) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
230where
231 Fut: Send + Future<Output = Result<()>>,
232{
233 let handler = Arc::new(handler);
234 move |message, response, session| {
235 let handler = handler.clone();
236 Box::pin(async move {
237 if let Some(user_session) = session.for_user() {
238 Ok(handler(message, response, user_session).await?)
239 } else {
240 Err(Error::Internal(anyhow!("must be a user")))
241 }
242 })
243 }
244}
245
246fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
247 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
248) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
249where
250 InnertRetFut: Send + Future<Output = Result<()>>,
251{
252 let handler = Arc::new(handler);
253 move |message, session| {
254 let handler = handler.clone();
255 Box::pin(async move {
256 if let Some(user_session) = session.for_user() {
257 Ok(handler(message, user_session).await?)
258 } else {
259 Err(Error::Internal(anyhow!("must be a user")))
260 }
261 })
262 }
263}
264
265struct DbHandle(Arc<Database>);
266
267impl Deref for DbHandle {
268 type Target = Database;
269
270 fn deref(&self) -> &Self::Target {
271 self.0.as_ref()
272 }
273}
274
275pub struct Server {
276 id: parking_lot::Mutex<ServerId>,
277 peer: Arc<Peer>,
278 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
279 app_state: Arc<AppState>,
280 handlers: HashMap<TypeId, MessageHandler>,
281 teardown: watch::Sender<bool>,
282}
283
284pub(crate) struct ConnectionPoolGuard<'a> {
285 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
286 _not_send: PhantomData<Rc<()>>,
287}
288
289#[derive(Serialize)]
290pub struct ServerSnapshot<'a> {
291 peer: &'a Peer,
292 #[serde(serialize_with = "serialize_deref")]
293 connection_pool: ConnectionPoolGuard<'a>,
294}
295
296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
297where
298 S: Serializer,
299 T: Deref<Target = U>,
300 U: Serialize,
301{
302 Serialize::serialize(value.deref(), serializer)
303}
304
305impl Server {
306 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
307 let mut server = Self {
308 id: parking_lot::Mutex::new(id),
309 peer: Peer::new(id.0 as u32),
310 app_state: app_state.clone(),
311 connection_pool: Default::default(),
312 handlers: Default::default(),
313 teardown: watch::channel(false).0,
314 };
315
316 server
317 .add_request_handler(ping)
318 .add_request_handler(user_handler(create_room))
319 .add_request_handler(user_handler(join_room))
320 .add_request_handler(user_handler(rejoin_room))
321 .add_request_handler(user_handler(leave_room))
322 .add_request_handler(user_handler(set_room_participant_role))
323 .add_request_handler(user_handler(call))
324 .add_request_handler(user_handler(cancel_call))
325 .add_message_handler(user_message_handler(decline_call))
326 .add_request_handler(user_handler(update_participant_location))
327 .add_request_handler(share_project)
328 .add_message_handler(unshare_project)
329 .add_request_handler(user_handler(join_project))
330 .add_request_handler(user_handler(join_hosted_project))
331 .add_message_handler(user_message_handler(leave_project))
332 .add_request_handler(update_project)
333 .add_request_handler(update_worktree)
334 .add_message_handler(start_language_server)
335 .add_message_handler(update_language_server)
336 .add_message_handler(update_diagnostic_summary)
337 .add_message_handler(update_worktree_settings)
338 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
339 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
341 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
342 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
344 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
345 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
346 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
347 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
348 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
349 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
350 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
353 )
354 .add_request_handler(
355 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
356 )
357 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
358 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
359 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
360 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
361 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
362 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
363 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
364 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
365 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
368 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
369 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
370 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
371 .add_message_handler(create_buffer_for_peer)
372 .add_request_handler(update_buffer)
373 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
374 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
375 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
376 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
377 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
378 .add_request_handler(get_users)
379 .add_request_handler(user_handler(fuzzy_search_users))
380 .add_request_handler(user_handler(request_contact))
381 .add_request_handler(user_handler(remove_contact))
382 .add_request_handler(user_handler(respond_to_contact_request))
383 .add_request_handler(user_handler(create_channel))
384 .add_request_handler(user_handler(delete_channel))
385 .add_request_handler(user_handler(invite_channel_member))
386 .add_request_handler(user_handler(remove_channel_member))
387 .add_request_handler(user_handler(set_channel_member_role))
388 .add_request_handler(user_handler(set_channel_visibility))
389 .add_request_handler(user_handler(rename_channel))
390 .add_request_handler(user_handler(join_channel_buffer))
391 .add_request_handler(user_handler(leave_channel_buffer))
392 .add_message_handler(user_message_handler(update_channel_buffer))
393 .add_request_handler(user_handler(rejoin_channel_buffers))
394 .add_request_handler(user_handler(get_channel_members))
395 .add_request_handler(user_handler(respond_to_channel_invite))
396 .add_request_handler(user_handler(join_channel))
397 .add_request_handler(user_handler(join_channel_chat))
398 .add_message_handler(user_message_handler(leave_channel_chat))
399 .add_request_handler(user_handler(send_channel_message))
400 .add_request_handler(user_handler(remove_channel_message))
401 .add_request_handler(user_handler(update_channel_message))
402 .add_request_handler(user_handler(get_channel_messages))
403 .add_request_handler(user_handler(get_channel_messages_by_id))
404 .add_request_handler(user_handler(get_notifications))
405 .add_request_handler(user_handler(mark_notification_as_read))
406 .add_request_handler(user_handler(move_channel))
407 .add_request_handler(user_handler(follow))
408 .add_message_handler(user_message_handler(unfollow))
409 .add_message_handler(user_message_handler(update_followers))
410 .add_request_handler(user_handler(get_private_user_info))
411 .add_message_handler(user_message_handler(acknowledge_channel_message))
412 .add_message_handler(user_message_handler(acknowledge_buffer_version))
413 .add_streaming_request_handler({
414 let app_state = app_state.clone();
415 move |request, response, session| {
416 complete_with_language_model(
417 request,
418 response,
419 session,
420 app_state.config.openai_api_key.clone(),
421 app_state.config.google_ai_api_key.clone(),
422 app_state.config.anthropic_api_key.clone(),
423 )
424 }
425 })
426 .add_request_handler({
427 let app_state = app_state.clone();
428 user_handler(move |request, response, session| {
429 count_tokens_with_language_model(
430 request,
431 response,
432 session,
433 app_state.config.google_ai_api_key.clone(),
434 )
435 })
436 });
437
438 Arc::new(server)
439 }
440
441 pub async fn start(&self) -> Result<()> {
442 let server_id = *self.id.lock();
443 let app_state = self.app_state.clone();
444 let peer = self.peer.clone();
445 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
446 let pool = self.connection_pool.clone();
447 let live_kit_client = self.app_state.live_kit_client.clone();
448
449 let span = info_span!("start server");
450 self.app_state.executor.spawn_detached(
451 async move {
452 tracing::info!("waiting for cleanup timeout");
453 timeout.await;
454 tracing::info!("cleanup timeout expired, retrieving stale rooms");
455 if let Some((room_ids, channel_ids)) = app_state
456 .db
457 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
458 .await
459 .trace_err()
460 {
461 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
462 tracing::info!(
463 stale_channel_buffer_count = channel_ids.len(),
464 "retrieved stale channel buffers"
465 );
466
467 for channel_id in channel_ids {
468 if let Some(refreshed_channel_buffer) = app_state
469 .db
470 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
471 .await
472 .trace_err()
473 {
474 for connection_id in refreshed_channel_buffer.connection_ids {
475 peer.send(
476 connection_id,
477 proto::UpdateChannelBufferCollaborators {
478 channel_id: channel_id.to_proto(),
479 collaborators: refreshed_channel_buffer
480 .collaborators
481 .clone(),
482 },
483 )
484 .trace_err();
485 }
486 }
487 }
488
489 for room_id in room_ids {
490 let mut contacts_to_update = HashSet::default();
491 let mut canceled_calls_to_user_ids = Vec::new();
492 let mut live_kit_room = String::new();
493 let mut delete_live_kit_room = false;
494
495 if let Some(mut refreshed_room) = app_state
496 .db
497 .clear_stale_room_participants(room_id, server_id)
498 .await
499 .trace_err()
500 {
501 tracing::info!(
502 room_id = room_id.0,
503 new_participant_count = refreshed_room.room.participants.len(),
504 "refreshed room"
505 );
506 room_updated(&refreshed_room.room, &peer);
507 if let Some(channel) = refreshed_room.channel.as_ref() {
508 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
509 }
510 contacts_to_update
511 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
512 contacts_to_update
513 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
514 canceled_calls_to_user_ids =
515 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
516 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
517 delete_live_kit_room = refreshed_room.room.participants.is_empty();
518 }
519
520 {
521 let pool = pool.lock();
522 for canceled_user_id in canceled_calls_to_user_ids {
523 for connection_id in pool.user_connection_ids(canceled_user_id) {
524 peer.send(
525 connection_id,
526 proto::CallCanceled {
527 room_id: room_id.to_proto(),
528 },
529 )
530 .trace_err();
531 }
532 }
533 }
534
535 for user_id in contacts_to_update {
536 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
537 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
538 if let Some((busy, contacts)) = busy.zip(contacts) {
539 let pool = pool.lock();
540 let updated_contact = contact_for_user(user_id, busy, &pool);
541 for contact in contacts {
542 if let db::Contact::Accepted {
543 user_id: contact_user_id,
544 ..
545 } = contact
546 {
547 for contact_conn_id in
548 pool.user_connection_ids(contact_user_id)
549 {
550 peer.send(
551 contact_conn_id,
552 proto::UpdateContacts {
553 contacts: vec![updated_contact.clone()],
554 remove_contacts: Default::default(),
555 incoming_requests: Default::default(),
556 remove_incoming_requests: Default::default(),
557 outgoing_requests: Default::default(),
558 remove_outgoing_requests: Default::default(),
559 },
560 )
561 .trace_err();
562 }
563 }
564 }
565 }
566 }
567
568 if let Some(live_kit) = live_kit_client.as_ref() {
569 if delete_live_kit_room {
570 live_kit.delete_room(live_kit_room).await.trace_err();
571 }
572 }
573 }
574 }
575
576 app_state
577 .db
578 .delete_stale_servers(&app_state.config.zed_environment, server_id)
579 .await
580 .trace_err();
581 }
582 .instrument(span),
583 );
584 Ok(())
585 }
586
587 pub fn teardown(&self) {
588 self.peer.teardown();
589 self.connection_pool.lock().reset();
590 let _ = self.teardown.send(true);
591 }
592
593 #[cfg(test)]
594 pub fn reset(&self, id: ServerId) {
595 self.teardown();
596 *self.id.lock() = id;
597 self.peer.reset(id.0 as u32);
598 let _ = self.teardown.send(false);
599 }
600
601 #[cfg(test)]
602 pub fn id(&self) -> ServerId {
603 *self.id.lock()
604 }
605
606 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
607 where
608 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
609 Fut: 'static + Send + Future<Output = Result<()>>,
610 M: EnvelopedMessage,
611 {
612 let prev_handler = self.handlers.insert(
613 TypeId::of::<M>(),
614 Box::new(move |envelope, session| {
615 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
616 let received_at = envelope.received_at;
617 tracing::info!(
618 "message received"
619 );
620 let start_time = Instant::now();
621 let future = (handler)(*envelope, session);
622 async move {
623 let result = future.await;
624 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
625 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
626 let queue_duration_ms = total_duration_ms - processing_duration_ms;
627 match result {
628 Err(error) => {
629 tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
630 }
631 Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
632 }
633 }
634 .boxed()
635 }),
636 );
637 if prev_handler.is_some() {
638 panic!("registered a handler for the same message twice");
639 }
640 self
641 }
642
643 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
644 where
645 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
646 Fut: 'static + Send + Future<Output = Result<()>>,
647 M: EnvelopedMessage,
648 {
649 self.add_handler(move |envelope, session| handler(envelope.payload, session));
650 self
651 }
652
653 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
654 where
655 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
656 Fut: Send + Future<Output = Result<()>>,
657 M: RequestMessage,
658 {
659 let handler = Arc::new(handler);
660 self.add_handler(move |envelope, session| {
661 let receipt = envelope.receipt();
662 let handler = handler.clone();
663 async move {
664 let peer = session.peer.clone();
665 let responded = Arc::new(AtomicBool::default());
666 let response = Response {
667 peer: peer.clone(),
668 responded: responded.clone(),
669 receipt,
670 };
671 match (handler)(envelope.payload, response, session).await {
672 Ok(()) => {
673 if responded.load(std::sync::atomic::Ordering::SeqCst) {
674 Ok(())
675 } else {
676 Err(anyhow!("handler did not send a response"))?
677 }
678 }
679 Err(error) => {
680 let proto_err = match &error {
681 Error::Internal(err) => err.to_proto(),
682 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
683 };
684 peer.respond_with_error(receipt, proto_err)?;
685 Err(error)
686 }
687 }
688 }
689 })
690 }
691
692 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
693 where
694 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
695 Fut: Send + Future<Output = Result<()>>,
696 M: RequestMessage,
697 {
698 let handler = Arc::new(handler);
699 self.add_handler(move |envelope, session| {
700 let receipt = envelope.receipt();
701 let handler = handler.clone();
702 async move {
703 let peer = session.peer.clone();
704 let response = StreamingResponse {
705 peer: peer.clone(),
706 receipt,
707 };
708 match (handler)(envelope.payload, response, session).await {
709 Ok(()) => {
710 peer.end_stream(receipt)?;
711 Ok(())
712 }
713 Err(error) => {
714 let proto_err = match &error {
715 Error::Internal(err) => err.to_proto(),
716 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
717 };
718 peer.respond_with_error(receipt, proto_err)?;
719 Err(error)
720 }
721 }
722 }
723 })
724 }
725
726 #[allow(clippy::too_many_arguments)]
727 pub fn handle_connection(
728 self: &Arc<Self>,
729 connection: Connection,
730 address: String,
731 principal: Principal,
732 zed_version: ZedVersion,
733 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
734 executor: Executor,
735 ) -> impl Future<Output = ()> {
736 let this = self.clone();
737 let span = info_span!("handle connection", %address, impersonator = field::Empty, connection_id = field::Empty);
738 principal.update_span(&span);
739
740 let mut teardown = self.teardown.subscribe();
741 async move {
742 if *teardown.borrow() {
743 tracing::error!("server is tearing down");
744 return
745 }
746 let (connection_id, handle_io, mut incoming_rx) = this
747 .peer
748 .add_connection(connection, {
749 let executor = executor.clone();
750 move |duration| executor.sleep(duration)
751 });
752 tracing::Span::current().record("connection_id", format!("{}", connection_id));
753 tracing::info!("connection opened");
754
755 let http_client = match IsahcHttpClient::new() {
756 Ok(http_client) => http_client,
757 Err(error) => {
758 tracing::error!(?error, "failed to create HTTP client");
759 return;
760 }
761 };
762
763 let session = Session {
764 principal: principal.clone(),
765 connection_id,
766 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
767 peer: this.peer.clone(),
768 connection_pool: this.connection_pool.clone(),
769 live_kit_client: this.app_state.live_kit_client.clone(),
770 http_client,
771 rate_limiter: this.app_state.rate_limiter.clone(),
772 _executor: executor.clone(),
773 };
774
775 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
776 tracing::error!(?error, "failed to send initial client update");
777 return;
778 }
779
780 let handle_io = handle_io.fuse();
781 futures::pin_mut!(handle_io);
782
783 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
784 // This prevents deadlocks when e.g., client A performs a request to client B and
785 // client B performs a request to client A. If both clients stop processing further
786 // messages until their respective request completes, they won't have a chance to
787 // respond to the other client's request and cause a deadlock.
788 //
789 // This arrangement ensures we will attempt to process earlier messages first, but fall
790 // back to processing messages arrived later in the spirit of making progress.
791 let mut foreground_message_handlers = FuturesUnordered::new();
792 let concurrent_handlers = Arc::new(Semaphore::new(256));
793 loop {
794 let next_message = async {
795 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
796 let message = incoming_rx.next().await;
797 (permit, message)
798 }.fuse();
799 futures::pin_mut!(next_message);
800 futures::select_biased! {
801 _ = teardown.changed().fuse() => return,
802 result = handle_io => {
803 if let Err(error) = result {
804 tracing::error!(?error, "error handling I/O");
805 }
806 break;
807 }
808 _ = foreground_message_handlers.next() => {}
809 next_message = next_message => {
810 let (permit, message) = next_message;
811 if let Some(message) = message {
812 let type_name = message.payload_type_name();
813 // note: we copy all the fields from the parent span so we can query them in the logs.
814 // (https://github.com/tokio-rs/tracing/issues/2670).
815 let span = tracing::info_span!("receive message", %connection_id, %address, type_name);
816 principal.update_span(&span);
817 let span_enter = span.enter();
818 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
819 let is_background = message.is_background();
820 let handle_message = (handler)(message, session.clone());
821 drop(span_enter);
822
823 let handle_message = async move {
824 handle_message.await;
825 drop(permit);
826 }.instrument(span);
827 if is_background {
828 executor.spawn_detached(handle_message);
829 } else {
830 foreground_message_handlers.push(handle_message);
831 }
832 } else {
833 tracing::error!("no message handler");
834 }
835 } else {
836 tracing::info!("connection closed");
837 break;
838 }
839 }
840 }
841 }
842
843 drop(foreground_message_handlers);
844 tracing::info!("signing out");
845 if let Err(error) = connection_lost(session, teardown, executor).await {
846 tracing::error!(?error, "error signing out");
847 }
848
849 }.instrument(span)
850 }
851
852 async fn send_initial_client_update(
853 &self,
854 connection_id: ConnectionId,
855 principal: &Principal,
856 zed_version: ZedVersion,
857 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
858 session: &Session,
859 ) -> Result<()> {
860 self.peer.send(
861 connection_id,
862 proto::Hello {
863 peer_id: Some(connection_id.into()),
864 },
865 )?;
866 tracing::info!("sent hello message");
867
868 let Principal::User(user) = principal else {
869 return Ok(());
870 };
871
872 if let Some(send_connection_id) = send_connection_id.take() {
873 let _ = send_connection_id.send(connection_id);
874 }
875
876 if !user.connected_once {
877 self.peer.send(connection_id, proto::ShowContacts {})?;
878 self.app_state
879 .db
880 .set_user_connected_once(user.id, true)
881 .await?;
882 }
883
884 let (contacts, channels_for_user, channel_invites) = future::try_join3(
885 self.app_state.db.get_contacts(user.id),
886 self.app_state.db.get_channels_for_user(user.id),
887 self.app_state.db.get_channel_invites_for_user(user.id),
888 )
889 .await?;
890
891 {
892 let mut pool = self.connection_pool.lock();
893 pool.add_connection(connection_id, user.id, user.admin, zed_version);
894 for membership in &channels_for_user.channel_memberships {
895 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
896 }
897 self.peer.send(
898 connection_id,
899 build_initial_contacts_update(contacts, &pool),
900 )?;
901 self.peer.send(
902 connection_id,
903 build_update_user_channels(&channels_for_user),
904 )?;
905 self.peer.send(
906 connection_id,
907 build_channels_update(channels_for_user, channel_invites),
908 )?;
909 }
910
911 if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
912 self.peer.send(connection_id, incoming_call)?;
913 }
914
915 update_user_contacts(user.id, &session).await?;
916 Ok(())
917 }
918
919 pub async fn invite_code_redeemed(
920 self: &Arc<Self>,
921 inviter_id: UserId,
922 invitee_id: UserId,
923 ) -> Result<()> {
924 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
925 if let Some(code) = &user.invite_code {
926 let pool = self.connection_pool.lock();
927 let invitee_contact = contact_for_user(invitee_id, false, &pool);
928 for connection_id in pool.user_connection_ids(inviter_id) {
929 self.peer.send(
930 connection_id,
931 proto::UpdateContacts {
932 contacts: vec![invitee_contact.clone()],
933 ..Default::default()
934 },
935 )?;
936 self.peer.send(
937 connection_id,
938 proto::UpdateInviteInfo {
939 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
940 count: user.invite_count as u32,
941 },
942 )?;
943 }
944 }
945 }
946 Ok(())
947 }
948
949 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
950 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
951 if let Some(invite_code) = &user.invite_code {
952 let pool = self.connection_pool.lock();
953 for connection_id in pool.user_connection_ids(user_id) {
954 self.peer.send(
955 connection_id,
956 proto::UpdateInviteInfo {
957 url: format!(
958 "{}{}",
959 self.app_state.config.invite_link_prefix, invite_code
960 ),
961 count: user.invite_count as u32,
962 },
963 )?;
964 }
965 }
966 }
967 Ok(())
968 }
969
970 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
971 ServerSnapshot {
972 connection_pool: ConnectionPoolGuard {
973 guard: self.connection_pool.lock(),
974 _not_send: PhantomData,
975 },
976 peer: &self.peer,
977 }
978 }
979}
980
981impl<'a> Deref for ConnectionPoolGuard<'a> {
982 type Target = ConnectionPool;
983
984 fn deref(&self) -> &Self::Target {
985 &self.guard
986 }
987}
988
989impl<'a> DerefMut for ConnectionPoolGuard<'a> {
990 fn deref_mut(&mut self) -> &mut Self::Target {
991 &mut self.guard
992 }
993}
994
995impl<'a> Drop for ConnectionPoolGuard<'a> {
996 fn drop(&mut self) {
997 #[cfg(test)]
998 self.check_invariants();
999 }
1000}
1001
1002fn broadcast<F>(
1003 sender_id: Option<ConnectionId>,
1004 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1005 mut f: F,
1006) where
1007 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1008{
1009 for receiver_id in receiver_ids {
1010 if Some(receiver_id) != sender_id {
1011 if let Err(error) = f(receiver_id) {
1012 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1013 }
1014 }
1015 }
1016}
1017
1018pub struct ProtocolVersion(u32);
1019
1020impl Header for ProtocolVersion {
1021 fn name() -> &'static HeaderName {
1022 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1023 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1024 }
1025
1026 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1027 where
1028 Self: Sized,
1029 I: Iterator<Item = &'i axum::http::HeaderValue>,
1030 {
1031 let version = values
1032 .next()
1033 .ok_or_else(axum::headers::Error::invalid)?
1034 .to_str()
1035 .map_err(|_| axum::headers::Error::invalid())?
1036 .parse()
1037 .map_err(|_| axum::headers::Error::invalid())?;
1038 Ok(Self(version))
1039 }
1040
1041 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1042 values.extend([self.0.to_string().parse().unwrap()]);
1043 }
1044}
1045
1046pub struct AppVersionHeader(SemanticVersion);
1047impl Header for AppVersionHeader {
1048 fn name() -> &'static HeaderName {
1049 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1050 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1051 }
1052
1053 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1054 where
1055 Self: Sized,
1056 I: Iterator<Item = &'i axum::http::HeaderValue>,
1057 {
1058 let version = values
1059 .next()
1060 .ok_or_else(axum::headers::Error::invalid)?
1061 .to_str()
1062 .map_err(|_| axum::headers::Error::invalid())?
1063 .parse()
1064 .map_err(|_| axum::headers::Error::invalid())?;
1065 Ok(Self(version))
1066 }
1067
1068 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1069 values.extend([self.0.to_string().parse().unwrap()]);
1070 }
1071}
1072
1073pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1074 Router::new()
1075 .route("/rpc", get(handle_websocket_request))
1076 .layer(
1077 ServiceBuilder::new()
1078 .layer(Extension(server.app_state.clone()))
1079 .layer(middleware::from_fn(auth::validate_header)),
1080 )
1081 .route("/metrics", get(handle_metrics))
1082 .layer(Extension(server))
1083}
1084
1085pub async fn handle_websocket_request(
1086 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1087 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1088 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1089 Extension(server): Extension<Arc<Server>>,
1090 Extension(principal): Extension<Principal>,
1091 ws: WebSocketUpgrade,
1092) -> axum::response::Response {
1093 if protocol_version != rpc::PROTOCOL_VERSION {
1094 return (
1095 StatusCode::UPGRADE_REQUIRED,
1096 "client must be upgraded".to_string(),
1097 )
1098 .into_response();
1099 }
1100
1101 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1102 return (
1103 StatusCode::UPGRADE_REQUIRED,
1104 "no version header found".to_string(),
1105 )
1106 .into_response();
1107 };
1108
1109 if !version.can_collaborate() {
1110 return (
1111 StatusCode::UPGRADE_REQUIRED,
1112 "client must be upgraded".to_string(),
1113 )
1114 .into_response();
1115 }
1116
1117 let socket_address = socket_address.to_string();
1118 ws.on_upgrade(move |socket| {
1119 let socket = socket
1120 .map_ok(to_tungstenite_message)
1121 .err_into()
1122 .with(|message| async move { Ok(to_axum_message(message)) });
1123 let connection = Connection::new(Box::pin(socket));
1124 async move {
1125 server
1126 .handle_connection(
1127 connection,
1128 socket_address,
1129 principal,
1130 version,
1131 None,
1132 Executor::Production,
1133 )
1134 .await;
1135 }
1136 })
1137}
1138
1139pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1140 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1141 let connections_metric = CONNECTIONS_METRIC
1142 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1143
1144 let connections = server
1145 .connection_pool
1146 .lock()
1147 .connections()
1148 .filter(|connection| !connection.admin)
1149 .count();
1150 connections_metric.set(connections as _);
1151
1152 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1153 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1154 register_int_gauge!(
1155 "shared_projects",
1156 "number of open projects with one or more guests"
1157 )
1158 .unwrap()
1159 });
1160
1161 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1162 shared_projects_metric.set(shared_projects as _);
1163
1164 let encoder = prometheus::TextEncoder::new();
1165 let metric_families = prometheus::gather();
1166 let encoded_metrics = encoder
1167 .encode_to_string(&metric_families)
1168 .map_err(|err| anyhow!("{}", err))?;
1169 Ok(encoded_metrics)
1170}
1171
1172#[instrument(err, skip(executor))]
1173async fn connection_lost(
1174 session: Session,
1175 mut teardown: watch::Receiver<bool>,
1176 executor: Executor,
1177) -> Result<()> {
1178 session.peer.disconnect(session.connection_id);
1179 session
1180 .connection_pool()
1181 .await
1182 .remove_connection(session.connection_id)?;
1183
1184 session
1185 .db()
1186 .await
1187 .connection_lost(session.connection_id)
1188 .await
1189 .trace_err();
1190
1191 futures::select_biased! {
1192 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1193 if let Some(session) = session.for_user() {
1194 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1195 leave_room_for_session(&session).await.trace_err();
1196 leave_channel_buffers_for_session(&session)
1197 .await
1198 .trace_err();
1199
1200 if !session
1201 .connection_pool()
1202 .await
1203 .is_user_online(session.user_id())
1204 {
1205 let db = session.db().await;
1206 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1207 room_updated(&room, &session.peer);
1208 }
1209 }
1210
1211 update_user_contacts(session.user_id(), &session).await?;
1212 }
1213 }
1214 _ = teardown.changed().fuse() => {}
1215 }
1216
1217 Ok(())
1218}
1219
1220/// Acknowledges a ping from a client, used to keep the connection alive.
1221async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1222 response.send(proto::Ack {})?;
1223 Ok(())
1224}
1225
1226/// Creates a new room for calling (outside of channels)
1227async fn create_room(
1228 _request: proto::CreateRoom,
1229 response: Response<proto::CreateRoom>,
1230 session: UserSession,
1231) -> Result<()> {
1232 let live_kit_room = nanoid::nanoid!(30);
1233
1234 let live_kit_connection_info = util::maybe!(async {
1235 let live_kit = session.live_kit_client.as_ref();
1236 let live_kit = live_kit?;
1237 let user_id = session.user_id().to_string();
1238
1239 let token = live_kit
1240 .room_token(&live_kit_room, &user_id.to_string())
1241 .trace_err()?;
1242
1243 Some(proto::LiveKitConnectionInfo {
1244 server_url: live_kit.url().into(),
1245 token,
1246 can_publish: true,
1247 })
1248 })
1249 .await;
1250
1251 let room = session
1252 .db()
1253 .await
1254 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1255 .await?;
1256
1257 response.send(proto::CreateRoomResponse {
1258 room: Some(room.clone()),
1259 live_kit_connection_info,
1260 })?;
1261
1262 update_user_contacts(session.user_id(), &session).await?;
1263 Ok(())
1264}
1265
1266/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1267async fn join_room(
1268 request: proto::JoinRoom,
1269 response: Response<proto::JoinRoom>,
1270 session: UserSession,
1271) -> Result<()> {
1272 let room_id = RoomId::from_proto(request.id);
1273
1274 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1275
1276 if let Some(channel_id) = channel_id {
1277 return join_channel_internal(channel_id, Box::new(response), session).await;
1278 }
1279
1280 let joined_room = {
1281 let room = session
1282 .db()
1283 .await
1284 .join_room(room_id, session.user_id(), session.connection_id)
1285 .await?;
1286 room_updated(&room.room, &session.peer);
1287 room.into_inner()
1288 };
1289
1290 for connection_id in session
1291 .connection_pool()
1292 .await
1293 .user_connection_ids(session.user_id())
1294 {
1295 session
1296 .peer
1297 .send(
1298 connection_id,
1299 proto::CallCanceled {
1300 room_id: room_id.to_proto(),
1301 },
1302 )
1303 .trace_err();
1304 }
1305
1306 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1307 if let Some(token) = live_kit
1308 .room_token(
1309 &joined_room.room.live_kit_room,
1310 &session.user_id().to_string(),
1311 )
1312 .trace_err()
1313 {
1314 Some(proto::LiveKitConnectionInfo {
1315 server_url: live_kit.url().into(),
1316 token,
1317 can_publish: true,
1318 })
1319 } else {
1320 None
1321 }
1322 } else {
1323 None
1324 };
1325
1326 response.send(proto::JoinRoomResponse {
1327 room: Some(joined_room.room),
1328 channel_id: None,
1329 live_kit_connection_info,
1330 })?;
1331
1332 update_user_contacts(session.user_id(), &session).await?;
1333 Ok(())
1334}
1335
1336/// Rejoin room is used to reconnect to a room after connection errors.
1337async fn rejoin_room(
1338 request: proto::RejoinRoom,
1339 response: Response<proto::RejoinRoom>,
1340 session: UserSession,
1341) -> Result<()> {
1342 let room;
1343 let channel;
1344 {
1345 let mut rejoined_room = session
1346 .db()
1347 .await
1348 .rejoin_room(request, session.user_id(), session.connection_id)
1349 .await?;
1350
1351 response.send(proto::RejoinRoomResponse {
1352 room: Some(rejoined_room.room.clone()),
1353 reshared_projects: rejoined_room
1354 .reshared_projects
1355 .iter()
1356 .map(|project| proto::ResharedProject {
1357 id: project.id.to_proto(),
1358 collaborators: project
1359 .collaborators
1360 .iter()
1361 .map(|collaborator| collaborator.to_proto())
1362 .collect(),
1363 })
1364 .collect(),
1365 rejoined_projects: rejoined_room
1366 .rejoined_projects
1367 .iter()
1368 .map(|rejoined_project| proto::RejoinedProject {
1369 id: rejoined_project.id.to_proto(),
1370 worktrees: rejoined_project
1371 .worktrees
1372 .iter()
1373 .map(|worktree| proto::WorktreeMetadata {
1374 id: worktree.id,
1375 root_name: worktree.root_name.clone(),
1376 visible: worktree.visible,
1377 abs_path: worktree.abs_path.clone(),
1378 })
1379 .collect(),
1380 collaborators: rejoined_project
1381 .collaborators
1382 .iter()
1383 .map(|collaborator| collaborator.to_proto())
1384 .collect(),
1385 language_servers: rejoined_project.language_servers.clone(),
1386 })
1387 .collect(),
1388 })?;
1389 room_updated(&rejoined_room.room, &session.peer);
1390
1391 for project in &rejoined_room.reshared_projects {
1392 for collaborator in &project.collaborators {
1393 session
1394 .peer
1395 .send(
1396 collaborator.connection_id,
1397 proto::UpdateProjectCollaborator {
1398 project_id: project.id.to_proto(),
1399 old_peer_id: Some(project.old_connection_id.into()),
1400 new_peer_id: Some(session.connection_id.into()),
1401 },
1402 )
1403 .trace_err();
1404 }
1405
1406 broadcast(
1407 Some(session.connection_id),
1408 project
1409 .collaborators
1410 .iter()
1411 .map(|collaborator| collaborator.connection_id),
1412 |connection_id| {
1413 session.peer.forward_send(
1414 session.connection_id,
1415 connection_id,
1416 proto::UpdateProject {
1417 project_id: project.id.to_proto(),
1418 worktrees: project.worktrees.clone(),
1419 },
1420 )
1421 },
1422 );
1423 }
1424
1425 for project in &rejoined_room.rejoined_projects {
1426 for collaborator in &project.collaborators {
1427 session
1428 .peer
1429 .send(
1430 collaborator.connection_id,
1431 proto::UpdateProjectCollaborator {
1432 project_id: project.id.to_proto(),
1433 old_peer_id: Some(project.old_connection_id.into()),
1434 new_peer_id: Some(session.connection_id.into()),
1435 },
1436 )
1437 .trace_err();
1438 }
1439 }
1440
1441 for project in &mut rejoined_room.rejoined_projects {
1442 for worktree in mem::take(&mut project.worktrees) {
1443 #[cfg(any(test, feature = "test-support"))]
1444 const MAX_CHUNK_SIZE: usize = 2;
1445 #[cfg(not(any(test, feature = "test-support")))]
1446 const MAX_CHUNK_SIZE: usize = 256;
1447
1448 // Stream this worktree's entries.
1449 let message = proto::UpdateWorktree {
1450 project_id: project.id.to_proto(),
1451 worktree_id: worktree.id,
1452 abs_path: worktree.abs_path.clone(),
1453 root_name: worktree.root_name,
1454 updated_entries: worktree.updated_entries,
1455 removed_entries: worktree.removed_entries,
1456 scan_id: worktree.scan_id,
1457 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1458 updated_repositories: worktree.updated_repositories,
1459 removed_repositories: worktree.removed_repositories,
1460 };
1461 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1462 session.peer.send(session.connection_id, update.clone())?;
1463 }
1464
1465 // Stream this worktree's diagnostics.
1466 for summary in worktree.diagnostic_summaries {
1467 session.peer.send(
1468 session.connection_id,
1469 proto::UpdateDiagnosticSummary {
1470 project_id: project.id.to_proto(),
1471 worktree_id: worktree.id,
1472 summary: Some(summary),
1473 },
1474 )?;
1475 }
1476
1477 for settings_file in worktree.settings_files {
1478 session.peer.send(
1479 session.connection_id,
1480 proto::UpdateWorktreeSettings {
1481 project_id: project.id.to_proto(),
1482 worktree_id: worktree.id,
1483 path: settings_file.path,
1484 content: Some(settings_file.content),
1485 },
1486 )?;
1487 }
1488 }
1489
1490 for language_server in &project.language_servers {
1491 session.peer.send(
1492 session.connection_id,
1493 proto::UpdateLanguageServer {
1494 project_id: project.id.to_proto(),
1495 language_server_id: language_server.id,
1496 variant: Some(
1497 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1498 proto::LspDiskBasedDiagnosticsUpdated {},
1499 ),
1500 ),
1501 },
1502 )?;
1503 }
1504 }
1505
1506 let rejoined_room = rejoined_room.into_inner();
1507
1508 room = rejoined_room.room;
1509 channel = rejoined_room.channel;
1510 }
1511
1512 if let Some(channel) = channel {
1513 channel_updated(
1514 &channel,
1515 &room,
1516 &session.peer,
1517 &*session.connection_pool().await,
1518 );
1519 }
1520
1521 update_user_contacts(session.user_id(), &session).await?;
1522 Ok(())
1523}
1524
1525/// leave room disconnects from the room.
1526async fn leave_room(
1527 _: proto::LeaveRoom,
1528 response: Response<proto::LeaveRoom>,
1529 session: UserSession,
1530) -> Result<()> {
1531 leave_room_for_session(&session).await?;
1532 response.send(proto::Ack {})?;
1533 Ok(())
1534}
1535
1536/// Updates the permissions of someone else in the room.
1537async fn set_room_participant_role(
1538 request: proto::SetRoomParticipantRole,
1539 response: Response<proto::SetRoomParticipantRole>,
1540 session: UserSession,
1541) -> Result<()> {
1542 let user_id = UserId::from_proto(request.user_id);
1543 let role = ChannelRole::from(request.role());
1544
1545 let (live_kit_room, can_publish) = {
1546 let room = session
1547 .db()
1548 .await
1549 .set_room_participant_role(
1550 session.user_id(),
1551 RoomId::from_proto(request.room_id),
1552 user_id,
1553 role,
1554 )
1555 .await?;
1556
1557 let live_kit_room = room.live_kit_room.clone();
1558 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1559 room_updated(&room, &session.peer);
1560 (live_kit_room, can_publish)
1561 };
1562
1563 if let Some(live_kit) = session.live_kit_client.as_ref() {
1564 live_kit
1565 .update_participant(
1566 live_kit_room.clone(),
1567 request.user_id.to_string(),
1568 live_kit_server::proto::ParticipantPermission {
1569 can_subscribe: true,
1570 can_publish,
1571 can_publish_data: can_publish,
1572 hidden: false,
1573 recorder: false,
1574 },
1575 )
1576 .await
1577 .trace_err();
1578 }
1579
1580 response.send(proto::Ack {})?;
1581 Ok(())
1582}
1583
1584/// Call someone else into the current room
1585async fn call(
1586 request: proto::Call,
1587 response: Response<proto::Call>,
1588 session: UserSession,
1589) -> Result<()> {
1590 let room_id = RoomId::from_proto(request.room_id);
1591 let calling_user_id = session.user_id();
1592 let calling_connection_id = session.connection_id;
1593 let called_user_id = UserId::from_proto(request.called_user_id);
1594 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1595 if !session
1596 .db()
1597 .await
1598 .has_contact(calling_user_id, called_user_id)
1599 .await?
1600 {
1601 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1602 }
1603
1604 let incoming_call = {
1605 let (room, incoming_call) = &mut *session
1606 .db()
1607 .await
1608 .call(
1609 room_id,
1610 calling_user_id,
1611 calling_connection_id,
1612 called_user_id,
1613 initial_project_id,
1614 )
1615 .await?;
1616 room_updated(&room, &session.peer);
1617 mem::take(incoming_call)
1618 };
1619 update_user_contacts(called_user_id, &session).await?;
1620
1621 let mut calls = session
1622 .connection_pool()
1623 .await
1624 .user_connection_ids(called_user_id)
1625 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1626 .collect::<FuturesUnordered<_>>();
1627
1628 while let Some(call_response) = calls.next().await {
1629 match call_response.as_ref() {
1630 Ok(_) => {
1631 response.send(proto::Ack {})?;
1632 return Ok(());
1633 }
1634 Err(_) => {
1635 call_response.trace_err();
1636 }
1637 }
1638 }
1639
1640 {
1641 let room = session
1642 .db()
1643 .await
1644 .call_failed(room_id, called_user_id)
1645 .await?;
1646 room_updated(&room, &session.peer);
1647 }
1648 update_user_contacts(called_user_id, &session).await?;
1649
1650 Err(anyhow!("failed to ring user"))?
1651}
1652
1653/// Cancel an outgoing call.
1654async fn cancel_call(
1655 request: proto::CancelCall,
1656 response: Response<proto::CancelCall>,
1657 session: UserSession,
1658) -> Result<()> {
1659 let called_user_id = UserId::from_proto(request.called_user_id);
1660 let room_id = RoomId::from_proto(request.room_id);
1661 {
1662 let room = session
1663 .db()
1664 .await
1665 .cancel_call(room_id, session.connection_id, called_user_id)
1666 .await?;
1667 room_updated(&room, &session.peer);
1668 }
1669
1670 for connection_id in session
1671 .connection_pool()
1672 .await
1673 .user_connection_ids(called_user_id)
1674 {
1675 session
1676 .peer
1677 .send(
1678 connection_id,
1679 proto::CallCanceled {
1680 room_id: room_id.to_proto(),
1681 },
1682 )
1683 .trace_err();
1684 }
1685 response.send(proto::Ack {})?;
1686
1687 update_user_contacts(called_user_id, &session).await?;
1688 Ok(())
1689}
1690
1691/// Decline an incoming call.
1692async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1693 let room_id = RoomId::from_proto(message.room_id);
1694 {
1695 let room = session
1696 .db()
1697 .await
1698 .decline_call(Some(room_id), session.user_id())
1699 .await?
1700 .ok_or_else(|| anyhow!("failed to decline call"))?;
1701 room_updated(&room, &session.peer);
1702 }
1703
1704 for connection_id in session
1705 .connection_pool()
1706 .await
1707 .user_connection_ids(session.user_id())
1708 {
1709 session
1710 .peer
1711 .send(
1712 connection_id,
1713 proto::CallCanceled {
1714 room_id: room_id.to_proto(),
1715 },
1716 )
1717 .trace_err();
1718 }
1719 update_user_contacts(session.user_id(), &session).await?;
1720 Ok(())
1721}
1722
1723/// Updates other participants in the room with your current location.
1724async fn update_participant_location(
1725 request: proto::UpdateParticipantLocation,
1726 response: Response<proto::UpdateParticipantLocation>,
1727 session: UserSession,
1728) -> Result<()> {
1729 let room_id = RoomId::from_proto(request.room_id);
1730 let location = request
1731 .location
1732 .ok_or_else(|| anyhow!("invalid location"))?;
1733
1734 let db = session.db().await;
1735 let room = db
1736 .update_room_participant_location(room_id, session.connection_id, location)
1737 .await?;
1738
1739 room_updated(&room, &session.peer);
1740 response.send(proto::Ack {})?;
1741 Ok(())
1742}
1743
1744/// Share a project into the room.
1745async fn share_project(
1746 request: proto::ShareProject,
1747 response: Response<proto::ShareProject>,
1748 session: Session,
1749) -> Result<()> {
1750 let (project_id, room) = &*session
1751 .db()
1752 .await
1753 .share_project(
1754 RoomId::from_proto(request.room_id),
1755 session.connection_id,
1756 &request.worktrees,
1757 )
1758 .await?;
1759 response.send(proto::ShareProjectResponse {
1760 project_id: project_id.to_proto(),
1761 })?;
1762 room_updated(&room, &session.peer);
1763
1764 Ok(())
1765}
1766
1767/// Unshare a project from the room.
1768async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1769 let project_id = ProjectId::from_proto(message.project_id);
1770
1771 let (room, guest_connection_ids) = &*session
1772 .db()
1773 .await
1774 .unshare_project(project_id, session.connection_id)
1775 .await?;
1776
1777 broadcast(
1778 Some(session.connection_id),
1779 guest_connection_ids.iter().copied(),
1780 |conn_id| session.peer.send(conn_id, message.clone()),
1781 );
1782 room_updated(&room, &session.peer);
1783
1784 Ok(())
1785}
1786
1787/// Join someone elses shared project.
1788async fn join_project(
1789 request: proto::JoinProject,
1790 response: Response<proto::JoinProject>,
1791 session: UserSession,
1792) -> Result<()> {
1793 let project_id = ProjectId::from_proto(request.project_id);
1794
1795 tracing::info!(%project_id, "join project");
1796
1797 let (project, replica_id) = &mut *session
1798 .db()
1799 .await
1800 .join_project_in_room(project_id, session.connection_id)
1801 .await?;
1802
1803 join_project_internal(response, session, project, replica_id)
1804}
1805
1806trait JoinProjectInternalResponse {
1807 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1808}
1809impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1810 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1811 Response::<proto::JoinProject>::send(self, result)
1812 }
1813}
1814impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1815 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1816 Response::<proto::JoinHostedProject>::send(self, result)
1817 }
1818}
1819
1820fn join_project_internal(
1821 response: impl JoinProjectInternalResponse,
1822 session: UserSession,
1823 project: &mut Project,
1824 replica_id: &ReplicaId,
1825) -> Result<()> {
1826 let collaborators = project
1827 .collaborators
1828 .iter()
1829 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1830 .map(|collaborator| collaborator.to_proto())
1831 .collect::<Vec<_>>();
1832 let project_id = project.id;
1833 let guest_user_id = session.user_id();
1834
1835 let worktrees = project
1836 .worktrees
1837 .iter()
1838 .map(|(id, worktree)| proto::WorktreeMetadata {
1839 id: *id,
1840 root_name: worktree.root_name.clone(),
1841 visible: worktree.visible,
1842 abs_path: worktree.abs_path.clone(),
1843 })
1844 .collect::<Vec<_>>();
1845
1846 for collaborator in &collaborators {
1847 session
1848 .peer
1849 .send(
1850 collaborator.peer_id.unwrap().into(),
1851 proto::AddProjectCollaborator {
1852 project_id: project_id.to_proto(),
1853 collaborator: Some(proto::Collaborator {
1854 peer_id: Some(session.connection_id.into()),
1855 replica_id: replica_id.0 as u32,
1856 user_id: guest_user_id.to_proto(),
1857 }),
1858 },
1859 )
1860 .trace_err();
1861 }
1862
1863 // First, we send the metadata associated with each worktree.
1864 response.send(proto::JoinProjectResponse {
1865 project_id: project.id.0 as u64,
1866 worktrees: worktrees.clone(),
1867 replica_id: replica_id.0 as u32,
1868 collaborators: collaborators.clone(),
1869 language_servers: project.language_servers.clone(),
1870 role: project.role.into(), // todo
1871 })?;
1872
1873 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1874 #[cfg(any(test, feature = "test-support"))]
1875 const MAX_CHUNK_SIZE: usize = 2;
1876 #[cfg(not(any(test, feature = "test-support")))]
1877 const MAX_CHUNK_SIZE: usize = 256;
1878
1879 // Stream this worktree's entries.
1880 let message = proto::UpdateWorktree {
1881 project_id: project_id.to_proto(),
1882 worktree_id,
1883 abs_path: worktree.abs_path.clone(),
1884 root_name: worktree.root_name,
1885 updated_entries: worktree.entries,
1886 removed_entries: Default::default(),
1887 scan_id: worktree.scan_id,
1888 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1889 updated_repositories: worktree.repository_entries.into_values().collect(),
1890 removed_repositories: Default::default(),
1891 };
1892 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1893 session.peer.send(session.connection_id, update.clone())?;
1894 }
1895
1896 // Stream this worktree's diagnostics.
1897 for summary in worktree.diagnostic_summaries {
1898 session.peer.send(
1899 session.connection_id,
1900 proto::UpdateDiagnosticSummary {
1901 project_id: project_id.to_proto(),
1902 worktree_id: worktree.id,
1903 summary: Some(summary),
1904 },
1905 )?;
1906 }
1907
1908 for settings_file in worktree.settings_files {
1909 session.peer.send(
1910 session.connection_id,
1911 proto::UpdateWorktreeSettings {
1912 project_id: project_id.to_proto(),
1913 worktree_id: worktree.id,
1914 path: settings_file.path,
1915 content: Some(settings_file.content),
1916 },
1917 )?;
1918 }
1919 }
1920
1921 for language_server in &project.language_servers {
1922 session.peer.send(
1923 session.connection_id,
1924 proto::UpdateLanguageServer {
1925 project_id: project_id.to_proto(),
1926 language_server_id: language_server.id,
1927 variant: Some(
1928 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1929 proto::LspDiskBasedDiagnosticsUpdated {},
1930 ),
1931 ),
1932 },
1933 )?;
1934 }
1935
1936 Ok(())
1937}
1938
1939/// Leave someone elses shared project.
1940async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
1941 let sender_id = session.connection_id;
1942 let project_id = ProjectId::from_proto(request.project_id);
1943 let db = session.db().await;
1944 if db.is_hosted_project(project_id).await? {
1945 let project = db.leave_hosted_project(project_id, sender_id).await?;
1946 project_left(&project, &session);
1947 return Ok(());
1948 }
1949
1950 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1951 tracing::info!(
1952 %project_id,
1953 host_user_id = ?project.host_user_id,
1954 host_connection_id = ?project.host_connection_id,
1955 "leave project"
1956 );
1957
1958 project_left(&project, &session);
1959 room_updated(&room, &session.peer);
1960
1961 Ok(())
1962}
1963
1964async fn join_hosted_project(
1965 request: proto::JoinHostedProject,
1966 response: Response<proto::JoinHostedProject>,
1967 session: UserSession,
1968) -> Result<()> {
1969 let (mut project, replica_id) = session
1970 .db()
1971 .await
1972 .join_hosted_project(
1973 ProjectId(request.project_id as i32),
1974 session.user_id(),
1975 session.connection_id,
1976 )
1977 .await?;
1978
1979 join_project_internal(response, session, &mut project, &replica_id)
1980}
1981
1982/// Updates other participants with changes to the project
1983async fn update_project(
1984 request: proto::UpdateProject,
1985 response: Response<proto::UpdateProject>,
1986 session: Session,
1987) -> Result<()> {
1988 let project_id = ProjectId::from_proto(request.project_id);
1989 let (room, guest_connection_ids) = &*session
1990 .db()
1991 .await
1992 .update_project(project_id, session.connection_id, &request.worktrees)
1993 .await?;
1994 broadcast(
1995 Some(session.connection_id),
1996 guest_connection_ids.iter().copied(),
1997 |connection_id| {
1998 session
1999 .peer
2000 .forward_send(session.connection_id, connection_id, request.clone())
2001 },
2002 );
2003 room_updated(&room, &session.peer);
2004 response.send(proto::Ack {})?;
2005
2006 Ok(())
2007}
2008
2009/// Updates other participants with changes to the worktree
2010async fn update_worktree(
2011 request: proto::UpdateWorktree,
2012 response: Response<proto::UpdateWorktree>,
2013 session: Session,
2014) -> Result<()> {
2015 let guest_connection_ids = session
2016 .db()
2017 .await
2018 .update_worktree(&request, session.connection_id)
2019 .await?;
2020
2021 broadcast(
2022 Some(session.connection_id),
2023 guest_connection_ids.iter().copied(),
2024 |connection_id| {
2025 session
2026 .peer
2027 .forward_send(session.connection_id, connection_id, request.clone())
2028 },
2029 );
2030 response.send(proto::Ack {})?;
2031 Ok(())
2032}
2033
2034/// Updates other participants with changes to the diagnostics
2035async fn update_diagnostic_summary(
2036 message: proto::UpdateDiagnosticSummary,
2037 session: Session,
2038) -> Result<()> {
2039 let guest_connection_ids = session
2040 .db()
2041 .await
2042 .update_diagnostic_summary(&message, session.connection_id)
2043 .await?;
2044
2045 broadcast(
2046 Some(session.connection_id),
2047 guest_connection_ids.iter().copied(),
2048 |connection_id| {
2049 session
2050 .peer
2051 .forward_send(session.connection_id, connection_id, message.clone())
2052 },
2053 );
2054
2055 Ok(())
2056}
2057
2058/// Updates other participants with changes to the worktree settings
2059async fn update_worktree_settings(
2060 message: proto::UpdateWorktreeSettings,
2061 session: Session,
2062) -> Result<()> {
2063 let guest_connection_ids = session
2064 .db()
2065 .await
2066 .update_worktree_settings(&message, session.connection_id)
2067 .await?;
2068
2069 broadcast(
2070 Some(session.connection_id),
2071 guest_connection_ids.iter().copied(),
2072 |connection_id| {
2073 session
2074 .peer
2075 .forward_send(session.connection_id, connection_id, message.clone())
2076 },
2077 );
2078
2079 Ok(())
2080}
2081
2082/// Notify other participants that a language server has started.
2083async fn start_language_server(
2084 request: proto::StartLanguageServer,
2085 session: Session,
2086) -> Result<()> {
2087 let guest_connection_ids = session
2088 .db()
2089 .await
2090 .start_language_server(&request, session.connection_id)
2091 .await?;
2092
2093 broadcast(
2094 Some(session.connection_id),
2095 guest_connection_ids.iter().copied(),
2096 |connection_id| {
2097 session
2098 .peer
2099 .forward_send(session.connection_id, connection_id, request.clone())
2100 },
2101 );
2102 Ok(())
2103}
2104
2105/// Notify other participants that a language server has changed.
2106async fn update_language_server(
2107 request: proto::UpdateLanguageServer,
2108 session: Session,
2109) -> Result<()> {
2110 let project_id = ProjectId::from_proto(request.project_id);
2111 let project_connection_ids = session
2112 .db()
2113 .await
2114 .project_connection_ids(project_id, session.connection_id)
2115 .await?;
2116 broadcast(
2117 Some(session.connection_id),
2118 project_connection_ids.iter().copied(),
2119 |connection_id| {
2120 session
2121 .peer
2122 .forward_send(session.connection_id, connection_id, request.clone())
2123 },
2124 );
2125 Ok(())
2126}
2127
2128/// forward a project request to the host. These requests should be read only
2129/// as guests are allowed to send them.
2130async fn forward_read_only_project_request<T>(
2131 request: T,
2132 response: Response<T>,
2133 session: Session,
2134) -> Result<()>
2135where
2136 T: EntityMessage + RequestMessage,
2137{
2138 let project_id = ProjectId::from_proto(request.remote_entity_id());
2139 let host_connection_id = session
2140 .db()
2141 .await
2142 .host_for_read_only_project_request(project_id, session.connection_id)
2143 .await?;
2144 let payload = session
2145 .peer
2146 .forward_request(session.connection_id, host_connection_id, request)
2147 .await?;
2148 response.send(payload)?;
2149 Ok(())
2150}
2151
2152/// forward a project request to the host. These requests are disallowed
2153/// for guests.
2154async fn forward_mutating_project_request<T>(
2155 request: T,
2156 response: Response<T>,
2157 session: Session,
2158) -> Result<()>
2159where
2160 T: EntityMessage + RequestMessage,
2161{
2162 let project_id = ProjectId::from_proto(request.remote_entity_id());
2163 let host_connection_id = session
2164 .db()
2165 .await
2166 .host_for_mutating_project_request(project_id, session.connection_id)
2167 .await?;
2168 let payload = session
2169 .peer
2170 .forward_request(session.connection_id, host_connection_id, request)
2171 .await?;
2172 response.send(payload)?;
2173 Ok(())
2174}
2175
2176/// Notify other participants that a new buffer has been created
2177async fn create_buffer_for_peer(
2178 request: proto::CreateBufferForPeer,
2179 session: Session,
2180) -> Result<()> {
2181 session
2182 .db()
2183 .await
2184 .check_user_is_project_host(
2185 ProjectId::from_proto(request.project_id),
2186 session.connection_id,
2187 )
2188 .await?;
2189 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2190 session
2191 .peer
2192 .forward_send(session.connection_id, peer_id.into(), request)?;
2193 Ok(())
2194}
2195
2196/// Notify other participants that a buffer has been updated. This is
2197/// allowed for guests as long as the update is limited to selections.
2198async fn update_buffer(
2199 request: proto::UpdateBuffer,
2200 response: Response<proto::UpdateBuffer>,
2201 session: Session,
2202) -> Result<()> {
2203 let project_id = ProjectId::from_proto(request.project_id);
2204 let mut guest_connection_ids;
2205 let mut host_connection_id = None;
2206
2207 let mut requires_write_permission = false;
2208
2209 for op in request.operations.iter() {
2210 match op.variant {
2211 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2212 Some(_) => requires_write_permission = true,
2213 }
2214 }
2215
2216 {
2217 let collaborators = session
2218 .db()
2219 .await
2220 .project_collaborators_for_buffer_update(
2221 project_id,
2222 session.connection_id,
2223 requires_write_permission,
2224 )
2225 .await?;
2226 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2227 for collaborator in collaborators.iter() {
2228 if collaborator.is_host {
2229 host_connection_id = Some(collaborator.connection_id);
2230 } else {
2231 guest_connection_ids.push(collaborator.connection_id);
2232 }
2233 }
2234 }
2235 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2236
2237 broadcast(
2238 Some(session.connection_id),
2239 guest_connection_ids,
2240 |connection_id| {
2241 session
2242 .peer
2243 .forward_send(session.connection_id, connection_id, request.clone())
2244 },
2245 );
2246 if host_connection_id != session.connection_id {
2247 session
2248 .peer
2249 .forward_request(session.connection_id, host_connection_id, request.clone())
2250 .await?;
2251 }
2252
2253 response.send(proto::Ack {})?;
2254 Ok(())
2255}
2256
2257/// Notify other participants that a project has been updated.
2258async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2259 request: T,
2260 session: Session,
2261) -> Result<()> {
2262 let project_id = ProjectId::from_proto(request.remote_entity_id());
2263 let project_connection_ids = session
2264 .db()
2265 .await
2266 .project_connection_ids(project_id, session.connection_id)
2267 .await?;
2268
2269 broadcast(
2270 Some(session.connection_id),
2271 project_connection_ids.iter().copied(),
2272 |connection_id| {
2273 session
2274 .peer
2275 .forward_send(session.connection_id, connection_id, request.clone())
2276 },
2277 );
2278 Ok(())
2279}
2280
2281/// Start following another user in a call.
2282async fn follow(
2283 request: proto::Follow,
2284 response: Response<proto::Follow>,
2285 session: UserSession,
2286) -> Result<()> {
2287 let room_id = RoomId::from_proto(request.room_id);
2288 let project_id = request.project_id.map(ProjectId::from_proto);
2289 let leader_id = request
2290 .leader_id
2291 .ok_or_else(|| anyhow!("invalid leader id"))?
2292 .into();
2293 let follower_id = session.connection_id;
2294
2295 session
2296 .db()
2297 .await
2298 .check_room_participants(room_id, leader_id, session.connection_id)
2299 .await?;
2300
2301 let response_payload = session
2302 .peer
2303 .forward_request(session.connection_id, leader_id, request)
2304 .await?;
2305 response.send(response_payload)?;
2306
2307 if let Some(project_id) = project_id {
2308 let room = session
2309 .db()
2310 .await
2311 .follow(room_id, project_id, leader_id, follower_id)
2312 .await?;
2313 room_updated(&room, &session.peer);
2314 }
2315
2316 Ok(())
2317}
2318
2319/// Stop following another user in a call.
2320async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2321 let room_id = RoomId::from_proto(request.room_id);
2322 let project_id = request.project_id.map(ProjectId::from_proto);
2323 let leader_id = request
2324 .leader_id
2325 .ok_or_else(|| anyhow!("invalid leader id"))?
2326 .into();
2327 let follower_id = session.connection_id;
2328
2329 session
2330 .db()
2331 .await
2332 .check_room_participants(room_id, leader_id, session.connection_id)
2333 .await?;
2334
2335 session
2336 .peer
2337 .forward_send(session.connection_id, leader_id, request)?;
2338
2339 if let Some(project_id) = project_id {
2340 let room = session
2341 .db()
2342 .await
2343 .unfollow(room_id, project_id, leader_id, follower_id)
2344 .await?;
2345 room_updated(&room, &session.peer);
2346 }
2347
2348 Ok(())
2349}
2350
2351/// Notify everyone following you of your current location.
2352async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2353 let room_id = RoomId::from_proto(request.room_id);
2354 let database = session.db.lock().await;
2355
2356 let connection_ids = if let Some(project_id) = request.project_id {
2357 let project_id = ProjectId::from_proto(project_id);
2358 database
2359 .project_connection_ids(project_id, session.connection_id)
2360 .await?
2361 } else {
2362 database
2363 .room_connection_ids(room_id, session.connection_id)
2364 .await?
2365 };
2366
2367 // For now, don't send view update messages back to that view's current leader.
2368 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2369 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2370 _ => None,
2371 });
2372
2373 for connection_id in connection_ids.iter().cloned() {
2374 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2375 session
2376 .peer
2377 .forward_send(session.connection_id, connection_id, request.clone())?;
2378 }
2379 }
2380 Ok(())
2381}
2382
2383/// Get public data about users.
2384async fn get_users(
2385 request: proto::GetUsers,
2386 response: Response<proto::GetUsers>,
2387 session: Session,
2388) -> Result<()> {
2389 let user_ids = request
2390 .user_ids
2391 .into_iter()
2392 .map(UserId::from_proto)
2393 .collect();
2394 let users = session
2395 .db()
2396 .await
2397 .get_users_by_ids(user_ids)
2398 .await?
2399 .into_iter()
2400 .map(|user| proto::User {
2401 id: user.id.to_proto(),
2402 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2403 github_login: user.github_login,
2404 })
2405 .collect();
2406 response.send(proto::UsersResponse { users })?;
2407 Ok(())
2408}
2409
2410/// Search for users (to invite) buy Github login
2411async fn fuzzy_search_users(
2412 request: proto::FuzzySearchUsers,
2413 response: Response<proto::FuzzySearchUsers>,
2414 session: UserSession,
2415) -> Result<()> {
2416 let query = request.query;
2417 let users = match query.len() {
2418 0 => vec![],
2419 1 | 2 => session
2420 .db()
2421 .await
2422 .get_user_by_github_login(&query)
2423 .await?
2424 .into_iter()
2425 .collect(),
2426 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2427 };
2428 let users = users
2429 .into_iter()
2430 .filter(|user| user.id != session.user_id())
2431 .map(|user| proto::User {
2432 id: user.id.to_proto(),
2433 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2434 github_login: user.github_login,
2435 })
2436 .collect();
2437 response.send(proto::UsersResponse { users })?;
2438 Ok(())
2439}
2440
2441/// Send a contact request to another user.
2442async fn request_contact(
2443 request: proto::RequestContact,
2444 response: Response<proto::RequestContact>,
2445 session: UserSession,
2446) -> Result<()> {
2447 let requester_id = session.user_id();
2448 let responder_id = UserId::from_proto(request.responder_id);
2449 if requester_id == responder_id {
2450 return Err(anyhow!("cannot add yourself as a contact"))?;
2451 }
2452
2453 let notifications = session
2454 .db()
2455 .await
2456 .send_contact_request(requester_id, responder_id)
2457 .await?;
2458
2459 // Update outgoing contact requests of requester
2460 let mut update = proto::UpdateContacts::default();
2461 update.outgoing_requests.push(responder_id.to_proto());
2462 for connection_id in session
2463 .connection_pool()
2464 .await
2465 .user_connection_ids(requester_id)
2466 {
2467 session.peer.send(connection_id, update.clone())?;
2468 }
2469
2470 // Update incoming contact requests of responder
2471 let mut update = proto::UpdateContacts::default();
2472 update
2473 .incoming_requests
2474 .push(proto::IncomingContactRequest {
2475 requester_id: requester_id.to_proto(),
2476 });
2477 let connection_pool = session.connection_pool().await;
2478 for connection_id in connection_pool.user_connection_ids(responder_id) {
2479 session.peer.send(connection_id, update.clone())?;
2480 }
2481
2482 send_notifications(&connection_pool, &session.peer, notifications);
2483
2484 response.send(proto::Ack {})?;
2485 Ok(())
2486}
2487
2488/// Accept or decline a contact request
2489async fn respond_to_contact_request(
2490 request: proto::RespondToContactRequest,
2491 response: Response<proto::RespondToContactRequest>,
2492 session: UserSession,
2493) -> Result<()> {
2494 let responder_id = session.user_id();
2495 let requester_id = UserId::from_proto(request.requester_id);
2496 let db = session.db().await;
2497 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2498 db.dismiss_contact_notification(responder_id, requester_id)
2499 .await?;
2500 } else {
2501 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2502
2503 let notifications = db
2504 .respond_to_contact_request(responder_id, requester_id, accept)
2505 .await?;
2506 let requester_busy = db.is_user_busy(requester_id).await?;
2507 let responder_busy = db.is_user_busy(responder_id).await?;
2508
2509 let pool = session.connection_pool().await;
2510 // Update responder with new contact
2511 let mut update = proto::UpdateContacts::default();
2512 if accept {
2513 update
2514 .contacts
2515 .push(contact_for_user(requester_id, requester_busy, &pool));
2516 }
2517 update
2518 .remove_incoming_requests
2519 .push(requester_id.to_proto());
2520 for connection_id in pool.user_connection_ids(responder_id) {
2521 session.peer.send(connection_id, update.clone())?;
2522 }
2523
2524 // Update requester with new contact
2525 let mut update = proto::UpdateContacts::default();
2526 if accept {
2527 update
2528 .contacts
2529 .push(contact_for_user(responder_id, responder_busy, &pool));
2530 }
2531 update
2532 .remove_outgoing_requests
2533 .push(responder_id.to_proto());
2534
2535 for connection_id in pool.user_connection_ids(requester_id) {
2536 session.peer.send(connection_id, update.clone())?;
2537 }
2538
2539 send_notifications(&pool, &session.peer, notifications);
2540 }
2541
2542 response.send(proto::Ack {})?;
2543 Ok(())
2544}
2545
2546/// Remove a contact.
2547async fn remove_contact(
2548 request: proto::RemoveContact,
2549 response: Response<proto::RemoveContact>,
2550 session: UserSession,
2551) -> Result<()> {
2552 let requester_id = session.user_id();
2553 let responder_id = UserId::from_proto(request.user_id);
2554 let db = session.db().await;
2555 let (contact_accepted, deleted_notification_id) =
2556 db.remove_contact(requester_id, responder_id).await?;
2557
2558 let pool = session.connection_pool().await;
2559 // Update outgoing contact requests of requester
2560 let mut update = proto::UpdateContacts::default();
2561 if contact_accepted {
2562 update.remove_contacts.push(responder_id.to_proto());
2563 } else {
2564 update
2565 .remove_outgoing_requests
2566 .push(responder_id.to_proto());
2567 }
2568 for connection_id in pool.user_connection_ids(requester_id) {
2569 session.peer.send(connection_id, update.clone())?;
2570 }
2571
2572 // Update incoming contact requests of responder
2573 let mut update = proto::UpdateContacts::default();
2574 if contact_accepted {
2575 update.remove_contacts.push(requester_id.to_proto());
2576 } else {
2577 update
2578 .remove_incoming_requests
2579 .push(requester_id.to_proto());
2580 }
2581 for connection_id in pool.user_connection_ids(responder_id) {
2582 session.peer.send(connection_id, update.clone())?;
2583 if let Some(notification_id) = deleted_notification_id {
2584 session.peer.send(
2585 connection_id,
2586 proto::DeleteNotification {
2587 notification_id: notification_id.to_proto(),
2588 },
2589 )?;
2590 }
2591 }
2592
2593 response.send(proto::Ack {})?;
2594 Ok(())
2595}
2596
2597/// Creates a new channel.
2598async fn create_channel(
2599 request: proto::CreateChannel,
2600 response: Response<proto::CreateChannel>,
2601 session: UserSession,
2602) -> Result<()> {
2603 let db = session.db().await;
2604
2605 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2606 let (channel, membership) = db
2607 .create_channel(&request.name, parent_id, session.user_id())
2608 .await?;
2609
2610 let root_id = channel.root_id();
2611 let channel = Channel::from_model(channel);
2612
2613 response.send(proto::CreateChannelResponse {
2614 channel: Some(channel.to_proto()),
2615 parent_id: request.parent_id,
2616 })?;
2617
2618 let mut connection_pool = session.connection_pool().await;
2619 if let Some(membership) = membership {
2620 connection_pool.subscribe_to_channel(
2621 membership.user_id,
2622 membership.channel_id,
2623 membership.role,
2624 );
2625 let update = proto::UpdateUserChannels {
2626 channel_memberships: vec![proto::ChannelMembership {
2627 channel_id: membership.channel_id.to_proto(),
2628 role: membership.role.into(),
2629 }],
2630 ..Default::default()
2631 };
2632 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2633 session.peer.send(connection_id, update.clone())?;
2634 }
2635 }
2636
2637 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2638 if !role.can_see_channel(channel.visibility) {
2639 continue;
2640 }
2641
2642 let update = proto::UpdateChannels {
2643 channels: vec![channel.to_proto()],
2644 ..Default::default()
2645 };
2646 session.peer.send(connection_id, update.clone())?;
2647 }
2648
2649 Ok(())
2650}
2651
2652/// Delete a channel
2653async fn delete_channel(
2654 request: proto::DeleteChannel,
2655 response: Response<proto::DeleteChannel>,
2656 session: UserSession,
2657) -> Result<()> {
2658 let db = session.db().await;
2659
2660 let channel_id = request.channel_id;
2661 let (root_channel, removed_channels) = db
2662 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2663 .await?;
2664 response.send(proto::Ack {})?;
2665
2666 // Notify members of removed channels
2667 let mut update = proto::UpdateChannels::default();
2668 update
2669 .delete_channels
2670 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2671
2672 let connection_pool = session.connection_pool().await;
2673 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2674 session.peer.send(connection_id, update.clone())?;
2675 }
2676
2677 Ok(())
2678}
2679
2680/// Invite someone to join a channel.
2681async fn invite_channel_member(
2682 request: proto::InviteChannelMember,
2683 response: Response<proto::InviteChannelMember>,
2684 session: UserSession,
2685) -> Result<()> {
2686 let db = session.db().await;
2687 let channel_id = ChannelId::from_proto(request.channel_id);
2688 let invitee_id = UserId::from_proto(request.user_id);
2689 let InviteMemberResult {
2690 channel,
2691 notifications,
2692 } = db
2693 .invite_channel_member(
2694 channel_id,
2695 invitee_id,
2696 session.user_id(),
2697 request.role().into(),
2698 )
2699 .await?;
2700
2701 let update = proto::UpdateChannels {
2702 channel_invitations: vec![channel.to_proto()],
2703 ..Default::default()
2704 };
2705
2706 let connection_pool = session.connection_pool().await;
2707 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2708 session.peer.send(connection_id, update.clone())?;
2709 }
2710
2711 send_notifications(&connection_pool, &session.peer, notifications);
2712
2713 response.send(proto::Ack {})?;
2714 Ok(())
2715}
2716
2717/// remove someone from a channel
2718async fn remove_channel_member(
2719 request: proto::RemoveChannelMember,
2720 response: Response<proto::RemoveChannelMember>,
2721 session: UserSession,
2722) -> Result<()> {
2723 let db = session.db().await;
2724 let channel_id = ChannelId::from_proto(request.channel_id);
2725 let member_id = UserId::from_proto(request.user_id);
2726
2727 let RemoveChannelMemberResult {
2728 membership_update,
2729 notification_id,
2730 } = db
2731 .remove_channel_member(channel_id, member_id, session.user_id())
2732 .await?;
2733
2734 let mut connection_pool = session.connection_pool().await;
2735 notify_membership_updated(
2736 &mut connection_pool,
2737 membership_update,
2738 member_id,
2739 &session.peer,
2740 );
2741 for connection_id in connection_pool.user_connection_ids(member_id) {
2742 if let Some(notification_id) = notification_id {
2743 session
2744 .peer
2745 .send(
2746 connection_id,
2747 proto::DeleteNotification {
2748 notification_id: notification_id.to_proto(),
2749 },
2750 )
2751 .trace_err();
2752 }
2753 }
2754
2755 response.send(proto::Ack {})?;
2756 Ok(())
2757}
2758
2759/// Toggle the channel between public and private.
2760/// Care is taken to maintain the invariant that public channels only descend from public channels,
2761/// (though members-only channels can appear at any point in the hierarchy).
2762async fn set_channel_visibility(
2763 request: proto::SetChannelVisibility,
2764 response: Response<proto::SetChannelVisibility>,
2765 session: UserSession,
2766) -> Result<()> {
2767 let db = session.db().await;
2768 let channel_id = ChannelId::from_proto(request.channel_id);
2769 let visibility = request.visibility().into();
2770
2771 let channel_model = db
2772 .set_channel_visibility(channel_id, visibility, session.user_id())
2773 .await?;
2774 let root_id = channel_model.root_id();
2775 let channel = Channel::from_model(channel_model);
2776
2777 let mut connection_pool = session.connection_pool().await;
2778 for (user_id, role) in connection_pool
2779 .channel_user_ids(root_id)
2780 .collect::<Vec<_>>()
2781 .into_iter()
2782 {
2783 let update = if role.can_see_channel(channel.visibility) {
2784 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2785 proto::UpdateChannels {
2786 channels: vec![channel.to_proto()],
2787 ..Default::default()
2788 }
2789 } else {
2790 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2791 proto::UpdateChannels {
2792 delete_channels: vec![channel.id.to_proto()],
2793 ..Default::default()
2794 }
2795 };
2796
2797 for connection_id in connection_pool.user_connection_ids(user_id) {
2798 session.peer.send(connection_id, update.clone())?;
2799 }
2800 }
2801
2802 response.send(proto::Ack {})?;
2803 Ok(())
2804}
2805
2806/// Alter the role for a user in the channel.
2807async fn set_channel_member_role(
2808 request: proto::SetChannelMemberRole,
2809 response: Response<proto::SetChannelMemberRole>,
2810 session: UserSession,
2811) -> Result<()> {
2812 let db = session.db().await;
2813 let channel_id = ChannelId::from_proto(request.channel_id);
2814 let member_id = UserId::from_proto(request.user_id);
2815 let result = db
2816 .set_channel_member_role(
2817 channel_id,
2818 session.user_id(),
2819 member_id,
2820 request.role().into(),
2821 )
2822 .await?;
2823
2824 match result {
2825 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2826 let mut connection_pool = session.connection_pool().await;
2827 notify_membership_updated(
2828 &mut connection_pool,
2829 membership_update,
2830 member_id,
2831 &session.peer,
2832 )
2833 }
2834 db::SetMemberRoleResult::InviteUpdated(channel) => {
2835 let update = proto::UpdateChannels {
2836 channel_invitations: vec![channel.to_proto()],
2837 ..Default::default()
2838 };
2839
2840 for connection_id in session
2841 .connection_pool()
2842 .await
2843 .user_connection_ids(member_id)
2844 {
2845 session.peer.send(connection_id, update.clone())?;
2846 }
2847 }
2848 }
2849
2850 response.send(proto::Ack {})?;
2851 Ok(())
2852}
2853
2854/// Change the name of a channel
2855async fn rename_channel(
2856 request: proto::RenameChannel,
2857 response: Response<proto::RenameChannel>,
2858 session: UserSession,
2859) -> Result<()> {
2860 let db = session.db().await;
2861 let channel_id = ChannelId::from_proto(request.channel_id);
2862 let channel_model = db
2863 .rename_channel(channel_id, session.user_id(), &request.name)
2864 .await?;
2865 let root_id = channel_model.root_id();
2866 let channel = Channel::from_model(channel_model);
2867
2868 response.send(proto::RenameChannelResponse {
2869 channel: Some(channel.to_proto()),
2870 })?;
2871
2872 let connection_pool = session.connection_pool().await;
2873 let update = proto::UpdateChannels {
2874 channels: vec![channel.to_proto()],
2875 ..Default::default()
2876 };
2877 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2878 if role.can_see_channel(channel.visibility) {
2879 session.peer.send(connection_id, update.clone())?;
2880 }
2881 }
2882
2883 Ok(())
2884}
2885
2886/// Move a channel to a new parent.
2887async fn move_channel(
2888 request: proto::MoveChannel,
2889 response: Response<proto::MoveChannel>,
2890 session: UserSession,
2891) -> Result<()> {
2892 let channel_id = ChannelId::from_proto(request.channel_id);
2893 let to = ChannelId::from_proto(request.to);
2894
2895 let (root_id, channels) = session
2896 .db()
2897 .await
2898 .move_channel(channel_id, to, session.user_id())
2899 .await?;
2900
2901 let connection_pool = session.connection_pool().await;
2902 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2903 let channels = channels
2904 .iter()
2905 .filter_map(|channel| {
2906 if role.can_see_channel(channel.visibility) {
2907 Some(channel.to_proto())
2908 } else {
2909 None
2910 }
2911 })
2912 .collect::<Vec<_>>();
2913 if channels.is_empty() {
2914 continue;
2915 }
2916
2917 let update = proto::UpdateChannels {
2918 channels,
2919 ..Default::default()
2920 };
2921
2922 session.peer.send(connection_id, update.clone())?;
2923 }
2924
2925 response.send(Ack {})?;
2926 Ok(())
2927}
2928
2929/// Get the list of channel members
2930async fn get_channel_members(
2931 request: proto::GetChannelMembers,
2932 response: Response<proto::GetChannelMembers>,
2933 session: UserSession,
2934) -> Result<()> {
2935 let db = session.db().await;
2936 let channel_id = ChannelId::from_proto(request.channel_id);
2937 let members = db
2938 .get_channel_participant_details(channel_id, session.user_id())
2939 .await?;
2940 response.send(proto::GetChannelMembersResponse { members })?;
2941 Ok(())
2942}
2943
2944/// Accept or decline a channel invitation.
2945async fn respond_to_channel_invite(
2946 request: proto::RespondToChannelInvite,
2947 response: Response<proto::RespondToChannelInvite>,
2948 session: UserSession,
2949) -> Result<()> {
2950 let db = session.db().await;
2951 let channel_id = ChannelId::from_proto(request.channel_id);
2952 let RespondToChannelInvite {
2953 membership_update,
2954 notifications,
2955 } = db
2956 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
2957 .await?;
2958
2959 let mut connection_pool = session.connection_pool().await;
2960 if let Some(membership_update) = membership_update {
2961 notify_membership_updated(
2962 &mut connection_pool,
2963 membership_update,
2964 session.user_id(),
2965 &session.peer,
2966 );
2967 } else {
2968 let update = proto::UpdateChannels {
2969 remove_channel_invitations: vec![channel_id.to_proto()],
2970 ..Default::default()
2971 };
2972
2973 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
2974 session.peer.send(connection_id, update.clone())?;
2975 }
2976 };
2977
2978 send_notifications(&connection_pool, &session.peer, notifications);
2979
2980 response.send(proto::Ack {})?;
2981
2982 Ok(())
2983}
2984
2985/// Join the channels' room
2986async fn join_channel(
2987 request: proto::JoinChannel,
2988 response: Response<proto::JoinChannel>,
2989 session: UserSession,
2990) -> Result<()> {
2991 let channel_id = ChannelId::from_proto(request.channel_id);
2992 join_channel_internal(channel_id, Box::new(response), session).await
2993}
2994
2995trait JoinChannelInternalResponse {
2996 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2997}
2998impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2999 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3000 Response::<proto::JoinChannel>::send(self, result)
3001 }
3002}
3003impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3004 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3005 Response::<proto::JoinRoom>::send(self, result)
3006 }
3007}
3008
3009async fn join_channel_internal(
3010 channel_id: ChannelId,
3011 response: Box<impl JoinChannelInternalResponse>,
3012 session: UserSession,
3013) -> Result<()> {
3014 let joined_room = {
3015 leave_room_for_session(&session).await?;
3016 let db = session.db().await;
3017
3018 let (joined_room, membership_updated, role) = db
3019 .join_channel(channel_id, session.user_id(), session.connection_id)
3020 .await?;
3021
3022 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3023 let (can_publish, token) = if role == ChannelRole::Guest {
3024 (
3025 false,
3026 live_kit
3027 .guest_token(
3028 &joined_room.room.live_kit_room,
3029 &session.user_id().to_string(),
3030 )
3031 .trace_err()?,
3032 )
3033 } else {
3034 (
3035 true,
3036 live_kit
3037 .room_token(
3038 &joined_room.room.live_kit_room,
3039 &session.user_id().to_string(),
3040 )
3041 .trace_err()?,
3042 )
3043 };
3044
3045 Some(LiveKitConnectionInfo {
3046 server_url: live_kit.url().into(),
3047 token,
3048 can_publish,
3049 })
3050 });
3051
3052 response.send(proto::JoinRoomResponse {
3053 room: Some(joined_room.room.clone()),
3054 channel_id: joined_room
3055 .channel
3056 .as_ref()
3057 .map(|channel| channel.id.to_proto()),
3058 live_kit_connection_info,
3059 })?;
3060
3061 let mut connection_pool = session.connection_pool().await;
3062 if let Some(membership_updated) = membership_updated {
3063 notify_membership_updated(
3064 &mut connection_pool,
3065 membership_updated,
3066 session.user_id(),
3067 &session.peer,
3068 );
3069 }
3070
3071 room_updated(&joined_room.room, &session.peer);
3072
3073 joined_room
3074 };
3075
3076 channel_updated(
3077 &joined_room
3078 .channel
3079 .ok_or_else(|| anyhow!("channel not returned"))?,
3080 &joined_room.room,
3081 &session.peer,
3082 &*session.connection_pool().await,
3083 );
3084
3085 update_user_contacts(session.user_id(), &session).await?;
3086 Ok(())
3087}
3088
3089/// Start editing the channel notes
3090async fn join_channel_buffer(
3091 request: proto::JoinChannelBuffer,
3092 response: Response<proto::JoinChannelBuffer>,
3093 session: UserSession,
3094) -> Result<()> {
3095 let db = session.db().await;
3096 let channel_id = ChannelId::from_proto(request.channel_id);
3097
3098 let open_response = db
3099 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3100 .await?;
3101
3102 let collaborators = open_response.collaborators.clone();
3103 response.send(open_response)?;
3104
3105 let update = UpdateChannelBufferCollaborators {
3106 channel_id: channel_id.to_proto(),
3107 collaborators: collaborators.clone(),
3108 };
3109 channel_buffer_updated(
3110 session.connection_id,
3111 collaborators
3112 .iter()
3113 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3114 &update,
3115 &session.peer,
3116 );
3117
3118 Ok(())
3119}
3120
3121/// Edit the channel notes
3122async fn update_channel_buffer(
3123 request: proto::UpdateChannelBuffer,
3124 session: UserSession,
3125) -> Result<()> {
3126 let db = session.db().await;
3127 let channel_id = ChannelId::from_proto(request.channel_id);
3128
3129 let (collaborators, non_collaborators, epoch, version) = db
3130 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3131 .await?;
3132
3133 channel_buffer_updated(
3134 session.connection_id,
3135 collaborators,
3136 &proto::UpdateChannelBuffer {
3137 channel_id: channel_id.to_proto(),
3138 operations: request.operations,
3139 },
3140 &session.peer,
3141 );
3142
3143 let pool = &*session.connection_pool().await;
3144
3145 broadcast(
3146 None,
3147 non_collaborators
3148 .iter()
3149 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3150 |peer_id| {
3151 session.peer.send(
3152 peer_id,
3153 proto::UpdateChannels {
3154 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3155 channel_id: channel_id.to_proto(),
3156 epoch: epoch as u64,
3157 version: version.clone(),
3158 }],
3159 ..Default::default()
3160 },
3161 )
3162 },
3163 );
3164
3165 Ok(())
3166}
3167
3168/// Rejoin the channel notes after a connection blip
3169async fn rejoin_channel_buffers(
3170 request: proto::RejoinChannelBuffers,
3171 response: Response<proto::RejoinChannelBuffers>,
3172 session: UserSession,
3173) -> Result<()> {
3174 let db = session.db().await;
3175 let buffers = db
3176 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3177 .await?;
3178
3179 for rejoined_buffer in &buffers {
3180 let collaborators_to_notify = rejoined_buffer
3181 .buffer
3182 .collaborators
3183 .iter()
3184 .filter_map(|c| Some(c.peer_id?.into()));
3185 channel_buffer_updated(
3186 session.connection_id,
3187 collaborators_to_notify,
3188 &proto::UpdateChannelBufferCollaborators {
3189 channel_id: rejoined_buffer.buffer.channel_id,
3190 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3191 },
3192 &session.peer,
3193 );
3194 }
3195
3196 response.send(proto::RejoinChannelBuffersResponse {
3197 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3198 })?;
3199
3200 Ok(())
3201}
3202
3203/// Stop editing the channel notes
3204async fn leave_channel_buffer(
3205 request: proto::LeaveChannelBuffer,
3206 response: Response<proto::LeaveChannelBuffer>,
3207 session: UserSession,
3208) -> Result<()> {
3209 let db = session.db().await;
3210 let channel_id = ChannelId::from_proto(request.channel_id);
3211
3212 let left_buffer = db
3213 .leave_channel_buffer(channel_id, session.connection_id)
3214 .await?;
3215
3216 response.send(Ack {})?;
3217
3218 channel_buffer_updated(
3219 session.connection_id,
3220 left_buffer.connections,
3221 &proto::UpdateChannelBufferCollaborators {
3222 channel_id: channel_id.to_proto(),
3223 collaborators: left_buffer.collaborators,
3224 },
3225 &session.peer,
3226 );
3227
3228 Ok(())
3229}
3230
3231fn channel_buffer_updated<T: EnvelopedMessage>(
3232 sender_id: ConnectionId,
3233 collaborators: impl IntoIterator<Item = ConnectionId>,
3234 message: &T,
3235 peer: &Peer,
3236) {
3237 broadcast(Some(sender_id), collaborators, |peer_id| {
3238 peer.send(peer_id, message.clone())
3239 });
3240}
3241
3242fn send_notifications(
3243 connection_pool: &ConnectionPool,
3244 peer: &Peer,
3245 notifications: db::NotificationBatch,
3246) {
3247 for (user_id, notification) in notifications {
3248 for connection_id in connection_pool.user_connection_ids(user_id) {
3249 if let Err(error) = peer.send(
3250 connection_id,
3251 proto::AddNotification {
3252 notification: Some(notification.clone()),
3253 },
3254 ) {
3255 tracing::error!(
3256 "failed to send notification to {:?} {}",
3257 connection_id,
3258 error
3259 );
3260 }
3261 }
3262 }
3263}
3264
3265/// Send a message to the channel
3266async fn send_channel_message(
3267 request: proto::SendChannelMessage,
3268 response: Response<proto::SendChannelMessage>,
3269 session: UserSession,
3270) -> Result<()> {
3271 // Validate the message body.
3272 let body = request.body.trim().to_string();
3273 if body.len() > MAX_MESSAGE_LEN {
3274 return Err(anyhow!("message is too long"))?;
3275 }
3276 if body.is_empty() {
3277 return Err(anyhow!("message can't be blank"))?;
3278 }
3279
3280 // TODO: adjust mentions if body is trimmed
3281
3282 let timestamp = OffsetDateTime::now_utc();
3283 let nonce = request
3284 .nonce
3285 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3286
3287 let channel_id = ChannelId::from_proto(request.channel_id);
3288 let CreatedChannelMessage {
3289 message_id,
3290 participant_connection_ids,
3291 channel_members,
3292 notifications,
3293 } = session
3294 .db()
3295 .await
3296 .create_channel_message(
3297 channel_id,
3298 session.user_id(),
3299 &body,
3300 &request.mentions,
3301 timestamp,
3302 nonce.clone().into(),
3303 match request.reply_to_message_id {
3304 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3305 None => None,
3306 },
3307 )
3308 .await?;
3309
3310 let message = proto::ChannelMessage {
3311 sender_id: session.user_id().to_proto(),
3312 id: message_id.to_proto(),
3313 body,
3314 mentions: request.mentions,
3315 timestamp: timestamp.unix_timestamp() as u64,
3316 nonce: Some(nonce),
3317 reply_to_message_id: request.reply_to_message_id,
3318 edited_at: None,
3319 };
3320 broadcast(
3321 Some(session.connection_id),
3322 participant_connection_ids,
3323 |connection| {
3324 session.peer.send(
3325 connection,
3326 proto::ChannelMessageSent {
3327 channel_id: channel_id.to_proto(),
3328 message: Some(message.clone()),
3329 },
3330 )
3331 },
3332 );
3333 response.send(proto::SendChannelMessageResponse {
3334 message: Some(message),
3335 })?;
3336
3337 let pool = &*session.connection_pool().await;
3338 broadcast(
3339 None,
3340 channel_members
3341 .iter()
3342 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3343 |peer_id| {
3344 session.peer.send(
3345 peer_id,
3346 proto::UpdateChannels {
3347 latest_channel_message_ids: vec![proto::ChannelMessageId {
3348 channel_id: channel_id.to_proto(),
3349 message_id: message_id.to_proto(),
3350 }],
3351 ..Default::default()
3352 },
3353 )
3354 },
3355 );
3356 send_notifications(pool, &session.peer, notifications);
3357
3358 Ok(())
3359}
3360
3361/// Delete a channel message
3362async fn remove_channel_message(
3363 request: proto::RemoveChannelMessage,
3364 response: Response<proto::RemoveChannelMessage>,
3365 session: UserSession,
3366) -> Result<()> {
3367 let channel_id = ChannelId::from_proto(request.channel_id);
3368 let message_id = MessageId::from_proto(request.message_id);
3369 let connection_ids = session
3370 .db()
3371 .await
3372 .remove_channel_message(channel_id, message_id, session.user_id())
3373 .await?;
3374 broadcast(Some(session.connection_id), connection_ids, |connection| {
3375 session.peer.send(connection, request.clone())
3376 });
3377 response.send(proto::Ack {})?;
3378 Ok(())
3379}
3380
3381async fn update_channel_message(
3382 request: proto::UpdateChannelMessage,
3383 response: Response<proto::UpdateChannelMessage>,
3384 session: UserSession,
3385) -> Result<()> {
3386 let channel_id = ChannelId::from_proto(request.channel_id);
3387 let message_id = MessageId::from_proto(request.message_id);
3388 let updated_at = OffsetDateTime::now_utc();
3389 let UpdatedChannelMessage {
3390 message_id,
3391 participant_connection_ids,
3392 notifications,
3393 reply_to_message_id,
3394 timestamp,
3395 } = session
3396 .db()
3397 .await
3398 .update_channel_message(
3399 channel_id,
3400 message_id,
3401 session.user_id(),
3402 request.body.as_str(),
3403 &request.mentions,
3404 updated_at,
3405 )
3406 .await?;
3407
3408 let nonce = request
3409 .nonce
3410 .clone()
3411 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3412
3413 let message = proto::ChannelMessage {
3414 sender_id: session.user_id().to_proto(),
3415 id: message_id.to_proto(),
3416 body: request.body.clone(),
3417 mentions: request.mentions.clone(),
3418 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3419 nonce: Some(nonce),
3420 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3421 edited_at: Some(updated_at.unix_timestamp() as u64),
3422 };
3423
3424 response.send(proto::Ack {})?;
3425
3426 let pool = &*session.connection_pool().await;
3427 broadcast(
3428 Some(session.connection_id),
3429 participant_connection_ids,
3430 |connection| {
3431 session.peer.send(
3432 connection,
3433 proto::ChannelMessageUpdate {
3434 channel_id: channel_id.to_proto(),
3435 message: Some(message.clone()),
3436 },
3437 )
3438 },
3439 );
3440
3441 send_notifications(pool, &session.peer, notifications);
3442
3443 Ok(())
3444}
3445
3446/// Mark a channel message as read
3447async fn acknowledge_channel_message(
3448 request: proto::AckChannelMessage,
3449 session: UserSession,
3450) -> Result<()> {
3451 let channel_id = ChannelId::from_proto(request.channel_id);
3452 let message_id = MessageId::from_proto(request.message_id);
3453 let notifications = session
3454 .db()
3455 .await
3456 .observe_channel_message(channel_id, session.user_id(), message_id)
3457 .await?;
3458 send_notifications(
3459 &*session.connection_pool().await,
3460 &session.peer,
3461 notifications,
3462 );
3463 Ok(())
3464}
3465
3466/// Mark a buffer version as synced
3467async fn acknowledge_buffer_version(
3468 request: proto::AckBufferOperation,
3469 session: UserSession,
3470) -> Result<()> {
3471 let buffer_id = BufferId::from_proto(request.buffer_id);
3472 session
3473 .db()
3474 .await
3475 .observe_buffer_version(
3476 buffer_id,
3477 session.user_id(),
3478 request.epoch as i32,
3479 &request.version,
3480 )
3481 .await?;
3482 Ok(())
3483}
3484
3485struct CompleteWithLanguageModelRateLimit;
3486
3487impl RateLimit for CompleteWithLanguageModelRateLimit {
3488 fn capacity() -> usize {
3489 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3490 .ok()
3491 .and_then(|v| v.parse().ok())
3492 .unwrap_or(120) // Picked arbitrarily
3493 }
3494
3495 fn refill_duration() -> chrono::Duration {
3496 chrono::Duration::hours(1)
3497 }
3498
3499 fn db_name() -> &'static str {
3500 "complete-with-language-model"
3501 }
3502}
3503
3504async fn complete_with_language_model(
3505 request: proto::CompleteWithLanguageModel,
3506 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3507 session: Session,
3508 open_ai_api_key: Option<Arc<str>>,
3509 google_ai_api_key: Option<Arc<str>>,
3510 anthropic_api_key: Option<Arc<str>>,
3511) -> Result<()> {
3512 let Some(session) = session.for_user() else {
3513 return Err(anyhow!("user not found"))?;
3514 };
3515 authorize_access_to_language_models(&session).await?;
3516 session
3517 .rate_limiter
3518 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3519 .await?;
3520
3521 if request.model.starts_with("gpt") {
3522 let api_key =
3523 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3524 complete_with_open_ai(request, response, session, api_key).await?;
3525 } else if request.model.starts_with("gemini") {
3526 let api_key = google_ai_api_key
3527 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3528 complete_with_google_ai(request, response, session, api_key).await?;
3529 } else if request.model.starts_with("claude") {
3530 let api_key = anthropic_api_key
3531 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
3532 complete_with_anthropic(request, response, session, api_key).await?;
3533 }
3534
3535 Ok(())
3536}
3537
3538async fn complete_with_open_ai(
3539 request: proto::CompleteWithLanguageModel,
3540 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3541 session: UserSession,
3542 api_key: Arc<str>,
3543) -> Result<()> {
3544 const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3545
3546 let mut completion_stream = open_ai::stream_completion(
3547 &session.http_client,
3548 OPEN_AI_API_URL,
3549 &api_key,
3550 crate::ai::language_model_request_to_open_ai(request)?,
3551 )
3552 .await
3553 .context("open_ai::stream_completion request failed")?;
3554
3555 while let Some(event) = completion_stream.next().await {
3556 let event = event?;
3557 response.send(proto::LanguageModelResponse {
3558 choices: event
3559 .choices
3560 .into_iter()
3561 .map(|choice| proto::LanguageModelChoiceDelta {
3562 index: choice.index,
3563 delta: Some(proto::LanguageModelResponseMessage {
3564 role: choice.delta.role.map(|role| match role {
3565 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3566 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3567 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3568 } as i32),
3569 content: choice.delta.content,
3570 }),
3571 finish_reason: choice.finish_reason,
3572 })
3573 .collect(),
3574 })?;
3575 }
3576
3577 Ok(())
3578}
3579
3580async fn complete_with_google_ai(
3581 request: proto::CompleteWithLanguageModel,
3582 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3583 session: UserSession,
3584 api_key: Arc<str>,
3585) -> Result<()> {
3586 let mut stream = google_ai::stream_generate_content(
3587 &session.http_client,
3588 google_ai::API_URL,
3589 api_key.as_ref(),
3590 crate::ai::language_model_request_to_google_ai(request)?,
3591 )
3592 .await
3593 .context("google_ai::stream_generate_content request failed")?;
3594
3595 while let Some(event) = stream.next().await {
3596 let event = event?;
3597 response.send(proto::LanguageModelResponse {
3598 choices: event
3599 .candidates
3600 .unwrap_or_default()
3601 .into_iter()
3602 .map(|candidate| proto::LanguageModelChoiceDelta {
3603 index: candidate.index as u32,
3604 delta: Some(proto::LanguageModelResponseMessage {
3605 role: Some(match candidate.content.role {
3606 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3607 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3608 } as i32),
3609 content: Some(
3610 candidate
3611 .content
3612 .parts
3613 .into_iter()
3614 .filter_map(|part| match part {
3615 google_ai::Part::TextPart(part) => Some(part.text),
3616 google_ai::Part::InlineDataPart(_) => None,
3617 })
3618 .collect(),
3619 ),
3620 }),
3621 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3622 })
3623 .collect(),
3624 })?;
3625 }
3626
3627 Ok(())
3628}
3629
3630async fn complete_with_anthropic(
3631 request: proto::CompleteWithLanguageModel,
3632 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3633 session: UserSession,
3634 api_key: Arc<str>,
3635) -> Result<()> {
3636 let model = anthropic::Model::from_id(&request.model)?;
3637
3638 let mut system_message = String::new();
3639 let messages = request
3640 .messages
3641 .into_iter()
3642 .filter_map(|message| match message.role() {
3643 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
3644 role: anthropic::Role::User,
3645 content: message.content,
3646 }),
3647 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
3648 role: anthropic::Role::Assistant,
3649 content: message.content,
3650 }),
3651 // Anthropic's API breaks system instructions out as a separate field rather
3652 // than having a system message role.
3653 LanguageModelRole::LanguageModelSystem => {
3654 if !system_message.is_empty() {
3655 system_message.push_str("\n\n");
3656 }
3657 system_message.push_str(&message.content);
3658
3659 None
3660 }
3661 })
3662 .collect();
3663
3664 let mut stream = anthropic::stream_completion(
3665 &session.http_client,
3666 "https://api.anthropic.com",
3667 &api_key,
3668 anthropic::Request {
3669 model,
3670 messages,
3671 stream: true,
3672 system: system_message,
3673 max_tokens: 4092,
3674 },
3675 )
3676 .await?;
3677
3678 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
3679
3680 while let Some(event) = stream.next().await {
3681 let event = event?;
3682
3683 match event {
3684 anthropic::ResponseEvent::MessageStart { message } => {
3685 if let Some(role) = message.role {
3686 if role == "assistant" {
3687 current_role = proto::LanguageModelRole::LanguageModelAssistant;
3688 } else if role == "user" {
3689 current_role = proto::LanguageModelRole::LanguageModelUser;
3690 }
3691 }
3692 }
3693 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
3694 match content_block {
3695 anthropic::ContentBlock::Text { text } => {
3696 if !text.is_empty() {
3697 response.send(proto::LanguageModelResponse {
3698 choices: vec![proto::LanguageModelChoiceDelta {
3699 index: 0,
3700 delta: Some(proto::LanguageModelResponseMessage {
3701 role: Some(current_role as i32),
3702 content: Some(text),
3703 }),
3704 finish_reason: None,
3705 }],
3706 })?;
3707 }
3708 }
3709 }
3710 }
3711 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
3712 anthropic::TextDelta::TextDelta { text } => {
3713 response.send(proto::LanguageModelResponse {
3714 choices: vec![proto::LanguageModelChoiceDelta {
3715 index: 0,
3716 delta: Some(proto::LanguageModelResponseMessage {
3717 role: Some(current_role as i32),
3718 content: Some(text),
3719 }),
3720 finish_reason: None,
3721 }],
3722 })?;
3723 }
3724 },
3725 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
3726 if let Some(stop_reason) = delta.stop_reason {
3727 response.send(proto::LanguageModelResponse {
3728 choices: vec![proto::LanguageModelChoiceDelta {
3729 index: 0,
3730 delta: None,
3731 finish_reason: Some(stop_reason),
3732 }],
3733 })?;
3734 }
3735 }
3736 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
3737 anthropic::ResponseEvent::MessageStop {} => {}
3738 anthropic::ResponseEvent::Ping {} => {}
3739 }
3740 }
3741
3742 Ok(())
3743}
3744
3745struct CountTokensWithLanguageModelRateLimit;
3746
3747impl RateLimit for CountTokensWithLanguageModelRateLimit {
3748 fn capacity() -> usize {
3749 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3750 .ok()
3751 .and_then(|v| v.parse().ok())
3752 .unwrap_or(600) // Picked arbitrarily
3753 }
3754
3755 fn refill_duration() -> chrono::Duration {
3756 chrono::Duration::hours(1)
3757 }
3758
3759 fn db_name() -> &'static str {
3760 "count-tokens-with-language-model"
3761 }
3762}
3763
3764async fn count_tokens_with_language_model(
3765 request: proto::CountTokensWithLanguageModel,
3766 response: Response<proto::CountTokensWithLanguageModel>,
3767 session: UserSession,
3768 google_ai_api_key: Option<Arc<str>>,
3769) -> Result<()> {
3770 authorize_access_to_language_models(&session).await?;
3771
3772 if !request.model.starts_with("gemini") {
3773 return Err(anyhow!(
3774 "counting tokens for model: {:?} is not supported",
3775 request.model
3776 ))?;
3777 }
3778
3779 session
3780 .rate_limiter
3781 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
3782 .await?;
3783
3784 let api_key = google_ai_api_key
3785 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3786 let tokens_response = google_ai::count_tokens(
3787 &session.http_client,
3788 google_ai::API_URL,
3789 &api_key,
3790 crate::ai::count_tokens_request_to_google_ai(request)?,
3791 )
3792 .await?;
3793 response.send(proto::CountTokensResponse {
3794 token_count: tokens_response.total_tokens as u32,
3795 })?;
3796 Ok(())
3797}
3798
3799async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
3800 let db = session.db().await;
3801 let flags = db.get_user_flags(session.user_id()).await?;
3802 if flags.iter().any(|flag| flag == "language-models") {
3803 Ok(())
3804 } else {
3805 Err(anyhow!("permission denied"))?
3806 }
3807}
3808
3809/// Start receiving chat updates for a channel
3810async fn join_channel_chat(
3811 request: proto::JoinChannelChat,
3812 response: Response<proto::JoinChannelChat>,
3813 session: UserSession,
3814) -> Result<()> {
3815 let channel_id = ChannelId::from_proto(request.channel_id);
3816
3817 let db = session.db().await;
3818 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3819 .await?;
3820 let messages = db
3821 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3822 .await?;
3823 response.send(proto::JoinChannelChatResponse {
3824 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3825 messages,
3826 })?;
3827 Ok(())
3828}
3829
3830/// Stop receiving chat updates for a channel
3831async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
3832 let channel_id = ChannelId::from_proto(request.channel_id);
3833 session
3834 .db()
3835 .await
3836 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3837 .await?;
3838 Ok(())
3839}
3840
3841/// Retrieve the chat history for a channel
3842async fn get_channel_messages(
3843 request: proto::GetChannelMessages,
3844 response: Response<proto::GetChannelMessages>,
3845 session: UserSession,
3846) -> Result<()> {
3847 let channel_id = ChannelId::from_proto(request.channel_id);
3848 let messages = session
3849 .db()
3850 .await
3851 .get_channel_messages(
3852 channel_id,
3853 session.user_id(),
3854 MESSAGE_COUNT_PER_PAGE,
3855 Some(MessageId::from_proto(request.before_message_id)),
3856 )
3857 .await?;
3858 response.send(proto::GetChannelMessagesResponse {
3859 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3860 messages,
3861 })?;
3862 Ok(())
3863}
3864
3865/// Retrieve specific chat messages
3866async fn get_channel_messages_by_id(
3867 request: proto::GetChannelMessagesById,
3868 response: Response<proto::GetChannelMessagesById>,
3869 session: UserSession,
3870) -> Result<()> {
3871 let message_ids = request
3872 .message_ids
3873 .iter()
3874 .map(|id| MessageId::from_proto(*id))
3875 .collect::<Vec<_>>();
3876 let messages = session
3877 .db()
3878 .await
3879 .get_channel_messages_by_id(session.user_id(), &message_ids)
3880 .await?;
3881 response.send(proto::GetChannelMessagesResponse {
3882 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3883 messages,
3884 })?;
3885 Ok(())
3886}
3887
3888/// Retrieve the current users notifications
3889async fn get_notifications(
3890 request: proto::GetNotifications,
3891 response: Response<proto::GetNotifications>,
3892 session: UserSession,
3893) -> Result<()> {
3894 let notifications = session
3895 .db()
3896 .await
3897 .get_notifications(
3898 session.user_id(),
3899 NOTIFICATION_COUNT_PER_PAGE,
3900 request
3901 .before_id
3902 .map(|id| db::NotificationId::from_proto(id)),
3903 )
3904 .await?;
3905 response.send(proto::GetNotificationsResponse {
3906 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3907 notifications,
3908 })?;
3909 Ok(())
3910}
3911
3912/// Mark notifications as read
3913async fn mark_notification_as_read(
3914 request: proto::MarkNotificationRead,
3915 response: Response<proto::MarkNotificationRead>,
3916 session: UserSession,
3917) -> Result<()> {
3918 let database = &session.db().await;
3919 let notifications = database
3920 .mark_notification_as_read_by_id(
3921 session.user_id(),
3922 NotificationId::from_proto(request.notification_id),
3923 )
3924 .await?;
3925 send_notifications(
3926 &*session.connection_pool().await,
3927 &session.peer,
3928 notifications,
3929 );
3930 response.send(proto::Ack {})?;
3931 Ok(())
3932}
3933
3934/// Get the current users information
3935async fn get_private_user_info(
3936 _request: proto::GetPrivateUserInfo,
3937 response: Response<proto::GetPrivateUserInfo>,
3938 session: UserSession,
3939) -> Result<()> {
3940 let db = session.db().await;
3941
3942 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3943 let user = db
3944 .get_user_by_id(session.user_id())
3945 .await?
3946 .ok_or_else(|| anyhow!("user not found"))?;
3947 let flags = db.get_user_flags(session.user_id()).await?;
3948
3949 response.send(proto::GetPrivateUserInfoResponse {
3950 metrics_id,
3951 staff: user.admin,
3952 flags,
3953 })?;
3954 Ok(())
3955}
3956
3957fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3958 match message {
3959 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3960 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3961 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3962 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3963 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3964 code: frame.code.into(),
3965 reason: frame.reason,
3966 })),
3967 }
3968}
3969
3970fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3971 match message {
3972 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3973 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3974 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3975 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3976 AxumMessage::Close(frame) => {
3977 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3978 code: frame.code.into(),
3979 reason: frame.reason,
3980 }))
3981 }
3982 }
3983}
3984
3985fn notify_membership_updated(
3986 connection_pool: &mut ConnectionPool,
3987 result: MembershipUpdated,
3988 user_id: UserId,
3989 peer: &Peer,
3990) {
3991 for membership in &result.new_channels.channel_memberships {
3992 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3993 }
3994 for channel_id in &result.removed_channels {
3995 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3996 }
3997
3998 let user_channels_update = proto::UpdateUserChannels {
3999 channel_memberships: result
4000 .new_channels
4001 .channel_memberships
4002 .iter()
4003 .map(|cm| proto::ChannelMembership {
4004 channel_id: cm.channel_id.to_proto(),
4005 role: cm.role.into(),
4006 })
4007 .collect(),
4008 ..Default::default()
4009 };
4010
4011 let mut update = build_channels_update(result.new_channels, vec![]);
4012 update.delete_channels = result
4013 .removed_channels
4014 .into_iter()
4015 .map(|id| id.to_proto())
4016 .collect();
4017 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4018
4019 for connection_id in connection_pool.user_connection_ids(user_id) {
4020 peer.send(connection_id, user_channels_update.clone())
4021 .trace_err();
4022 peer.send(connection_id, update.clone()).trace_err();
4023 }
4024}
4025
4026fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4027 proto::UpdateUserChannels {
4028 channel_memberships: channels
4029 .channel_memberships
4030 .iter()
4031 .map(|m| proto::ChannelMembership {
4032 channel_id: m.channel_id.to_proto(),
4033 role: m.role.into(),
4034 })
4035 .collect(),
4036 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4037 observed_channel_message_id: channels.observed_channel_messages.clone(),
4038 }
4039}
4040
4041fn build_channels_update(
4042 channels: ChannelsForUser,
4043 channel_invites: Vec<db::Channel>,
4044) -> proto::UpdateChannels {
4045 let mut update = proto::UpdateChannels::default();
4046
4047 for channel in channels.channels {
4048 update.channels.push(channel.to_proto());
4049 }
4050
4051 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4052 update.latest_channel_message_ids = channels.latest_channel_messages;
4053
4054 for (channel_id, participants) in channels.channel_participants {
4055 update
4056 .channel_participants
4057 .push(proto::ChannelParticipants {
4058 channel_id: channel_id.to_proto(),
4059 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4060 });
4061 }
4062
4063 for channel in channel_invites {
4064 update.channel_invitations.push(channel.to_proto());
4065 }
4066 for project in channels.hosted_projects {
4067 update.hosted_projects.push(project);
4068 }
4069
4070 update
4071}
4072
4073fn build_initial_contacts_update(
4074 contacts: Vec<db::Contact>,
4075 pool: &ConnectionPool,
4076) -> proto::UpdateContacts {
4077 let mut update = proto::UpdateContacts::default();
4078
4079 for contact in contacts {
4080 match contact {
4081 db::Contact::Accepted { user_id, busy } => {
4082 update.contacts.push(contact_for_user(user_id, busy, &pool));
4083 }
4084 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4085 db::Contact::Incoming { user_id } => {
4086 update
4087 .incoming_requests
4088 .push(proto::IncomingContactRequest {
4089 requester_id: user_id.to_proto(),
4090 })
4091 }
4092 }
4093 }
4094
4095 update
4096}
4097
4098fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4099 proto::Contact {
4100 user_id: user_id.to_proto(),
4101 online: pool.is_user_online(user_id),
4102 busy,
4103 }
4104}
4105
4106fn room_updated(room: &proto::Room, peer: &Peer) {
4107 broadcast(
4108 None,
4109 room.participants
4110 .iter()
4111 .filter_map(|participant| Some(participant.peer_id?.into())),
4112 |peer_id| {
4113 peer.send(
4114 peer_id,
4115 proto::RoomUpdated {
4116 room: Some(room.clone()),
4117 },
4118 )
4119 },
4120 );
4121}
4122
4123fn channel_updated(
4124 channel: &db::channel::Model,
4125 room: &proto::Room,
4126 peer: &Peer,
4127 pool: &ConnectionPool,
4128) {
4129 let participants = room
4130 .participants
4131 .iter()
4132 .map(|p| p.user_id)
4133 .collect::<Vec<_>>();
4134
4135 broadcast(
4136 None,
4137 pool.channel_connection_ids(channel.root_id())
4138 .filter_map(|(channel_id, role)| {
4139 role.can_see_channel(channel.visibility).then(|| channel_id)
4140 }),
4141 |peer_id| {
4142 peer.send(
4143 peer_id,
4144 proto::UpdateChannels {
4145 channel_participants: vec![proto::ChannelParticipants {
4146 channel_id: channel.id.to_proto(),
4147 participant_user_ids: participants.clone(),
4148 }],
4149 ..Default::default()
4150 },
4151 )
4152 },
4153 );
4154}
4155
4156async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4157 let db = session.db().await;
4158
4159 let contacts = db.get_contacts(user_id).await?;
4160 let busy = db.is_user_busy(user_id).await?;
4161
4162 let pool = session.connection_pool().await;
4163 let updated_contact = contact_for_user(user_id, busy, &pool);
4164 for contact in contacts {
4165 if let db::Contact::Accepted {
4166 user_id: contact_user_id,
4167 ..
4168 } = contact
4169 {
4170 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4171 session
4172 .peer
4173 .send(
4174 contact_conn_id,
4175 proto::UpdateContacts {
4176 contacts: vec![updated_contact.clone()],
4177 remove_contacts: Default::default(),
4178 incoming_requests: Default::default(),
4179 remove_incoming_requests: Default::default(),
4180 outgoing_requests: Default::default(),
4181 remove_outgoing_requests: Default::default(),
4182 },
4183 )
4184 .trace_err();
4185 }
4186 }
4187 }
4188 Ok(())
4189}
4190
4191async fn leave_room_for_session(session: &UserSession) -> Result<()> {
4192 let mut contacts_to_update = HashSet::default();
4193
4194 let room_id;
4195 let canceled_calls_to_user_ids;
4196 let live_kit_room;
4197 let delete_live_kit_room;
4198 let room;
4199 let channel;
4200
4201 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
4202 contacts_to_update.insert(session.user_id());
4203
4204 for project in left_room.left_projects.values() {
4205 project_left(project, session);
4206 }
4207
4208 room_id = RoomId::from_proto(left_room.room.id);
4209 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4210 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4211 delete_live_kit_room = left_room.deleted;
4212 room = mem::take(&mut left_room.room);
4213 channel = mem::take(&mut left_room.channel);
4214
4215 room_updated(&room, &session.peer);
4216 } else {
4217 return Ok(());
4218 }
4219
4220 if let Some(channel) = channel {
4221 channel_updated(
4222 &channel,
4223 &room,
4224 &session.peer,
4225 &*session.connection_pool().await,
4226 );
4227 }
4228
4229 {
4230 let pool = session.connection_pool().await;
4231 for canceled_user_id in canceled_calls_to_user_ids {
4232 for connection_id in pool.user_connection_ids(canceled_user_id) {
4233 session
4234 .peer
4235 .send(
4236 connection_id,
4237 proto::CallCanceled {
4238 room_id: room_id.to_proto(),
4239 },
4240 )
4241 .trace_err();
4242 }
4243 contacts_to_update.insert(canceled_user_id);
4244 }
4245 }
4246
4247 for contact_user_id in contacts_to_update {
4248 update_user_contacts(contact_user_id, &session).await?;
4249 }
4250
4251 if let Some(live_kit) = session.live_kit_client.as_ref() {
4252 live_kit
4253 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4254 .await
4255 .trace_err();
4256
4257 if delete_live_kit_room {
4258 live_kit.delete_room(live_kit_room).await.trace_err();
4259 }
4260 }
4261
4262 Ok(())
4263}
4264
4265async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4266 let left_channel_buffers = session
4267 .db()
4268 .await
4269 .leave_channel_buffers(session.connection_id)
4270 .await?;
4271
4272 for left_buffer in left_channel_buffers {
4273 channel_buffer_updated(
4274 session.connection_id,
4275 left_buffer.connections,
4276 &proto::UpdateChannelBufferCollaborators {
4277 channel_id: left_buffer.channel_id.to_proto(),
4278 collaborators: left_buffer.collaborators,
4279 },
4280 &session.peer,
4281 );
4282 }
4283
4284 Ok(())
4285}
4286
4287fn project_left(project: &db::LeftProject, session: &UserSession) {
4288 for connection_id in &project.connection_ids {
4289 if project.host_user_id == Some(session.user_id()) {
4290 session
4291 .peer
4292 .send(
4293 *connection_id,
4294 proto::UnshareProject {
4295 project_id: project.id.to_proto(),
4296 },
4297 )
4298 .trace_err();
4299 } else {
4300 session
4301 .peer
4302 .send(
4303 *connection_id,
4304 proto::RemoveProjectCollaborator {
4305 project_id: project.id.to_proto(),
4306 peer_id: Some(session.connection_id.into()),
4307 },
4308 )
4309 .trace_err();
4310 }
4311 }
4312}
4313
4314pub trait ResultExt {
4315 type Ok;
4316
4317 fn trace_err(self) -> Option<Self::Ok>;
4318}
4319
4320impl<T, E> ResultExt for Result<T, E>
4321where
4322 E: std::fmt::Debug,
4323{
4324 type Ok = T;
4325
4326 fn trace_err(self) -> Option<T> {
4327 match self {
4328 Ok(value) => Some(value),
4329 Err(error) => {
4330 tracing::error!("{:?}", error);
4331 None
4332 }
4333 }
4334 }
4335}