@@ -1030,19 +1030,26 @@ where
})
}
- pub async fn decline_call(&self, room_id: RoomId, user_id: UserId) -> Result<proto::Room> {
+ pub async fn decline_call(
+ &self,
+ expected_room_id: Option<RoomId>,
+ user_id: UserId,
+ ) -> Result<proto::Room> {
test_support!(self, {
let mut tx = self.pool.begin().await?;
- sqlx::query(
+ let room_id = sqlx::query_scalar(
"
DELETE FROM room_participants
- WHERE room_id = $1 AND user_id = $2 AND answering_connection_id IS NULL
+ WHERE user_id = $1 AND answering_connection_id IS NULL
+ RETURNING room_id
",
)
- .bind(room_id)
.bind(user_id)
- .execute(&mut tx)
+ .fetch_one(&mut tx)
.await?;
+ if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
+ return Err(anyhow!("declining call on unexpected room"))?;
+ }
self.commit_room_transaction(room_id, tx).await
})
@@ -1050,23 +1057,26 @@ where
pub async fn cancel_call(
&self,
- room_id: RoomId,
+ expected_room_id: Option<RoomId>,
calling_connection_id: ConnectionId,
called_user_id: UserId,
) -> Result<proto::Room> {
test_support!(self, {
let mut tx = self.pool.begin().await?;
- sqlx::query(
+ let room_id = sqlx::query_scalar(
"
DELETE FROM room_participants
- WHERE room_id = $1 AND user_id = $2 AND calling_connection_id = $3 AND answering_connection_id IS NULL
+ WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL
+ RETURNING room_id
",
)
- .bind(room_id)
.bind(called_user_id)
.bind(calling_connection_id.0 as i32)
- .execute(&mut tx)
+ .fetch_one(&mut tx)
.await?;
+ if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
+ return Err(anyhow!("canceling call on unexpected room"))?;
+ }
self.commit_room_transaction(room_id, tx).await
})
@@ -430,9 +430,29 @@ impl Server {
user_id: UserId,
) -> Result<()> {
self.peer.disconnect(connection_id);
- self.store().await.remove_connection(connection_id)?;
+ let decline_calls = {
+ let mut store = self.store().await;
+ store.remove_connection(connection_id)?;
+ let mut connections = store.connection_ids_for_user(user_id);
+ connections.next().is_none()
+ };
+
self.leave_room_for_connection(connection_id, user_id)
- .await?;
+ .await
+ .trace_err();
+ if decline_calls {
+ if let Some(room) = self
+ .app_state
+ .db
+ .decline_call(None, user_id)
+ .await
+ .trace_err()
+ {
+ self.room_updated(&room);
+ }
+ }
+
+ self.update_user_contacts(user_id).await?;
Ok(())
}
@@ -761,11 +781,10 @@ impl Server {
) -> Result<()> {
let called_user_id = UserId::from_proto(request.payload.called_user_id);
let room_id = RoomId::from_proto(request.payload.room_id);
-
let room = self
.app_state
.db
- .cancel_call(room_id, request.sender_connection_id, called_user_id)
+ .cancel_call(Some(room_id), request.sender_connection_id, called_user_id)
.await?;
for connection_id in self.store().await.connection_ids_for_user(called_user_id) {
self.peer
@@ -780,13 +799,11 @@ impl Server {
}
async fn decline_call(self: Arc<Server>, message: Message<proto::DeclineCall>) -> Result<()> {
+ let room_id = RoomId::from_proto(message.payload.room_id);
let room = self
.app_state
.db
- .decline_call(
- RoomId::from_proto(message.payload.room_id),
- message.sender_user_id,
- )
+ .decline_call(Some(room_id), message.sender_user_id)
.await?;
for connection_id in self
.store()