tokio_xmpp: implement IQ tracking

Jonas SchΓ€fer created

Change summary

tokio-xmpp/Cargo.toml                |   2 
tokio-xmpp/ChangeLog                 |   3 
tokio-xmpp/examples/contact_addr.rs  |  67 +++--
tokio-xmpp/src/client/iq.rs          | 305 ++++++++++++++++++++++++++++++
tokio-xmpp/src/client/mod.rs         |  35 +++
tokio-xmpp/src/client/stream.rs      |  11 
tokio-xmpp/src/event.rs              |   2 
tokio-xmpp/src/lib.rs                |   2 
tokio-xmpp/src/stanzastream/queue.rs |   4 
9 files changed, 399 insertions(+), 32 deletions(-)

Detailed changes

tokio-xmpp/Cargo.toml πŸ”—

@@ -16,7 +16,7 @@ bytes = "1"
 futures = "0.3"
 log = "0.4"
 tokio = { version = "1", features = ["net", "rt", "rt-multi-thread", "macros"] }
-tokio-stream = { version = "0.1", features = [] }
+tokio-stream = { version = "0.1", features = ["sync"] }
 webpki-roots = { version = "0.26", optional = true }
 rustls-native-certs = { version = "0.7", optional = true }
 rxml = { version = "0.12.0", features = ["compact_str"] }

tokio-xmpp/ChangeLog πŸ”—

@@ -25,6 +25,9 @@ XXXX-YY-ZZ RELEASER <admin@example.com>
       - `Component` is now gated behind `insecure-tcp` feature flag
       - `XMPPStream` and `XmppCodec` were removed in favour of the newly
         implemented `tokio_xmpp::xmlstream module.
+    * Added:
+      - Support for sending IQ requests while tracking their responses in a
+        Future.
     * Changes:
       - On Linux, once the TLS session is established, we can delegate the
         actual encryption and decryption to the kernel, which in turn can

tokio-xmpp/examples/contact_addr.rs πŸ”—

@@ -2,10 +2,9 @@ use futures::stream::StreamExt;
 use std::env::args;
 use std::process::exit;
 use std::str::FromStr;
-use tokio_xmpp::{Client, Stanza};
+use tokio_xmpp::{Client, IqRequest, IqResponse};
 use xmpp_parsers::{
     disco::{DiscoInfoQuery, DiscoInfoResult},
-    iq::{Iq, IqType},
     jid::{BareJid, Jid},
     ns,
     server_info::ServerInfo,
@@ -22,45 +21,59 @@ async fn main() {
     }
     let jid = BareJid::from_str(&args[1]).expect(&format!("Invalid JID: {}", &args[1]));
     let password = args[2].clone();
-    let target = &args[3];
+    let target = Jid::from_str(&args[3]).expect(&format!("Invalid JID: {}", &args[3]));
 
     // Client instance
     let mut client = Client::new(jid, password);
 
-    // Main loop, processes events
-    while let Some(event) = client.next().await {
-        if event.is_online() {
-            println!("Online!");
+    let token = client
+        .send_iq(
+            Some(target),
+            IqRequest::Get(DiscoInfoQuery { node: None }.into()),
+        )
+        .await;
+    tokio::pin!(token);
 
-            let target_jid: Jid = target.clone().parse().unwrap();
-            let iq = make_disco_iq(target_jid);
-            println!("Sending disco#info request to {}", target.clone());
-            println!(">> {:?}", iq);
-            client.send_stanza(iq.into()).await.unwrap();
-        } else if let Some(Stanza::Iq(iq)) = event.into_stanza() {
-            if let IqType::Result(Some(payload)) = iq.payload {
-                if payload.is("query", ns::DISCO_INFO) {
-                    if let Ok(disco_info) = DiscoInfoResult::try_from(payload) {
-                        for ext in disco_info.extensions {
-                            if let Ok(server_info) = ServerInfo::try_from(ext) {
-                                print_server_info(server_info);
+    // Main loop, processes events
+    loop {
+        tokio::select! {
+            response = &mut token => match response {
+                Ok(IqResponse::Result(Some(payload))) => {
+                    if payload.is("query", ns::DISCO_INFO) {
+                        if let Ok(disco_info) = DiscoInfoResult::try_from(payload) {
+                            for ext in disco_info.extensions {
+                                if let Ok(server_info) = ServerInfo::try_from(ext) {
+                                    print_server_info(server_info);
+                                }
                             }
                         }
                     }
+                    break;
                 }
-                break;
-            }
+                Ok(IqResponse::Result(None)) => {
+                    panic!("disco#info response misses payload!");
+                }
+                Ok(IqResponse::Error(err)) => {
+                    panic!("disco#info response is an error: {:?}", err);
+                }
+                Err(err) => {
+                    panic!("disco#info request failed to send: {}", err);
+                }
+            },
+            event = client.next() => {
+                let Some(event) = event else {
+                    println!("Client terminated");
+                    break;
+                };
+                if event.is_online() {
+                    println!("Online!");
+                }
+            },
         }
     }
     client.send_end().await.expect("Stream shutdown unclean");
 }
 
-fn make_disco_iq(target: Jid) -> Iq {
-    Iq::from_get("disco", DiscoInfoQuery { node: None })
-        .with_id(String::from("contact"))
-        .with_to(target)
-}
-
 fn convert_field(field: Vec<String>) -> String {
     field
         .iter()

tokio-xmpp/src/client/iq.rs πŸ”—

@@ -0,0 +1,305 @@
+// Copyright (c) 2025 Jonas SchΓ€fer <jonas@zombofant.net>
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+use alloc::collections::BTreeMap;
+use alloc::sync::{Arc, Weak};
+use core::error::Error;
+use core::fmt;
+use core::future::Future;
+use core::ops::ControlFlow;
+use core::pin::Pin;
+use core::task::{ready, Context, Poll};
+use std::io;
+use std::sync::Mutex;
+
+use futures::Stream;
+use tokio::sync::oneshot;
+
+use xmpp_parsers::{
+    iq::{Iq, IqType},
+    stanza_error::StanzaError,
+};
+
+use crate::{
+    event::make_id,
+    jid::Jid,
+    minidom::Element,
+    stanzastream::{StanzaState, StanzaToken},
+};
+
+/// An IQ request payload
+pub enum IqRequest {
+    /// Payload for a `type="get"` request
+    Get(Element),
+
+    /// Payload for a `type="set"` request
+    Set(Element),
+}
+
+impl From<IqRequest> for IqType {
+    fn from(other: IqRequest) -> IqType {
+        match other {
+            IqRequest::Get(v) => Self::Get(v),
+            IqRequest::Set(v) => Self::Set(v),
+        }
+    }
+}
+
+/// An IQ response payload
+pub enum IqResponse {
+    /// Payload for a `type="result"` response.
+    Result(Option<Element>),
+
+    /// Payload for a `type="error"` response.
+    Error(StanzaError),
+}
+
+impl From<IqResponse> for IqType {
+    fn from(other: IqResponse) -> IqType {
+        match other {
+            IqResponse::Result(v) => Self::Result(v),
+            IqResponse::Error(v) => Self::Error(v),
+        }
+    }
+}
+
+/// Error enumeration for Iq sending failures
+#[derive(Debug)]
+pub enum IqFailure {
+    /// Internal error inside tokio_xmpp which caused the stream worker to
+    /// drop the token before the response was received.
+    ///
+    /// Most likely, this means that the stream has died with a panic.
+    LostWorker,
+
+    /// The IQ failed to send because of an I/O or serialisation error.
+    SendError(io::Error),
+}
+
+impl fmt::Display for IqFailure {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Self::LostWorker => {
+                f.write_str("disconnected from internal connection worker while sending IQ")
+            }
+            Self::SendError(e) => write!(f, "send error: {e}"),
+        }
+    }
+}
+
+impl Error for IqFailure {
+    fn source(&self) -> Option<&(dyn Error + 'static)> {
+        match self {
+            Self::SendError(ref e) => Some(e),
+            Self::LostWorker => None,
+        }
+    }
+}
+
+type IqKey = (Option<Jid>, String);
+type IqMap = BTreeMap<IqKey, IqResponseSink>;
+
+struct IqMapEntryHandle {
+    key: IqKey,
+    map: Weak<Mutex<IqMap>>,
+}
+
+impl Drop for IqMapEntryHandle {
+    fn drop(&mut self) {
+        let Some(map) = self.map.upgrade() else {
+            return;
+        };
+        let Some(mut map) = map.lock().ok() else {
+            return;
+        };
+        map.remove(&self.key);
+    }
+}
+
+pin_project_lite::pin_project! {
+    /// Handle for awaiting an IQ response.
+    ///
+    /// The `IqResponseToken` can be awaited and will generate a result once
+    /// the Iq response has been received. Note that an `Ok(_)` result does
+    /// **not** imply a successful execution of the remote command: It may
+    /// contain a [`IqResponse::Error`] variant.
+    ///
+    /// Note that there are no internal timeouts for Iq responses: If a reply
+    /// never arrives, the [`IqResponseToken`] future will never complete.
+    /// Most of the time, you should combine that token with something like
+    /// [`tokio::time::timeout`].
+    ///
+    /// Dropping (cancelling) an `IqResponseToken` removes the internal
+    /// bookkeeping required for tracking the response.
+    pub struct IqResponseToken {
+        entry: Option<IqMapEntryHandle>,
+        #[pin]
+        stanza_token: Option<tokio_stream::wrappers::WatchStream<StanzaState>>,
+        #[pin]
+        inner: oneshot::Receiver<Result<IqResponse, IqFailure>>,
+    }
+}
+
+impl IqResponseToken {
+    /// Tie a stanza token to this IQ response token.
+    ///
+    /// The stanza token should point at the IQ **request**, the response of
+    /// which this response token awaits.
+    ///
+    /// Awaiting the response token will then handle error states in the
+    /// stanza token and return IqFailure as appropriate.
+    pub(crate) fn set_stanza_token(&mut self, token: StanzaToken) {
+        assert!(self.stanza_token.is_none());
+        self.stanza_token = Some(token.into_stream());
+    }
+}
+
+impl Future for IqResponseToken {
+    type Output = Result<IqResponse, IqFailure>;
+
+    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+        let mut this = self.project();
+        match this.inner.poll(cx) {
+            Poll::Ready(Ok(v)) => {
+                // Drop the map entry handle to release some memory.
+                this.entry.take();
+                return Poll::Ready(v);
+            }
+            Poll::Ready(Err(_)) => {
+                log::warn!("IqResponseToken oneshot::Receiver returned receive error!");
+                // Drop the map entry handle to release some memory.
+                this.entry.take();
+                return Poll::Ready(Err(IqFailure::LostWorker));
+            }
+            Poll::Pending => (),
+        };
+
+        loop {
+            match this.stanza_token.as_mut().as_pin_mut() {
+                // We have a stanza token to look at, so we check its state.
+                Some(stream) => match ready!(stream.poll_next(cx)) {
+                    // Still in the queue.
+                    Some(StanzaState::Queued) => (),
+
+                    Some(StanzaState::Dropped) | None => {
+                        log::warn!("StanzaToken associated with IqResponseToken signalled that the Stanza was dropped before transmission.");
+                        // Drop the map entry handle to release some memory.
+                        this.entry.take();
+                        // Lost stanza stream: cannot ever get a reply.
+                        return Poll::Ready(Err(IqFailure::LostWorker));
+                    }
+
+                    Some(StanzaState::Failed { error }) => {
+                        // Drop the map entry handle to release some memory.
+                        this.entry.take();
+                        // Send error: cannot ever get a reply.
+                        return Poll::Ready(Err(IqFailure::SendError(error.into_io_error())));
+                    }
+
+                    Some(StanzaState::Sent { .. }) | Some(StanzaState::Acked { .. }) => {
+                        // Sent successfully, stop polling the stream: We do
+                        // not care what happens after successful sending,
+                        // the next step we expect is that this.inner
+                        // completes.
+                        *this.stanza_token = None;
+                        return Poll::Pending;
+                    }
+                },
+
+                // No StanzaToken to poll, so we return Poll::Pending and hope
+                // that we will get a response through this.inner eventually..
+                None => return Poll::Pending,
+            }
+        }
+    }
+}
+
+struct IqResponseSink {
+    inner: oneshot::Sender<Result<IqResponse, IqFailure>>,
+}
+
+impl IqResponseSink {
+    fn complete(self, resp: IqResponse) {
+        let _: Result<_, _> = self.inner.send(Ok(resp));
+    }
+}
+
+/// Utility struct to track IQ responses.
+pub struct IqResponseTracker {
+    map: Arc<Mutex<IqMap>>,
+}
+
+impl IqResponseTracker {
+    /// Create a new empty response tracker.
+    pub fn new() -> Self {
+        Self {
+            map: Arc::new(Mutex::new(IqMap::new())),
+        }
+    }
+
+    /// Attempt to handle an IQ stanza as IQ response.
+    ///
+    /// Returns the IQ stanza unharmed if it is not an IQ response matching
+    /// any request which is still being tracked.
+    pub fn handle_iq(&self, iq: Iq) -> ControlFlow<(), Iq> {
+        let payload = match iq.payload {
+            IqType::Error(error) => IqResponse::Error(error),
+            IqType::Result(result) => IqResponse::Result(result),
+            _ => return ControlFlow::Continue(iq),
+        };
+        let key = (iq.from, iq.id);
+        let mut map = self.map.lock().unwrap();
+        match map.remove(&key) {
+            None => {
+                log::trace!("not handling IQ response from {:?} with id {:?}: no active tracker for this tuple", key.0, key.1);
+                ControlFlow::Continue(Iq {
+                    from: key.0,
+                    id: key.1,
+                    to: iq.to,
+                    payload: payload.into(),
+                })
+            }
+            Some(sink) => {
+                sink.complete(payload);
+                ControlFlow::Break(())
+            }
+        }
+    }
+
+    /// Allocate a new IQ response tracking handle.
+    ///
+    /// This modifies the IQ to assign a unique ID.
+    pub fn allocate_iq_handle(
+        &self,
+        from: Option<Jid>,
+        to: Option<Jid>,
+        req: IqRequest,
+    ) -> (Iq, IqResponseToken) {
+        let key = (to, make_id());
+        let mut map = self.map.lock().unwrap();
+        let (tx, rx) = oneshot::channel();
+        let sink = IqResponseSink { inner: tx };
+        assert!(map.get(&key).is_none());
+        let token = IqResponseToken {
+            entry: Some(IqMapEntryHandle {
+                key: key.clone(),
+                map: Arc::downgrade(&self.map),
+            }),
+            stanza_token: None,
+            inner: rx,
+        };
+        map.insert(key.clone(), sink);
+        (
+            Iq {
+                from,
+                to: key.0,
+                id: key.1,
+                payload: req.into(),
+            },
+            token,
+        )
+    }
+}

tokio-xmpp/src/client/mod.rs πŸ”—

@@ -23,9 +23,12 @@ use crate::connect::StartTlsServerConnector;
 #[cfg(feature = "insecure-tcp")]
 use crate::connect::TcpServerConnector;
 
+mod iq;
 pub(crate) mod login;
 mod stream;
 
+pub use iq::{IqFailure, IqRequest, IqResponse, IqResponseToken};
+
 /// XMPP client connection and state
 ///
 /// This implements the `futures` crate's [`Stream`](#impl-Stream) to receive
@@ -37,6 +40,7 @@ pub struct Client {
     stream: StanzaStream,
     bound_jid: Option<Jid>,
     features: Option<StreamFeatures>,
+    iq_response_tracker: iq::IqResponseTracker,
 }
 
 impl Client {
@@ -59,6 +63,9 @@ impl Client {
     /// method can be called with [`StanzaStage::Acked`], but that stage will
     /// only ever be reached if the server supports XEP-0198 and it has been
     /// negotiated successfully (this may change in the future).
+    ///
+    /// For sending Iq request stanzas, it is recommended to use
+    /// [`send_iq`][`Self::send_iq`], which allows awaiting the response.
     pub async fn send_stanza(&mut self, mut stanza: Stanza) -> Result<StanzaToken, io::Error> {
         stanza.ensure_id();
         let mut token = self.stream.send(Box::new(stanza)).await;
@@ -75,6 +82,33 @@ impl Client {
         }
     }
 
+    /// Send an IQ request and return a token to retrieve the response.
+    ///
+    /// This coroutine method will complete once the Iq has been sent to the
+    /// server. The returned `IqResponseToken` can be used to await the
+    /// response. See also the documentation of [`IqResponseToken`] for more
+    /// information on the behaviour of these tokens.
+    ///
+    /// **Important**: Even though IQ responses are delivered through the
+    /// returned token (and never through the `Stream`), the
+    /// [`Stream`][`futures::Stream`]
+    /// implementation of the [`Client`] **must be polled** to make progress
+    /// on the stream and to process incoming stanzas and thus to deliver them
+    /// to the returned token.
+    ///
+    /// **Note**: If an IQ response arrives after the `token` has been
+    /// dropped (e.g. due to a timeout), it will be delivered through the
+    /// `Stream` like any other stanza.
+    pub async fn send_iq(&mut self, to: Option<Jid>, req: IqRequest) -> IqResponseToken {
+        let (iq, mut token) = self.iq_response_tracker.allocate_iq_handle(
+            // from is always None for a client
+            None, to, req,
+        );
+        let stanza_token = self.stream.send(Box::new(iq.into())).await;
+        token.set_stanza_token(stanza_token);
+        token
+    }
+
     /// Get the stream features (`<stream:features/>`) of the underlying
     /// stream.
     ///
@@ -153,6 +187,7 @@ impl Client {
             stream: StanzaStream::new_c2s(connector, jid.into(), password.into(), timeouts, 16),
             bound_jid: None,
             features: None,
+            iq_response_tracker: iq::IqResponseTracker::new(),
         }
     }
 }

tokio-xmpp/src/client/stream.rs πŸ”—

@@ -4,13 +4,14 @@
 // License, v. 2.0. If a copy of the MPL was not distributed with this
 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
 
+use core::ops::ControlFlow;
 use core::{pin::Pin, task::Context};
 use futures::{ready, task::Poll, Stream};
 
 use crate::{
     client::Client,
     stanzastream::{Event as StanzaStreamEvent, StreamEvent},
-    Event,
+    Event, Stanza,
 };
 
 /// Incoming XMPP events
@@ -34,7 +35,13 @@ impl Stream for Client {
         loop {
             return Poll::Ready(match ready!(Pin::new(&mut self.stream).poll_next(cx)) {
                 None => None,
-                Some(StanzaStreamEvent::Stanza(st)) => Some(Event::Stanza(st)),
+                Some(StanzaStreamEvent::Stanza(st)) => match st {
+                    Stanza::Iq(iq) => match self.iq_response_tracker.handle_iq(iq) {
+                        ControlFlow::Break(()) => continue,
+                        ControlFlow::Continue(iq) => Some(Event::Stanza(Stanza::Iq(iq))),
+                    },
+                    other => Some(Event::Stanza(other)),
+                },
                 Some(StanzaStreamEvent::Stream(StreamEvent::Reset {
                     bound_jid,
                     features,

tokio-xmpp/src/event.rs πŸ”—

@@ -16,7 +16,7 @@ use xso::{AsXml, FromXml};
 use crate::xmlstream::XmppStreamElement;
 use crate::Error;
 
-fn make_id() -> String {
+pub(crate) fn make_id() -> String {
     let id: u64 = thread_rng().gen();
     format!("{}", id)
 }

tokio-xmpp/src/lib.rs πŸ”—

@@ -64,7 +64,7 @@ pub mod xmlstream;
 #[doc(inline)]
 /// Generic tokio_xmpp Error
 pub use crate::error::Error;
-pub use client::Client;
+pub use client::{Client, IqFailure, IqRequest, IqResponse, IqResponseToken};
 #[cfg(feature = "insecure-tcp")]
 pub use component::Component;
 pub use event::{Event, Stanza};

tokio-xmpp/src/stanzastream/queue.rs πŸ”—

@@ -184,6 +184,10 @@ impl StanzaToken {
             .ok()
     }
 
+    pub(crate) fn into_stream(self) -> tokio_stream::wrappers::WatchStream<StanzaState> {
+        tokio_stream::wrappers::WatchStream::new(self.inner)
+    }
+
     /// Read the current transmission state.
     pub fn state(&self) -> StanzaState {
         self.inner.borrow().clone()