From 8a4d548d9aa2827cd226e6ee898020fa42bb701a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20Sch=C3=A4fer?= Date: Sun, 26 Jan 2025 11:10:44 +0100 Subject: [PATCH] tokio_xmpp: implement IQ tracking --- tokio-xmpp/Cargo.toml | 2 +- tokio-xmpp/ChangeLog | 3 + tokio-xmpp/examples/contact_addr.rs | 67 +++--- tokio-xmpp/src/client/iq.rs | 305 +++++++++++++++++++++++++++ tokio-xmpp/src/client/mod.rs | 35 +++ tokio-xmpp/src/client/stream.rs | 11 +- tokio-xmpp/src/event.rs | 2 +- tokio-xmpp/src/lib.rs | 2 +- tokio-xmpp/src/stanzastream/queue.rs | 4 + 9 files changed, 399 insertions(+), 32 deletions(-) create mode 100644 tokio-xmpp/src/client/iq.rs diff --git a/tokio-xmpp/Cargo.toml b/tokio-xmpp/Cargo.toml index 673102ca0286e0db2cc96a021b4cfb20e04dffba..262a2b30d7d92356727d65ee2d049fd8db0809cc 100644 --- a/tokio-xmpp/Cargo.toml +++ b/tokio-xmpp/Cargo.toml @@ -16,7 +16,7 @@ bytes = "1" futures = "0.3" log = "0.4" tokio = { version = "1", features = ["net", "rt", "rt-multi-thread", "macros"] } -tokio-stream = { version = "0.1", features = [] } +tokio-stream = { version = "0.1", features = ["sync"] } webpki-roots = { version = "0.26", optional = true } rustls-native-certs = { version = "0.7", optional = true } rxml = { version = "0.12.0", features = ["compact_str"] } diff --git a/tokio-xmpp/ChangeLog b/tokio-xmpp/ChangeLog index cbd09448f072b39200e59894ed10cc8861ee829a..cbf166fe83ca50e1bea8688d9a392dfead5ed101 100644 --- a/tokio-xmpp/ChangeLog +++ b/tokio-xmpp/ChangeLog @@ -25,6 +25,9 @@ XXXX-YY-ZZ RELEASER - `Component` is now gated behind `insecure-tcp` feature flag - `XMPPStream` and `XmppCodec` were removed in favour of the newly implemented `tokio_xmpp::xmlstream module. + * Added: + - Support for sending IQ requests while tracking their responses in a + Future. * Changes: - On Linux, once the TLS session is established, we can delegate the actual encryption and decryption to the kernel, which in turn can diff --git a/tokio-xmpp/examples/contact_addr.rs b/tokio-xmpp/examples/contact_addr.rs index 343b90d9408b7823226e8add3a4529a480d88a09..584250f17db26f0f501b6076e3a5979510d4c37e 100644 --- a/tokio-xmpp/examples/contact_addr.rs +++ b/tokio-xmpp/examples/contact_addr.rs @@ -2,10 +2,9 @@ use futures::stream::StreamExt; use std::env::args; use std::process::exit; use std::str::FromStr; -use tokio_xmpp::{Client, Stanza}; +use tokio_xmpp::{Client, IqRequest, IqResponse}; use xmpp_parsers::{ disco::{DiscoInfoQuery, DiscoInfoResult}, - iq::{Iq, IqType}, jid::{BareJid, Jid}, ns, server_info::ServerInfo, @@ -22,45 +21,59 @@ async fn main() { } let jid = BareJid::from_str(&args[1]).expect(&format!("Invalid JID: {}", &args[1])); let password = args[2].clone(); - let target = &args[3]; + let target = Jid::from_str(&args[3]).expect(&format!("Invalid JID: {}", &args[3])); // Client instance let mut client = Client::new(jid, password); - // Main loop, processes events - while let Some(event) = client.next().await { - if event.is_online() { - println!("Online!"); + let token = client + .send_iq( + Some(target), + IqRequest::Get(DiscoInfoQuery { node: None }.into()), + ) + .await; + tokio::pin!(token); - let target_jid: Jid = target.clone().parse().unwrap(); - let iq = make_disco_iq(target_jid); - println!("Sending disco#info request to {}", target.clone()); - println!(">> {:?}", iq); - client.send_stanza(iq.into()).await.unwrap(); - } else if let Some(Stanza::Iq(iq)) = event.into_stanza() { - if let IqType::Result(Some(payload)) = iq.payload { - if payload.is("query", ns::DISCO_INFO) { - if let Ok(disco_info) = DiscoInfoResult::try_from(payload) { - for ext in disco_info.extensions { - if let Ok(server_info) = ServerInfo::try_from(ext) { - print_server_info(server_info); + // Main loop, processes events + loop { + tokio::select! { + response = &mut token => match response { + Ok(IqResponse::Result(Some(payload))) => { + if payload.is("query", ns::DISCO_INFO) { + if let Ok(disco_info) = DiscoInfoResult::try_from(payload) { + for ext in disco_info.extensions { + if let Ok(server_info) = ServerInfo::try_from(ext) { + print_server_info(server_info); + } } } } + break; } - break; - } + Ok(IqResponse::Result(None)) => { + panic!("disco#info response misses payload!"); + } + Ok(IqResponse::Error(err)) => { + panic!("disco#info response is an error: {:?}", err); + } + Err(err) => { + panic!("disco#info request failed to send: {}", err); + } + }, + event = client.next() => { + let Some(event) = event else { + println!("Client terminated"); + break; + }; + if event.is_online() { + println!("Online!"); + } + }, } } client.send_end().await.expect("Stream shutdown unclean"); } -fn make_disco_iq(target: Jid) -> Iq { - Iq::from_get("disco", DiscoInfoQuery { node: None }) - .with_id(String::from("contact")) - .with_to(target) -} - fn convert_field(field: Vec) -> String { field .iter() diff --git a/tokio-xmpp/src/client/iq.rs b/tokio-xmpp/src/client/iq.rs new file mode 100644 index 0000000000000000000000000000000000000000..bfa7cbeaf55204844d6865d22c164524227506c1 --- /dev/null +++ b/tokio-xmpp/src/client/iq.rs @@ -0,0 +1,305 @@ +// Copyright (c) 2025 Jonas Schäfer +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +use alloc::collections::BTreeMap; +use alloc::sync::{Arc, Weak}; +use core::error::Error; +use core::fmt; +use core::future::Future; +use core::ops::ControlFlow; +use core::pin::Pin; +use core::task::{ready, Context, Poll}; +use std::io; +use std::sync::Mutex; + +use futures::Stream; +use tokio::sync::oneshot; + +use xmpp_parsers::{ + iq::{Iq, IqType}, + stanza_error::StanzaError, +}; + +use crate::{ + event::make_id, + jid::Jid, + minidom::Element, + stanzastream::{StanzaState, StanzaToken}, +}; + +/// An IQ request payload +pub enum IqRequest { + /// Payload for a `type="get"` request + Get(Element), + + /// Payload for a `type="set"` request + Set(Element), +} + +impl From for IqType { + fn from(other: IqRequest) -> IqType { + match other { + IqRequest::Get(v) => Self::Get(v), + IqRequest::Set(v) => Self::Set(v), + } + } +} + +/// An IQ response payload +pub enum IqResponse { + /// Payload for a `type="result"` response. + Result(Option), + + /// Payload for a `type="error"` response. + Error(StanzaError), +} + +impl From for IqType { + fn from(other: IqResponse) -> IqType { + match other { + IqResponse::Result(v) => Self::Result(v), + IqResponse::Error(v) => Self::Error(v), + } + } +} + +/// Error enumeration for Iq sending failures +#[derive(Debug)] +pub enum IqFailure { + /// Internal error inside tokio_xmpp which caused the stream worker to + /// drop the token before the response was received. + /// + /// Most likely, this means that the stream has died with a panic. + LostWorker, + + /// The IQ failed to send because of an I/O or serialisation error. + SendError(io::Error), +} + +impl fmt::Display for IqFailure { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::LostWorker => { + f.write_str("disconnected from internal connection worker while sending IQ") + } + Self::SendError(e) => write!(f, "send error: {e}"), + } + } +} + +impl Error for IqFailure { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Self::SendError(ref e) => Some(e), + Self::LostWorker => None, + } + } +} + +type IqKey = (Option, String); +type IqMap = BTreeMap; + +struct IqMapEntryHandle { + key: IqKey, + map: Weak>, +} + +impl Drop for IqMapEntryHandle { + fn drop(&mut self) { + let Some(map) = self.map.upgrade() else { + return; + }; + let Some(mut map) = map.lock().ok() else { + return; + }; + map.remove(&self.key); + } +} + +pin_project_lite::pin_project! { + /// Handle for awaiting an IQ response. + /// + /// The `IqResponseToken` can be awaited and will generate a result once + /// the Iq response has been received. Note that an `Ok(_)` result does + /// **not** imply a successful execution of the remote command: It may + /// contain a [`IqResponse::Error`] variant. + /// + /// Note that there are no internal timeouts for Iq responses: If a reply + /// never arrives, the [`IqResponseToken`] future will never complete. + /// Most of the time, you should combine that token with something like + /// [`tokio::time::timeout`]. + /// + /// Dropping (cancelling) an `IqResponseToken` removes the internal + /// bookkeeping required for tracking the response. + pub struct IqResponseToken { + entry: Option, + #[pin] + stanza_token: Option>, + #[pin] + inner: oneshot::Receiver>, + } +} + +impl IqResponseToken { + /// Tie a stanza token to this IQ response token. + /// + /// The stanza token should point at the IQ **request**, the response of + /// which this response token awaits. + /// + /// Awaiting the response token will then handle error states in the + /// stanza token and return IqFailure as appropriate. + pub(crate) fn set_stanza_token(&mut self, token: StanzaToken) { + assert!(self.stanza_token.is_none()); + self.stanza_token = Some(token.into_stream()); + } +} + +impl Future for IqResponseToken { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + match this.inner.poll(cx) { + Poll::Ready(Ok(v)) => { + // Drop the map entry handle to release some memory. + this.entry.take(); + return Poll::Ready(v); + } + Poll::Ready(Err(_)) => { + log::warn!("IqResponseToken oneshot::Receiver returned receive error!"); + // Drop the map entry handle to release some memory. + this.entry.take(); + return Poll::Ready(Err(IqFailure::LostWorker)); + } + Poll::Pending => (), + }; + + loop { + match this.stanza_token.as_mut().as_pin_mut() { + // We have a stanza token to look at, so we check its state. + Some(stream) => match ready!(stream.poll_next(cx)) { + // Still in the queue. + Some(StanzaState::Queued) => (), + + Some(StanzaState::Dropped) | None => { + log::warn!("StanzaToken associated with IqResponseToken signalled that the Stanza was dropped before transmission."); + // Drop the map entry handle to release some memory. + this.entry.take(); + // Lost stanza stream: cannot ever get a reply. + return Poll::Ready(Err(IqFailure::LostWorker)); + } + + Some(StanzaState::Failed { error }) => { + // Drop the map entry handle to release some memory. + this.entry.take(); + // Send error: cannot ever get a reply. + return Poll::Ready(Err(IqFailure::SendError(error.into_io_error()))); + } + + Some(StanzaState::Sent { .. }) | Some(StanzaState::Acked { .. }) => { + // Sent successfully, stop polling the stream: We do + // not care what happens after successful sending, + // the next step we expect is that this.inner + // completes. + *this.stanza_token = None; + return Poll::Pending; + } + }, + + // No StanzaToken to poll, so we return Poll::Pending and hope + // that we will get a response through this.inner eventually.. + None => return Poll::Pending, + } + } + } +} + +struct IqResponseSink { + inner: oneshot::Sender>, +} + +impl IqResponseSink { + fn complete(self, resp: IqResponse) { + let _: Result<_, _> = self.inner.send(Ok(resp)); + } +} + +/// Utility struct to track IQ responses. +pub struct IqResponseTracker { + map: Arc>, +} + +impl IqResponseTracker { + /// Create a new empty response tracker. + pub fn new() -> Self { + Self { + map: Arc::new(Mutex::new(IqMap::new())), + } + } + + /// Attempt to handle an IQ stanza as IQ response. + /// + /// Returns the IQ stanza unharmed if it is not an IQ response matching + /// any request which is still being tracked. + pub fn handle_iq(&self, iq: Iq) -> ControlFlow<(), Iq> { + let payload = match iq.payload { + IqType::Error(error) => IqResponse::Error(error), + IqType::Result(result) => IqResponse::Result(result), + _ => return ControlFlow::Continue(iq), + }; + let key = (iq.from, iq.id); + let mut map = self.map.lock().unwrap(); + match map.remove(&key) { + None => { + log::trace!("not handling IQ response from {:?} with id {:?}: no active tracker for this tuple", key.0, key.1); + ControlFlow::Continue(Iq { + from: key.0, + id: key.1, + to: iq.to, + payload: payload.into(), + }) + } + Some(sink) => { + sink.complete(payload); + ControlFlow::Break(()) + } + } + } + + /// Allocate a new IQ response tracking handle. + /// + /// This modifies the IQ to assign a unique ID. + pub fn allocate_iq_handle( + &self, + from: Option, + to: Option, + req: IqRequest, + ) -> (Iq, IqResponseToken) { + let key = (to, make_id()); + let mut map = self.map.lock().unwrap(); + let (tx, rx) = oneshot::channel(); + let sink = IqResponseSink { inner: tx }; + assert!(map.get(&key).is_none()); + let token = IqResponseToken { + entry: Some(IqMapEntryHandle { + key: key.clone(), + map: Arc::downgrade(&self.map), + }), + stanza_token: None, + inner: rx, + }; + map.insert(key.clone(), sink); + ( + Iq { + from, + to: key.0, + id: key.1, + payload: req.into(), + }, + token, + ) + } +} diff --git a/tokio-xmpp/src/client/mod.rs b/tokio-xmpp/src/client/mod.rs index 1c31cbdaf9ca3573e4082f759ef8dbecc94acf66..88b6e3b42a8a8d63fb26068561988ed86ccf21d8 100644 --- a/tokio-xmpp/src/client/mod.rs +++ b/tokio-xmpp/src/client/mod.rs @@ -23,9 +23,12 @@ use crate::connect::StartTlsServerConnector; #[cfg(feature = "insecure-tcp")] use crate::connect::TcpServerConnector; +mod iq; pub(crate) mod login; mod stream; +pub use iq::{IqFailure, IqRequest, IqResponse, IqResponseToken}; + /// XMPP client connection and state /// /// This implements the `futures` crate's [`Stream`](#impl-Stream) to receive @@ -37,6 +40,7 @@ pub struct Client { stream: StanzaStream, bound_jid: Option, features: Option, + iq_response_tracker: iq::IqResponseTracker, } impl Client { @@ -59,6 +63,9 @@ impl Client { /// method can be called with [`StanzaStage::Acked`], but that stage will /// only ever be reached if the server supports XEP-0198 and it has been /// negotiated successfully (this may change in the future). + /// + /// For sending Iq request stanzas, it is recommended to use + /// [`send_iq`][`Self::send_iq`], which allows awaiting the response. pub async fn send_stanza(&mut self, mut stanza: Stanza) -> Result { stanza.ensure_id(); let mut token = self.stream.send(Box::new(stanza)).await; @@ -75,6 +82,33 @@ impl Client { } } + /// Send an IQ request and return a token to retrieve the response. + /// + /// This coroutine method will complete once the Iq has been sent to the + /// server. The returned `IqResponseToken` can be used to await the + /// response. See also the documentation of [`IqResponseToken`] for more + /// information on the behaviour of these tokens. + /// + /// **Important**: Even though IQ responses are delivered through the + /// returned token (and never through the `Stream`), the + /// [`Stream`][`futures::Stream`] + /// implementation of the [`Client`] **must be polled** to make progress + /// on the stream and to process incoming stanzas and thus to deliver them + /// to the returned token. + /// + /// **Note**: If an IQ response arrives after the `token` has been + /// dropped (e.g. due to a timeout), it will be delivered through the + /// `Stream` like any other stanza. + pub async fn send_iq(&mut self, to: Option, req: IqRequest) -> IqResponseToken { + let (iq, mut token) = self.iq_response_tracker.allocate_iq_handle( + // from is always None for a client + None, to, req, + ); + let stanza_token = self.stream.send(Box::new(iq.into())).await; + token.set_stanza_token(stanza_token); + token + } + /// Get the stream features (``) of the underlying /// stream. /// @@ -153,6 +187,7 @@ impl Client { stream: StanzaStream::new_c2s(connector, jid.into(), password.into(), timeouts, 16), bound_jid: None, features: None, + iq_response_tracker: iq::IqResponseTracker::new(), } } } diff --git a/tokio-xmpp/src/client/stream.rs b/tokio-xmpp/src/client/stream.rs index c585276b4394793aca253abfd61e9584ffe19205..84858686f74d135fbd1849db9ba6efd168e5ede9 100644 --- a/tokio-xmpp/src/client/stream.rs +++ b/tokio-xmpp/src/client/stream.rs @@ -4,13 +4,14 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +use core::ops::ControlFlow; use core::{pin::Pin, task::Context}; use futures::{ready, task::Poll, Stream}; use crate::{ client::Client, stanzastream::{Event as StanzaStreamEvent, StreamEvent}, - Event, + Event, Stanza, }; /// Incoming XMPP events @@ -34,7 +35,13 @@ impl Stream for Client { loop { return Poll::Ready(match ready!(Pin::new(&mut self.stream).poll_next(cx)) { None => None, - Some(StanzaStreamEvent::Stanza(st)) => Some(Event::Stanza(st)), + Some(StanzaStreamEvent::Stanza(st)) => match st { + Stanza::Iq(iq) => match self.iq_response_tracker.handle_iq(iq) { + ControlFlow::Break(()) => continue, + ControlFlow::Continue(iq) => Some(Event::Stanza(Stanza::Iq(iq))), + }, + other => Some(Event::Stanza(other)), + }, Some(StanzaStreamEvent::Stream(StreamEvent::Reset { bound_jid, features, diff --git a/tokio-xmpp/src/event.rs b/tokio-xmpp/src/event.rs index e789c26e18b93a3d76999981777b474e9b9f99eb..c7815a3ca507d04d383bf1f6123ef98b622d0466 100644 --- a/tokio-xmpp/src/event.rs +++ b/tokio-xmpp/src/event.rs @@ -16,7 +16,7 @@ use xso::{AsXml, FromXml}; use crate::xmlstream::XmppStreamElement; use crate::Error; -fn make_id() -> String { +pub(crate) fn make_id() -> String { let id: u64 = thread_rng().gen(); format!("{}", id) } diff --git a/tokio-xmpp/src/lib.rs b/tokio-xmpp/src/lib.rs index b957f2607636d22302309f0774095e8af15b67ed..a08d11453beb603bbe6927abb00052799765b0b5 100644 --- a/tokio-xmpp/src/lib.rs +++ b/tokio-xmpp/src/lib.rs @@ -64,7 +64,7 @@ pub mod xmlstream; #[doc(inline)] /// Generic tokio_xmpp Error pub use crate::error::Error; -pub use client::Client; +pub use client::{Client, IqFailure, IqRequest, IqResponse, IqResponseToken}; #[cfg(feature = "insecure-tcp")] pub use component::Component; pub use event::{Event, Stanza}; diff --git a/tokio-xmpp/src/stanzastream/queue.rs b/tokio-xmpp/src/stanzastream/queue.rs index ffb0900154a063be26839ef8fc2a93379ed59b3a..04a8e3afdf67a80c6091cd5520c8512a9fe295a9 100644 --- a/tokio-xmpp/src/stanzastream/queue.rs +++ b/tokio-xmpp/src/stanzastream/queue.rs @@ -184,6 +184,10 @@ impl StanzaToken { .ok() } + pub(crate) fn into_stream(self) -> tokio_stream::wrappers::WatchStream { + tokio_stream::wrappers::WatchStream::new(self.inner) + } + /// Read the current transmission state. pub fn state(&self) -> StanzaState { self.inner.borrow().clone()