@@ -1613,9 +1613,18 @@ impl ChannelClient {
pub fn request<T: RequestMessage>(
&self,
payload: T,
+ ) -> impl 'static + Future<Output = Result<T::Response>> {
+ self.request_internal(payload, true)
+ }
+
+ fn request_internal<T: RequestMessage>(
+ &self,
+ payload: T,
+ use_buffer: bool,
) -> impl 'static + Future<Output = Result<T::Response>> {
log::debug!("ssh request start. name:{}", T::NAME);
- let response = self.request_dynamic(payload.into_envelope(0, None, None), T::NAME);
+ let response =
+ self.request_dynamic(payload.into_envelope(0, None, None), T::NAME, use_buffer);
async move {
let response = response.await?;
log::debug!("ssh request finish. name:{}", T::NAME);
@@ -1627,7 +1636,9 @@ impl ChannelClient {
pub async fn resync(&self, timeout: Duration) -> Result<()> {
smol::future::or(
async {
- self.request(proto::FlushBufferedMessages {}).await?;
+ self.request_internal(proto::FlushBufferedMessages {}, false)
+ .await?;
+
for envelope in self.buffer.lock().iter() {
self.outgoing_tx
.lock()
@@ -1663,10 +1674,11 @@ impl ChannelClient {
self.send_dynamic(payload.into_envelope(0, None, None))
}
- pub fn request_dynamic(
+ fn request_dynamic(
&self,
mut envelope: proto::Envelope,
type_name: &'static str,
+ use_buffer: bool,
) -> impl 'static + Future<Output = Result<proto::Envelope>> {
envelope.id = self.next_message_id.fetch_add(1, SeqCst);
let (tx, rx) = oneshot::channel();
@@ -1674,7 +1686,11 @@ impl ChannelClient {
response_channels_lock.insert(MessageId(envelope.id), tx);
drop(response_channels_lock);
- let result = self.send_buffered(envelope);
+ let result = if use_buffer {
+ self.send_buffered(envelope)
+ } else {
+ self.send_unbuffered(envelope)
+ };
async move {
if let Err(error) = &result {
log::error!("failed to send message: {}", error);
@@ -1694,7 +1710,7 @@ impl ChannelClient {
self.send_buffered(envelope)
}
- pub fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
+ fn send_buffered(&self, mut envelope: proto::Envelope) -> Result<()> {
envelope.ack_id = Some(self.max_received.load(SeqCst));
self.buffer.lock().push_back(envelope.clone());
// ignore errors on send (happen while we're reconnecting)
@@ -1702,6 +1718,12 @@ impl ChannelClient {
self.outgoing_tx.lock().unbounded_send(envelope).ok();
Ok(())
}
+
+ fn send_unbuffered(&self, mut envelope: proto::Envelope) -> Result<()> {
+ envelope.ack_id = Some(self.max_received.load(SeqCst));
+ self.outgoing_tx.lock().unbounded_send(envelope).ok();
+ Ok(())
+ }
}
impl ProtoClient for ChannelClient {
@@ -1710,7 +1732,7 @@ impl ProtoClient for ChannelClient {
envelope: proto::Envelope,
request_type: &'static str,
) -> BoxFuture<'static, Result<proto::Envelope>> {
- self.request_dynamic(envelope, request_type).boxed()
+ self.request_dynamic(envelope, request_type, true).boxed()
}
fn send(&self, envelope: proto::Envelope, _message_type: &'static str) -> Result<()> {