From ab10e30ac0cfa01fc0e3092c5dcb861aff44b797 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20Sch=C3=A4fer?= Date: Sat, 10 Aug 2024 15:05:42 +0200 Subject: [PATCH] Port crates to use new XSO-based xmlstream --- tokio-xmpp/Cargo.toml | 1 - tokio-xmpp/examples/contact_addr.rs | 31 +- tokio-xmpp/examples/download_avatars.rs | 172 +++++----- tokio-xmpp/examples/echo_bot.rs | 13 +- tokio-xmpp/examples/echo_component.rs | 13 +- tokio-xmpp/examples/echo_server.rs | 20 +- tokio-xmpp/src/client/bind.rs | 55 ++-- tokio-xmpp/src/client/login.rs | 106 +++--- tokio-xmpp/src/client/mod.rs | 24 +- tokio-xmpp/src/client/stream.rs | 122 ++++--- tokio-xmpp/src/component/login.rs | 59 ++-- tokio-xmpp/src/component/mod.rs | 14 +- tokio-xmpp/src/component/stream.rs | 24 +- tokio-xmpp/src/connect/mod.rs | 12 +- tokio-xmpp/src/connect/starttls.rs | 105 ++++-- tokio-xmpp/src/connect/tcp.rs | 25 +- tokio-xmpp/src/event.rs | 117 ++++++- tokio-xmpp/src/lib.rs | 3 +- tokio-xmpp/src/proto/mod.rs | 8 - tokio-xmpp/src/proto/xmpp_codec.rs | 410 ------------------------ tokio-xmpp/src/proto/xmpp_stream.rs | 193 ----------- tokio-xmpp/src/xmlstream/common.rs | 4 + tokio-xmpp/src/xmlstream/initiator.rs | 15 + tokio-xmpp/src/xmlstream/mod.rs | 9 +- tokio-xmpp/src/xmlstream/xmpp.rs | 6 +- xmpp/src/event_loop.rs | 35 +- 26 files changed, 623 insertions(+), 973 deletions(-) delete mode 100644 tokio-xmpp/src/proto/mod.rs delete mode 100644 tokio-xmpp/src/proto/xmpp_codec.rs delete mode 100644 tokio-xmpp/src/proto/xmpp_stream.rs diff --git a/tokio-xmpp/Cargo.toml b/tokio-xmpp/Cargo.toml index 709291e464e0d389c13ab78ce35f733ba79511c3..49e9b9f58ff76cf9eebd500abc4f76459c9aa2f9 100644 --- a/tokio-xmpp/Cargo.toml +++ b/tokio-xmpp/Cargo.toml @@ -17,7 +17,6 @@ futures = "0.3" log = "0.4" tokio = { version = "1", features = ["net", "rt", "rt-multi-thread", "macros"] } tokio-stream = { version = "0.1", features = [] } -tokio-util = { version = "0.7", features = ["codec"] } webpki-roots = { version = "0.26", optional = true } rustls-native-certs = { version = "0.7", optional = true } rxml = { version = "0.12.0", features = ["compact_str"] } diff --git a/tokio-xmpp/examples/contact_addr.rs b/tokio-xmpp/examples/contact_addr.rs index b473b11fa0f79345070bfa802283b18e32b2014a..d9a14484cec58323640dc82b89426a4a5c3fd9b6 100644 --- a/tokio-xmpp/examples/contact_addr.rs +++ b/tokio-xmpp/examples/contact_addr.rs @@ -1,9 +1,8 @@ use futures::stream::StreamExt; -use minidom::Element; use std::env::args; use std::process::exit; use std::str::FromStr; -use tokio_xmpp::Client; +use tokio_xmpp::{Client, Stanza}; use xmpp_parsers::{ disco::{DiscoInfoQuery, DiscoInfoResult}, iq::{Iq, IqType}, @@ -41,24 +40,21 @@ async fn main() { let target_jid: Jid = target.clone().parse().unwrap(); let iq = make_disco_iq(target_jid); println!("Sending disco#info request to {}", target.clone()); - println!(">> {}", String::from(&iq)); - client.send_stanza(iq).await.unwrap(); - } else if let Some(stanza) = event.into_stanza() { - if stanza.is("iq", "jabber:client") { - let iq = Iq::try_from(stanza).unwrap(); - if let IqType::Result(Some(payload)) = iq.payload { - if payload.is("query", ns::DISCO_INFO) { - if let Ok(disco_info) = DiscoInfoResult::try_from(payload) { - for ext in disco_info.extensions { - if let Ok(server_info) = ServerInfo::try_from(ext) { - print_server_info(server_info); - } + println!(">> {:?}", iq); + client.send_stanza(iq.into()).await.unwrap(); + } else if let Some(Stanza::Iq(iq)) = event.into_stanza() { + if let IqType::Result(Some(payload)) = iq.payload { + if payload.is("query", ns::DISCO_INFO) { + if let Ok(disco_info) = DiscoInfoResult::try_from(payload) { + for ext in disco_info.extensions { + if let Ok(server_info) = ServerInfo::try_from(ext) { + print_server_info(server_info); } } } - wait_for_stream_end = true; - client.send_end().await.unwrap(); } + wait_for_stream_end = true; + client.send_end().await.unwrap(); } } } else { @@ -67,11 +63,10 @@ async fn main() { } } -fn make_disco_iq(target: Jid) -> Element { +fn make_disco_iq(target: Jid) -> Iq { Iq::from_get("disco", DiscoInfoQuery { node: None }) .with_id(String::from("contact")) .with_to(target) - .into() } fn convert_field(field: Vec) -> String { diff --git a/tokio-xmpp/examples/download_avatars.rs b/tokio-xmpp/examples/download_avatars.rs index da818dce7789b23a5ecf6c2b5563acdc3697041f..db8e64384eb017bcda042ec6aac0a3fa792e54e0 100644 --- a/tokio-xmpp/examples/download_avatars.rs +++ b/tokio-xmpp/examples/download_avatars.rs @@ -1,11 +1,10 @@ use futures::stream::StreamExt; -use minidom::Element; use std::env::args; use std::fs::{create_dir_all, File}; use std::io::{self, Write}; use std::process::exit; use std::str::FromStr; -use tokio_xmpp::Client; +use tokio_xmpp::{Client, Stanza}; use xmpp_parsers::{ avatar::{Data as AvatarData, Metadata as AvatarMetadata}, caps::{compute_disco, hash_caps, Caps}, @@ -13,7 +12,6 @@ use xmpp_parsers::{ hashes::Algo, iq::{Iq, IqType}, jid::{BareJid, Jid}, - message::Message, ns, presence::{Presence, Type as PresenceType}, pubsub::{ @@ -55,100 +53,107 @@ async fn main() { let presence = make_presence(caps); client.send_stanza(presence.into()).await.unwrap(); } else if let Some(stanza) = event.into_stanza() { - if stanza.is("iq", "jabber:client") { - let iq = Iq::try_from(stanza).unwrap(); - if let IqType::Get(payload) = iq.payload { - if payload.is("query", ns::DISCO_INFO) { - let query = DiscoInfoQuery::try_from(payload); - match query { - Ok(query) => { - let mut disco = disco_info.clone(); - disco.node = query.node; - let iq = Iq::from_result(iq.id, Some(disco)) - .with_to(iq.from.unwrap()); - client.send_stanza(iq.into()).await.unwrap(); + match stanza { + Stanza::Iq(iq) => { + if let IqType::Get(payload) = iq.payload { + if payload.is("query", ns::DISCO_INFO) { + let query = DiscoInfoQuery::try_from(payload); + match query { + Ok(query) => { + let mut disco = disco_info.clone(); + disco.node = query.node; + let iq = Iq::from_result(iq.id, Some(disco)) + .with_to(iq.from.unwrap()); + client.send_stanza(iq.into()).await.unwrap(); + } + Err(err) => client + .send_stanza( + make_error( + iq.from.unwrap(), + iq.id, + ErrorType::Modify, + DefinedCondition::BadRequest, + &format!("{}", err), + ) + .into(), + ) + .await + .unwrap(), } - Err(err) => client - .send_stanza(make_error( - iq.from.unwrap(), - iq.id, - ErrorType::Modify, - DefinedCondition::BadRequest, - &format!("{}", err), - )) + } else { + // We MUST answer unhandled get iqs with a service-unavailable error. + client + .send_stanza( + make_error( + iq.from.unwrap(), + iq.id, + ErrorType::Cancel, + DefinedCondition::ServiceUnavailable, + "No handler defined for this kind of iq.", + ) + .into(), + ) .await - .unwrap(), + .unwrap(); } - } else { - // We MUST answer unhandled get iqs with a service-unavailable error. + } else if let IqType::Result(Some(payload)) = iq.payload { + if payload.is("pubsub", ns::PUBSUB) { + let pubsub = PubSub::try_from(payload).unwrap(); + let from = iq.from.clone().unwrap_or(jid.clone().into()); + handle_iq_result(pubsub, &from); + } + } else if let IqType::Set(_) = iq.payload { + // We MUST answer unhandled set iqs with a service-unavailable error. client - .send_stanza(make_error( - iq.from.unwrap(), - iq.id, - ErrorType::Cancel, - DefinedCondition::ServiceUnavailable, - "No handler defined for this kind of iq.", - )) + .send_stanza( + make_error( + iq.from.unwrap(), + iq.id, + ErrorType::Cancel, + DefinedCondition::ServiceUnavailable, + "No handler defined for this kind of iq.", + ) + .into(), + ) .await .unwrap(); } - } else if let IqType::Result(Some(payload)) = iq.payload { - if payload.is("pubsub", ns::PUBSUB) { - let pubsub = PubSub::try_from(payload).unwrap(); - let from = iq.from.clone().unwrap_or(jid.clone().into()); - handle_iq_result(pubsub, &from); - } - } else if let IqType::Set(_) = iq.payload { - // We MUST answer unhandled set iqs with a service-unavailable error. - client - .send_stanza(make_error( - iq.from.unwrap(), - iq.id, - ErrorType::Cancel, - DefinedCondition::ServiceUnavailable, - "No handler defined for this kind of iq.", - )) - .await - .unwrap(); } - } else if stanza.is("message", "jabber:client") { - let message = Message::try_from(stanza).unwrap(); - let from = message.from.clone().unwrap(); - if let Some(body) = message.get_best_body(vec!["en"]) { - if body.0 == "die" { - println!("Secret die command triggered by {}", from); - wait_for_stream_end = true; - client.send_end().await.unwrap(); + Stanza::Message(message) => { + let from = message.from.clone().unwrap(); + if let Some(body) = message.get_best_body(vec!["en"]) { + if body.0 == "die" { + println!("Secret die command triggered by {}", from); + wait_for_stream_end = true; + client.send_end().await.unwrap(); + } } - } - for child in message.payloads { - if child.is("event", ns::PUBSUB_EVENT) { - let event = PubSubEvent::try_from(child).unwrap(); - if let PubSubEvent::PublishedItems { node, items } = event { - if node.0 == ns::AVATAR_METADATA { - for item in items.into_iter() { - let payload = item.payload.clone().unwrap(); - if payload.is("metadata", ns::AVATAR_METADATA) { - // TODO: do something with these metadata. - let _metadata = - AvatarMetadata::try_from(payload).unwrap(); - println!( - "{} has published an avatar, downloading...", - from.clone() - ); - let iq = download_avatar(from.clone()); - client.send_stanza(iq.into()).await.unwrap(); + for child in message.payloads { + if child.is("event", ns::PUBSUB_EVENT) { + let event = PubSubEvent::try_from(child).unwrap(); + if let PubSubEvent::PublishedItems { node, items } = event { + if node.0 == ns::AVATAR_METADATA { + for item in items.into_iter() { + let payload = item.payload.clone().unwrap(); + if payload.is("metadata", ns::AVATAR_METADATA) { + // TODO: do something with these metadata. + let _metadata = + AvatarMetadata::try_from(payload).unwrap(); + println!( + "{} has published an avatar, downloading...", + from.clone() + ); + let iq = download_avatar(from.clone()); + client.send_stanza(iq.into()).await.unwrap(); + } } } } } } } - } else if stanza.is("presence", "jabber:client") { // Nothing to do here. - () - } else { - panic!("Unknown stanza: {}", String::from(&stanza)); + Stanza::Presence(_) => (), } } } else { @@ -164,10 +169,9 @@ fn make_error( type_: ErrorType, condition: DefinedCondition, text: &str, -) -> Element { +) -> Iq { let error = StanzaError::new(type_, condition, "en", text); - let iq = Iq::from_error(id, error).with_to(to); - iq.into() + Iq::from_error(id, error).with_to(to) } fn make_disco() -> DiscoInfoResult { diff --git a/tokio-xmpp/examples/echo_bot.rs b/tokio-xmpp/examples/echo_bot.rs index 9810c598aac57147e22f5350e38162695af41305..055ff8f770153b72f37137203b9afe724c443000 100644 --- a/tokio-xmpp/examples/echo_bot.rs +++ b/tokio-xmpp/examples/echo_bot.rs @@ -1,5 +1,4 @@ use futures::stream::StreamExt; -use minidom::Element; use std::env::args; use std::process::exit; use std::str::FromStr; @@ -40,7 +39,7 @@ async fn main() { println!("Online at {}", jid); let presence = make_presence(); - client.send_stanza(presence).await.unwrap(); + client.send_stanza(presence.into()).await.unwrap(); } else if let Some(message) = event .into_stanza() .and_then(|stanza| Message::try_from(stanza).ok()) @@ -55,7 +54,7 @@ async fn main() { if message.type_ != MessageType::Error { // This is a message we'll echo let reply = make_reply(from.clone(), &body.0); - client.send_stanza(reply).await.unwrap(); + client.send_stanza(reply.into()).await.unwrap(); } } _ => {} @@ -69,18 +68,18 @@ async fn main() { } // Construct a -fn make_presence() -> Element { +fn make_presence() -> Presence { let mut presence = Presence::new(PresenceType::None); presence.show = Some(PresenceShow::Chat); presence .statuses .insert(String::from("en"), String::from("Echoing messages.")); - presence.into() + presence } // Construct a chat -fn make_reply(to: Jid, body: &str) -> Element { +fn make_reply(to: Jid, body: &str) -> Message { let mut message = Message::new(Some(to)); message.bodies.insert(String::new(), Body(body.to_owned())); - message.into() + message } diff --git a/tokio-xmpp/examples/echo_component.rs b/tokio-xmpp/examples/echo_component.rs index 87481140705365587570e9e279f03eb0fcb4eccf..70ffa538fe62dea3162fa55bbd78d0efa301770a 100644 --- a/tokio-xmpp/examples/echo_component.rs +++ b/tokio-xmpp/examples/echo_component.rs @@ -1,5 +1,4 @@ use futures::stream::StreamExt; -use minidom::Element; use std::env::args; use std::process::exit; use std::str::FromStr; @@ -45,7 +44,7 @@ async fn main() { Jid::from_str("test@component.linkmauve.fr/coucou").unwrap(), Jid::from_str("linkmauve@linkmauve.fr").unwrap(), ); - component.send_stanza(presence).await.unwrap(); + component.send_stanza(presence.into()).await.unwrap(); // Main loop, processes events loop { @@ -56,7 +55,7 @@ async fn main() { (Some(from), Some(body)) => { if message.type_ != MessageType::Error { let reply = make_reply(from, &body.0); - component.send_stanza(reply).await.unwrap(); + component.send_stanza(reply.into()).await.unwrap(); } } _ => (), @@ -69,7 +68,7 @@ async fn main() { } // Construct a -fn make_presence(from: Jid, to: Jid) -> Element { +fn make_presence(from: Jid, to: Jid) -> Presence { let mut presence = Presence::new(PresenceType::None); presence.from = Some(from); presence.to = Some(to); @@ -77,12 +76,12 @@ fn make_presence(from: Jid, to: Jid) -> Element { presence .statuses .insert(String::from("en"), String::from("Echoing messages.")); - presence.into() + presence } // Construct a chat -fn make_reply(to: Jid, body: &str) -> Element { +fn make_reply(to: Jid, body: &str) -> Message { let mut message = Message::new(Some(to)); message.bodies.insert(String::new(), Body(body.to_owned())); - message.into() + message } diff --git a/tokio-xmpp/examples/echo_server.rs b/tokio-xmpp/examples/echo_server.rs index 67664715888ab8b269bf708f616bf40c7f73d348..dbf2eb0570f42a7644eb5a1b5597dcc68332c8e6 100644 --- a/tokio-xmpp/examples/echo_server.rs +++ b/tokio-xmpp/examples/echo_server.rs @@ -1,8 +1,8 @@ use futures::{SinkExt, StreamExt}; use tokio::{self, io, net::TcpSocket}; -use tokio_util::codec::Framed; -use tokio_xmpp::proto::XmppCodec; +use tokio_xmpp::parsers::stream_features::StreamFeatures; +use tokio_xmpp::xmlstream::{accept_stream, StreamHeader}; #[tokio::main] async fn main() -> Result<(), io::Error> { @@ -16,16 +16,22 @@ async fn main() -> Result<(), io::Error> { // Main loop, accepts incoming connections loop { let (stream, _addr) = listener.accept().await?; - - // Use the `XMPPCodec` to encode and decode frames - let mut framed = Framed::new(stream, XmppCodec::new()); + let stream = accept_stream( + tokio::io::BufStream::new(stream), + tokio_xmpp::parsers::ns::DEFAULT_NS, + ) + .await?; + let stream = stream.send_header(StreamHeader::default()).await?; + let mut stream = stream + .send_features::(&StreamFeatures::default()) + .await?; tokio::spawn(async move { - while let Some(packet) = framed.next().await { + while let Some(packet) = stream.next().await { match packet { Ok(packet) => { println!("Received packet: {:?}", packet); - framed.send(packet).await.unwrap(); + stream.send(&packet).await.unwrap(); } Err(e) => { eprintln!("Error: {:?}", e); diff --git a/tokio-xmpp/src/client/bind.rs b/tokio-xmpp/src/client/bind.rs index 2f282ad43b821c4f2747a5edc6e8bc0da5ad72bc..dfc94edfb8f48cb05e42acd367280169a12e54aa 100644 --- a/tokio-xmpp/src/client/bind.rs +++ b/tokio-xmpp/src/client/bind.rs @@ -1,46 +1,53 @@ -use futures::stream::StreamExt; -use tokio::io::{AsyncRead, AsyncWrite}; +use std::io; + +use futures::{SinkExt, StreamExt}; +use tokio::io::{AsyncBufRead, AsyncWrite}; use xmpp_parsers::bind::{BindQuery, BindResponse}; use xmpp_parsers::iq::{Iq, IqType}; +use xmpp_parsers::stream_features::StreamFeatures; use crate::error::{Error, ProtocolError}; -use crate::proto::{Packet, XmppStream}; +use crate::jid::{FullJid, Jid}; +use crate::xmlstream::{ReadError, XmppStream, XmppStreamElement}; const BIND_REQ_ID: &str = "resource-bind"; -pub async fn bind( - mut stream: XmppStream, -) -> Result, Error> { - if stream.stream_features.can_bind() { - let resource = stream - .jid +pub async fn bind( + stream: &mut XmppStream, + features: &StreamFeatures, + jid: &Jid, +) -> Result, Error> { + if features.can_bind() { + let resource = jid .resource() .and_then(|resource| Some(resource.to_string())); let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource)); - stream.send_stanza(iq).await?; + stream.send(&XmppStreamElement::Iq(iq)).await?; loop { match stream.next().await { - Some(Ok(Packet::Stanza(stanza))) => match Iq::try_from(stanza) { - Ok(iq) if iq.id == BIND_REQ_ID => match iq.payload { - IqType::Result(payload) => { - payload - .and_then(|payload| BindResponse::try_from(payload).ok()) - .map(|bind| stream.jid = bind.into()); - return Ok(stream); + Some(Ok(XmppStreamElement::Iq(iq))) if iq.id == BIND_REQ_ID => match iq.payload { + IqType::Result(Some(payload)) => match BindResponse::try_from(payload) { + Ok(v) => { + return Ok(Some(v.into())); } - _ => return Err(ProtocolError::InvalidBindResponse.into()), + Err(_) => return Err(ProtocolError::InvalidBindResponse.into()), }, - _ => {} + _ => return Err(ProtocolError::InvalidBindResponse.into()), }, Some(Ok(_)) => {} - Some(Err(e)) => return Err(e), - None => return Err(Error::Disconnected), + Some(Err(ReadError::SoftTimeout)) => {} + Some(Err(ReadError::HardError(e))) => return Err(e.into()), + Some(Err(ReadError::ParseError(e))) => { + return Err(io::Error::new(io::ErrorKind::InvalidData, e).into()) + } + Some(Err(ReadError::StreamFooterReceived)) | None => { + return Err(Error::Disconnected) + } } } } else { - // No resource binding available, - // return the (probably // usable) stream immediately - return Ok(stream); + // No resource binding available, do nothing. + return Ok(None); } } diff --git a/tokio-xmpp/src/client/login.rs b/tokio-xmpp/src/client/login.rs index 5e4d12191ddfe65c3a8053a19c88ff12ddab4bf2..14cb029318d94afdea60dd5d94412d1fd5244bd7 100644 --- a/tokio-xmpp/src/client/login.rs +++ b/tokio-xmpp/src/client/login.rs @@ -1,25 +1,32 @@ -use futures::stream::StreamExt; +use futures::{SinkExt, StreamExt}; use sasl::client::mechanisms::{Anonymous, Plain, Scram}; use sasl::client::Mechanism; use sasl::common::scram::{Sha1, Sha256}; use sasl::common::Credentials; +use std::borrow::Cow; use std::collections::HashSet; +use std::io; use std::str::FromStr; -use tokio::io::{AsyncRead, AsyncWrite}; -use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success}; -use xmpp_parsers::{jid::Jid, ns}; +use tokio::io::{AsyncBufRead, AsyncWrite}; +use xmpp_parsers::{ + jid::{FullJid, Jid}, + ns, + sasl::{Auth, Mechanism as XMPPMechanism, Nonza, Response}, + stream_features::{SaslMechanisms, StreamFeatures}, +}; use crate::{ client::bind::bind, connect::ServerConnector, error::{AuthError, Error, ProtocolError}, - proto::{Packet, XmppStream}, + xmlstream::{xmpp::XmppStreamElement, InitiatingStream, ReadError, StreamHeader, XmppStream}, }; -pub async fn auth( +pub async fn auth( mut stream: XmppStream, + sasl_mechanisms: &SaslMechanisms, creds: Credentials, -) -> Result { +) -> Result, Error> { let local_mechs: Vec Box + Send>> = vec![ Box::new(|| Box::new(Scram::::from_credentials(creds.clone()).unwrap())), Box::new(|| Box::new(Scram::::from_credentials(creds.clone()).unwrap())), @@ -27,13 +34,7 @@ pub async fn auth( Box::new(|| Box::new(Anonymous::new())), ]; - let remote_mechs: HashSet = stream - .stream_features - .sasl_mechanisms - .mechanisms - .iter() - .cloned() - .collect(); + let remote_mechs: HashSet = sasl_mechanisms.mechanisms.iter().cloned().collect(); for local_mech in local_mechs { let mut mechanism = local_mech(); @@ -43,43 +44,55 @@ pub async fn auth( XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?; stream - .send_stanza(Auth { + .send(&XmppStreamElement::Sasl(Nonza::Auth(Auth { mechanism: mechanism_name, data: initial, - }) + }))) .await?; loop { match stream.next().await { - Some(Ok(Packet::Stanza(stanza))) => { - if let Ok(challenge) = Challenge::try_from(stanza.clone()) { + Some(Ok(XmppStreamElement::Sasl(sasl))) => match sasl { + Nonza::Challenge(challenge) => { let response = mechanism .response(&challenge.data) .map_err(|e| AuthError::Sasl(e))?; // Send response and loop - stream.send_stanza(Response { data: response }).await?; - } else if let Ok(_) = Success::try_from(stanza.clone()) { - return Ok(stream.into_inner()); - } else if let Ok(failure) = Failure::try_from(stanza.clone()) { + stream + .send(&XmppStreamElement::Sasl(Nonza::Response(Response { + data: response, + }))) + .await?; + } + Nonza::Success(_) => return Ok(stream.initiate_reset()), + Nonza::Failure(failure) => { return Err(Error::Auth(AuthError::Fail(failure.defined_condition))); - // TODO: This code was needed for compatibility with some broken server, - // but it’s been forgotten which. It is currently commented out so that we - // can find it and fix the server software instead. - /* - } else if stanza.name() == "failure" { - // Workaround for https://gitlab.com/xmpp-rs/xmpp-parsers/merge_requests/1 - return Err(Error::Auth(AuthError::Sasl("failure".to_string()))); - */ - } else { - // ignore and loop } + _ => { + // Ignore?! + } + }, + Some(Ok(el)) => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "unexpected stream element during SASL negotiation: {:?}", + el + ), + ) + .into()) + } + Some(Err(ReadError::HardError(e))) => return Err(e.into()), + Some(Err(ReadError::ParseError(e))) => { + return Err(io::Error::new(io::ErrorKind::InvalidData, e).into()) + } + Some(Err(ReadError::SoftTimeout)) => { + // We cannot do anything about soft timeouts here... } - Some(Ok(_)) => { - // ignore and loop + Some(Err(ReadError::StreamFooterReceived)) | None => { + return Err(Error::Disconnected) } - Some(Err(e)) => return Err(e), - None => return Err(Error::Disconnected), } } } @@ -94,24 +107,31 @@ pub async fn client_login( server: C, jid: Jid, password: String, -) -> Result, Error> { +) -> Result<(Option, StreamFeatures, XmppStream), Error> { let username = jid.node().unwrap().as_str(); let password = password; let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT).await?; + let (features, xmpp_stream) = xmpp_stream.recv_features().await?; - let channel_binding = C::channel_binding(xmpp_stream.stream.get_ref())?; + let channel_binding = C::channel_binding(xmpp_stream.get_stream())?; let creds = Credentials::default() .with_username(username) .with_password(password) .with_channel_binding(channel_binding); // Authenticated (unspecified) stream - let stream = auth(xmpp_stream, creds).await?; - // Authenticated XmppStream - let xmpp_stream = XmppStream::start(stream, jid, ns::JABBER_CLIENT.to_owned()).await?; + let stream = auth(xmpp_stream, &features.sasl_mechanisms, creds).await?; + let stream = stream + .send_header(StreamHeader { + to: Some(Cow::Borrowed(jid.domain().as_str())), + from: None, + id: None, + }) + .await?; + let (features, mut stream) = stream.recv_features().await?; // XmppStream bound to user session - let xmpp_stream = bind(xmpp_stream).await?; - Ok(xmpp_stream) + let full_jid = bind(&mut stream, &features, &jid).await?; + Ok((full_jid, features, stream)) } diff --git a/tokio-xmpp/src/client/mod.rs b/tokio-xmpp/src/client/mod.rs index 3633aa1010e1872f6fe178d5de9da5bf3f57ffa6..13fcf6f289f9b43076b08462fa0193155cee47bc 100644 --- a/tokio-xmpp/src/client/mod.rs +++ b/tokio-xmpp/src/client/mod.rs @@ -1,12 +1,11 @@ use futures::sink::SinkExt; -use minidom::Element; -use xmpp_parsers::{jid::Jid, ns, stream_features::StreamFeatures}; +use xmpp_parsers::{jid::Jid, stream_features::StreamFeatures}; use crate::{ client::{login::client_login, stream::ClientState}, connect::ServerConnector, error::Error, - proto::{add_stanza_id, Packet}, + Stanza, }; #[cfg(any(feature = "starttls", feature = "insecure-tcp"))] @@ -47,21 +46,21 @@ impl Client { /// server). pub fn bound_jid(&self) -> Option<&Jid> { match self.state { - ClientState::Connected(ref stream) => Some(&stream.jid), + ClientState::Connected { ref bound_jid, .. } => Some(bound_jid), _ => None, } } /// Send stanza - pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> { - self.send(Packet::Stanza(add_stanza_id(stanza, ns::JABBER_CLIENT))) - .await + pub async fn send_stanza(&mut self, mut stanza: Stanza) -> Result<(), Error> { + stanza.ensure_id(); + self.send(stanza).await } /// Get the stream features (``) of the underlying stream pub fn get_stream_features(&self) -> Option<&StreamFeatures> { match self.state { - ClientState::Connected(ref stream) => Some(&stream.stream_features), + ClientState::Connected { ref features, .. } => Some(features), _ => None, } } @@ -73,7 +72,14 @@ impl Client { /// /// Make sure to disable reconnect. pub async fn send_end(&mut self) -> Result<(), Error> { - self.send(Packet::StreamEnd).await + match self.state { + ClientState::Connected { ref mut stream, .. } => Ok(stream.close().await?), + ClientState::Connecting { .. } => { + self.state = ClientState::Disconnected; + Ok(()) + } + _ => Ok(()), + } } } diff --git a/tokio-xmpp/src/client/stream.rs b/tokio-xmpp/src/client/stream.rs index b185bd136e8c4033c48689ab6f16589b8c2d2f42..7b19ecb8de5a6d22bdea28f6d3f0849a2a898a6e 100644 --- a/tokio-xmpp/src/client/stream.rs +++ b/tokio-xmpp/src/client/stream.rs @@ -1,22 +1,31 @@ use futures::{task::Poll, Future, Sink, Stream}; +use std::io; use std::mem::replace; use std::pin::Pin; use std::task::Context; use tokio::task::JoinHandle; +use xmpp_parsers::{ + jid::{FullJid, Jid}, + stream_features::StreamFeatures, +}; use crate::{ - client::login::client_login, + client::{login::client_login, Client}, connect::{AsyncReadAndWrite, ServerConnector}, - error::{Error, ProtocolError}, - proto::{Packet, XmppStream}, - Client, Event, + error::Error, + xmlstream::{xmpp::XmppStreamElement, ReadError, XmppStream}, + Event, Stanza, }; pub(crate) enum ClientState { Invalid, Disconnected, - Connecting(JoinHandle, Error>>), - Connected(XmppStream), + Connecting(JoinHandle, StreamFeatures, XmppStream), Error>>), + Connected { + stream: XmppStream, + features: StreamFeatures, + bound_jid: Jid, + }, } /// Incoming XMPP events @@ -56,9 +65,13 @@ impl Stream for Client { Poll::Ready(None) } ClientState::Connecting(mut connect) => match Pin::new(&mut connect).poll(cx) { - Poll::Ready(Ok(Ok(stream))) => { - let bound_jid = stream.jid.clone(); - self.state = ClientState::Connected(stream); + Poll::Ready(Ok(Ok((bound_jid, features, stream)))) => { + let bound_jid = bound_jid.map(Jid::from).unwrap_or_else(|| self.jid.clone()); + self.state = ClientState::Connected { + stream, + bound_jid: bound_jid.clone(), + features, + }; Poll::Ready(Some(Event::Online { bound_jid, resumed: false, @@ -77,7 +90,11 @@ impl Stream for Client { Poll::Pending } }, - ClientState::Connected(mut stream) => { + ClientState::Connected { + mut stream, + features, + bound_jid, + } => { // Poll sink match Pin::new(&mut stream).poll_ready(cx) { Poll::Pending => (), @@ -99,40 +116,69 @@ impl Stream for Client { // return. loop { match Pin::new(&mut stream).poll_next(cx) { - Poll::Ready(None) => { + Poll::Ready(None) + | Poll::Ready(Some(Err(ReadError::StreamFooterReceived))) => { // EOF self.state = ClientState::Disconnected; return Poll::Ready(Some(Event::Disconnected(Error::Disconnected))); } - Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => { - // Receive stanza - self.state = ClientState::Connected(stream); - return Poll::Ready(Some(Event::Stanza(stanza))); - } - Poll::Ready(Some(Ok(Packet::Text(_)))) => { - // Ignore text between stanzas + Poll::Ready(Some(Err(ReadError::HardError(e)))) => { + // Treat stream as dead on I/O errors + self.state = ClientState::Disconnected; + return Poll::Ready(Some(Event::Disconnected(e.into()))); } - Poll::Ready(Some(Ok(Packet::StreamStart(_)))) => { - // + Poll::Ready(Some(Err(ReadError::ParseError(e)))) => { + // Treat stream as dead on parse errors, too (for now...) self.state = ClientState::Disconnected; return Poll::Ready(Some(Event::Disconnected( - ProtocolError::InvalidStreamStart.into(), + io::Error::new(io::ErrorKind::InvalidData, e).into(), ))); } - Poll::Ready(Some(Ok(Packet::StreamEnd))) => { - // End of stream: - self.state = ClientState::Disconnected; - return Poll::Ready(Some(Event::Disconnected(Error::Disconnected))); + Poll::Ready(Some(Err(ReadError::SoftTimeout))) => { + // TODO: do something smart about this. + } + Poll::Ready(Some(Ok(XmppStreamElement::Iq(stanza)))) => { + // Receive stanza + self.state = ClientState::Connected { + stream, + features, + bound_jid, + }; + // TODO: use specific stanza types instead of going back to elements... + return Poll::Ready(Some(Event::Stanza(stanza.into()))); + } + Poll::Ready(Some(Ok(XmppStreamElement::Message(stanza)))) => { + // Receive stanza + self.state = ClientState::Connected { + stream, + features, + bound_jid, + }; + // TODO: use specific stanza types instead of going back to elements... + return Poll::Ready(Some(Event::Stanza(stanza.into()))); + } + Poll::Ready(Some(Ok(XmppStreamElement::Presence(stanza)))) => { + // Receive stanza + self.state = ClientState::Connected { + stream, + features, + bound_jid, + }; + // TODO: use specific stanza types instead of going back to elements... + return Poll::Ready(Some(Event::Stanza(stanza.into()))); + } + Poll::Ready(Some(Ok(_))) => { + // We ignore these for now. } Poll::Pending => { // Try again later - self.state = ClientState::Connected(stream); + self.state = ClientState::Connected { + stream, + features, + bound_jid, + }; return Poll::Pending; } - Poll::Ready(Some(Err(e))) => { - self.state = ClientState::Disconnected; - return Poll::Ready(Some(Event::Disconnected(e.into()))); - } } } } @@ -143,21 +189,21 @@ impl Stream for Client { /// Outgoing XMPP packets /// /// See `send_stanza()` for an `async fn` -impl Sink for Client { +impl Sink for Client { type Error = Error; - fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> { match self.state { - ClientState::Connected(ref mut stream) => { - Pin::new(stream).start_send(item).map_err(|e| e.into()) - } + ClientState::Connected { ref mut stream, .. } => Pin::new(stream) + .start_send(&item.into()) + .map_err(|e| e.into()), _ => Err(Error::InvalidState), } } fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { match self.state { - ClientState::Connected(ref mut stream) => { + ClientState::Connected { ref mut stream, .. } => { Pin::new(stream).poll_ready(cx).map_err(|e| e.into()) } _ => Poll::Pending, @@ -166,7 +212,7 @@ impl Sink for Client { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { match self.state { - ClientState::Connected(ref mut stream) => { + ClientState::Connected { ref mut stream, .. } => { Pin::new(stream).poll_flush(cx).map_err(|e| e.into()) } _ => Poll::Pending, @@ -175,7 +221,7 @@ impl Sink for Client { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { match self.state { - ClientState::Connected(ref mut stream) => { + ClientState::Connected { ref mut stream, .. } => { Pin::new(stream).poll_close(cx).map_err(|e| e.into()) } _ => Poll::Pending, diff --git a/tokio-xmpp/src/component/login.rs b/tokio-xmpp/src/component/login.rs index d72781dbe91426663db06ab07bd487d4ba402ea2..427ef89ad8e1ccadace5884cf64be807911d6d2b 100644 --- a/tokio-xmpp/src/component/login.rs +++ b/tokio-xmpp/src/component/login.rs @@ -1,12 +1,12 @@ -use futures::stream::StreamExt; -use tokio::io::{AsyncRead, AsyncWrite}; +use std::io; + +use futures::{SinkExt, StreamExt}; +use tokio::io::{AsyncBufRead, AsyncWrite}; use xmpp_parsers::{component::Handshake, jid::Jid, ns}; -use crate::{ - connect::ServerConnector, - error::{AuthError, Error}, - proto::{Packet, XmppStream}, -}; +use crate::component::ServerConnector; +use crate::error::{AuthError, Error}; +use crate::xmlstream::{ReadError, XmppStream, XmppStreamElement}; /// Log into an XMPP server as a client with a jid+pass pub async fn component_login( @@ -15,32 +15,47 @@ pub async fn component_login( password: String, ) -> Result, Error> { let password = password; - let mut xmpp_stream = connector.connect(&jid, ns::COMPONENT).await?; - auth(&mut xmpp_stream, password).await?; - Ok(xmpp_stream) + let mut stream = connector.connect(&jid, ns::COMPONENT).await?; + let header = stream.take_header(); + let mut stream = stream.skip_features(); + let stream_id = match header.id { + Some(ref v) => &**v, + None => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "stream id missing on component stream", + ) + .into()) + } + }; + auth(&mut stream, stream_id, &password).await?; + Ok(stream) } -pub async fn auth( +pub async fn auth( stream: &mut XmppStream, - password: String, + stream_id: &str, + password: &str, ) -> Result<(), Error> { - let nonza = Handshake::from_password_and_stream_id(&password, &stream.id); - stream.send_stanza(nonza).await?; + let nonza = Handshake::from_password_and_stream_id(password, stream_id); + stream + .send(&XmppStreamElement::ComponentHandshake(nonza)) + .await?; loop { match stream.next().await { - Some(Ok(Packet::Stanza(ref stanza))) - if stanza.is("handshake", ns::COMPONENT_ACCEPT) => - { + Some(Ok(XmppStreamElement::ComponentHandshake(_))) => { return Ok(()); } - Some(Ok(Packet::Stanza(ref stanza))) - if stanza.is("error", "http://etherx.jabber.org/streams") => - { + Some(Ok(_)) => { return Err(AuthError::ComponentFail.into()); } - Some(_) => {} - None => return Err(Error::Disconnected), + Some(Err(ReadError::SoftTimeout)) => (), + Some(Err(ReadError::HardError(e))) => return Err(e.into()), + Some(Err(ReadError::ParseError(e))) => { + return Err(io::Error::new(io::ErrorKind::InvalidData, e).into()) + } + Some(Err(ReadError::StreamFooterReceived)) | None => return Err(Error::Disconnected), } } } diff --git a/tokio-xmpp/src/component/mod.rs b/tokio-xmpp/src/component/mod.rs index 52391f14b4826778e48262a3656ce57038a414ea..efd58a40cd1e0d773804536357115d4c0d4673e0 100644 --- a/tokio-xmpp/src/component/mod.rs +++ b/tokio-xmpp/src/component/mod.rs @@ -2,15 +2,12 @@ //! XMPP server under a JID consisting of just a domain name. They are //! allowed to use any user and resource identifiers in their stanzas. use futures::sink::SinkExt; -use minidom::Element; use std::str::FromStr; -use xmpp_parsers::{jid::Jid, ns}; +use xmpp_parsers::jid::Jid; use crate::{ - component::login::component_login, - connect::ServerConnector, - proto::{add_stanza_id, XmppStream}, - Error, + component::login::component_login, connect::ServerConnector, xmlstream::XmppStream, Error, + Stanza, }; #[cfg(any(feature = "starttls", feature = "insecure-tcp"))] @@ -33,8 +30,9 @@ pub struct Component { impl Component { /// Send stanza - pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> { - self.send(add_stanza_id(stanza, ns::COMPONENT_ACCEPT)).await + pub async fn send_stanza(&mut self, mut stanza: Stanza) -> Result<(), Error> { + stanza.ensure_id(); + self.send(stanza).await } /// End connection diff --git a/tokio-xmpp/src/component/stream.rs b/tokio-xmpp/src/component/stream.rs index 3ae62c050595ded87aa35bcec48234a835409706..b1d0d4667aa6c4cea41b810785764e412ccdd72e 100644 --- a/tokio-xmpp/src/component/stream.rs +++ b/tokio-xmpp/src/component/stream.rs @@ -2,21 +2,27 @@ //! XMPP server under a JID consisting of just a domain name. They are //! allowed to use any user and resource identifiers in their stanzas. use futures::{task::Poll, Sink, Stream}; -use minidom::Element; use std::pin::Pin; use std::task::Context; -use crate::{component::Component, connect::ServerConnector, proto::Packet, Error}; +use crate::{ + component::Component, connect::ServerConnector, xmlstream::XmppStreamElement, Error, Stanza, +}; impl Stream for Component { - type Item = Element; + type Item = Stanza; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { loop { match Pin::new(&mut self.stream).poll_next(cx) { - Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => return Poll::Ready(Some(stanza)), - Poll::Ready(Some(Ok(Packet::Text(_)))) => { - // retry + Poll::Ready(Some(Ok(XmppStreamElement::Iq(stanza)))) => { + return Poll::Ready(Some(Stanza::Iq(stanza))) + } + Poll::Ready(Some(Ok(XmppStreamElement::Message(stanza)))) => { + return Poll::Ready(Some(Stanza::Message(stanza))) + } + Poll::Ready(Some(Ok(XmppStreamElement::Presence(stanza)))) => { + return Poll::Ready(Some(Stanza::Presence(stanza))) } Poll::Ready(Some(Ok(_))) => // unexpected @@ -31,12 +37,12 @@ impl Stream for Component { } } -impl Sink for Component { +impl Sink for Component { type Error = Error; - fn start_send(mut self: Pin<&mut Self>, item: Element) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> { Pin::new(&mut self.stream) - .start_send(Packet::Stanza(item)) + .start_send(&item.into()) .map_err(|e| e.into()) } diff --git a/tokio-xmpp/src/connect/mod.rs b/tokio-xmpp/src/connect/mod.rs index 2d7cb5b36eb3766263bbd9761b12f9ca4105ba3c..937b04026914484522244dbd7217fb3bb3470a6c 100644 --- a/tokio-xmpp/src/connect/mod.rs +++ b/tokio-xmpp/src/connect/mod.rs @@ -1,10 +1,10 @@ //! `ServerConnector` provides streams for XMPP clients use sasl::common::ChannelBinding; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncBufRead, AsyncWrite}; use xmpp_parsers::jid::Jid; -use crate::proto::XmppStream; +use crate::xmlstream::PendingFeaturesRecv; use crate::Error; #[cfg(feature = "starttls")] @@ -21,8 +21,8 @@ mod dns; pub use dns::DnsConfig; /// trait returned wrapped in XmppStream by ServerConnector -pub trait AsyncReadAndWrite: AsyncRead + AsyncWrite + Unpin + Send {} -impl AsyncReadAndWrite for T {} +pub trait AsyncReadAndWrite: AsyncBufRead + AsyncWrite + Unpin + Send {} +impl AsyncReadAndWrite for T {} /// Trait that must be extended by the implementation of ServerConnector pub trait ServerConnectorError: std::error::Error + Sync + Send {} @@ -35,8 +35,8 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static { fn connect( &self, jid: &Jid, - ns: &str, - ) -> impl std::future::Future, Error>> + Send; + ns: &'static str, + ) -> impl std::future::Future, Error>> + Send; /// Return channel binding data if available /// do not fail if channel binding is simply unavailable, just return Ok(None) diff --git a/tokio-xmpp/src/connect/starttls.rs b/tokio-xmpp/src/connect/starttls.rs index 5d8550a74f52c16c5b475a3b016cf827bbb28abf..1d0d26e4f157e711b188d6171e42b72c0bf9e5c3 100644 --- a/tokio-xmpp/src/connect/starttls.rs +++ b/tokio-xmpp/src/connect/starttls.rs @@ -2,8 +2,10 @@ #[cfg(feature = "tls-native")] use native_tls::Error as TlsError; +use std::borrow::Cow; use std::error::Error as StdError; use std::fmt; +use std::io; #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] use tokio_rustls::rustls::pki_types::InvalidDnsNameError; #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] @@ -28,18 +30,23 @@ use { tokio_native_tls::{TlsConnector, TlsStream}, }; -use minidom::Element; use sasl::common::ChannelBinding; use tokio::{ - io::{AsyncRead, AsyncWrite}, + io::{AsyncRead, AsyncWrite, BufStream}, net::TcpStream, }; -use xmpp_parsers::{jid::Jid, ns}; +use xmpp_parsers::{ + jid::Jid, + starttls::{self, Request}, +}; use crate::{ connect::{DnsConfig, ServerConnector, ServerConnectorError}, error::{Error, ProtocolError}, - proto::{Packet, XmppStream}, + xmlstream::{ + initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, XmppStream, + XmppStreamElement, + }, Client, }; @@ -57,20 +64,44 @@ impl From for StartTlsServerConnector { } impl ServerConnector for StartTlsServerConnector { - type Stream = TlsStream; - async fn connect(&self, jid: &Jid, ns: &str) -> Result, Error> { - let tcp_stream = self.0.resolve().await?; + type Stream = BufStream>; - // Unencryped XmppStream - let xmpp_stream = XmppStream::start(tcp_stream, jid.clone(), ns.to_owned()).await?; + async fn connect( + &self, + jid: &Jid, + ns: &'static str, + ) -> Result, Error> { + let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?); - if xmpp_stream.stream_features.can_starttls() { + // Unencryped XmppStream + let xmpp_stream = initiate_stream( + tcp_stream, + ns, + StreamHeader { + to: Some(Cow::Borrowed(jid.domain().as_str())), + from: None, + id: None, + }, + ) + .await?; + let (features, xmpp_stream) = xmpp_stream.recv_features().await?; + + if features.can_starttls() { // TlsStream - let tls_stream = starttls(xmpp_stream).await?; + let tls_stream = starttls(xmpp_stream, jid.domain().as_str()).await?; // Encrypted XmppStream - Ok(XmppStream::start(tls_stream, jid.clone(), ns.to_owned()).await?) + Ok(initiate_stream( + tokio::io::BufStream::new(tls_stream), + ns, + StreamHeader { + to: Some(Cow::Borrowed(jid.domain().as_str())), + from: None, + id: None, + }, + ) + .await?) } else { - return Err(crate::Error::Protocol(ProtocolError::NoTls).into()); + Err(crate::Error::Protocol(ProtocolError::NoTls).into()) } } @@ -84,7 +115,7 @@ impl ServerConnector for StartTlsServerConnector { } #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] { - let (_, connection) = stream.get_ref(); + let (_, connection) = stream.get_ref().get_ref(); Ok(match connection.protocol_version() { // TODO: Add support for TLS 1.2 and earlier. Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => { @@ -102,10 +133,11 @@ impl ServerConnector for StartTlsServerConnector { #[cfg(feature = "tls-native")] async fn get_tls_stream( - xmpp_stream: XmppStream, + xmpp_stream: XmppStream>, + domain: &str, ) -> Result, Error> { - let domain = xmpp_stream.jid.domain().to_owned(); - let stream = xmpp_stream.into_inner(); + let domain = domain.to_owned(); + let stream = xmpp_stream.into_inner().into_inner(); let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) .connect(&domain, stream) .await @@ -115,11 +147,11 @@ async fn get_tls_stream( #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] async fn get_tls_stream( - xmpp_stream: XmppStream, + xmpp_stream: XmppStream>, + domain: &str, ) -> Result, Error> { - let domain = xmpp_stream.jid.domain().to_string(); - let domain = ServerName::try_from(domain).map_err(|e| StartTlsError::DnsNameError(e))?; - let stream = xmpp_stream.into_inner(); + let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?; + let stream = xmpp_stream.into_inner().into_inner(); let mut root_store = RootCertStore::empty(); #[cfg(feature = "webpki-roots")] { @@ -142,24 +174,33 @@ async fn get_tls_stream( /// Performs `` on an XmppStream and returns a binary /// TlsStream. pub async fn starttls( - mut xmpp_stream: XmppStream, + mut stream: XmppStream>, + domain: &str, ) -> Result, Error> { - let nonza = Element::builder("starttls", ns::TLS).build(); - let packet = Packet::Stanza(nonza); - xmpp_stream.send(packet).await?; + stream + .send(&XmppStreamElement::Starttls(starttls::Nonza::Request( + Request, + ))) + .await?; loop { - match xmpp_stream.next().await { - Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break, - Some(Ok(Packet::Text(_))) => {} - Some(Err(e)) => return Err(e.into()), - _ => { - return Err(crate::Error::Protocol(ProtocolError::NoTls).into()); + match stream.next().await { + Some(Ok(XmppStreamElement::Starttls(starttls::Nonza::Proceed(_)))) => { + break; + } + Some(Ok(_)) => (), + Some(Err(ReadError::SoftTimeout)) => (), + Some(Err(ReadError::HardError(e))) => return Err(e.into()), + Some(Err(ReadError::ParseError(e))) => { + return Err(io::Error::new(io::ErrorKind::InvalidData, e).into()) + } + None | Some(Err(ReadError::StreamFooterReceived)) => { + return Err(crate::Error::Disconnected) } } } - get_tls_stream(xmpp_stream).await + get_tls_stream(stream, domain).await } /// StartTLS ServerConnector Error diff --git a/tokio-xmpp/src/connect/tcp.rs b/tokio-xmpp/src/connect/tcp.rs index 21061cf6df8cb881d5fbe924b30ddec44fb8777c..89a9e12dd02528d37a395e24271ebb100b985caa 100644 --- a/tokio-xmpp/src/connect/tcp.rs +++ b/tokio-xmpp/src/connect/tcp.rs @@ -1,10 +1,12 @@ //! `starttls::ServerConfig` provides a `ServerConnector` for starttls connections -use tokio::net::TcpStream; +use std::borrow::Cow; + +use tokio::{io::BufStream, net::TcpStream}; use crate::{ connect::{DnsConfig, ServerConnector}, - proto::XmppStream, + xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader}, Client, Component, Error, }; @@ -27,14 +29,23 @@ impl From for TcpServerConnector { } impl ServerConnector for TcpServerConnector { - type Stream = TcpStream; + type Stream = BufStream; async fn connect( &self, jid: &xmpp_parsers::jid::Jid, - ns: &str, - ) -> Result, Error> { - let stream = self.0.resolve().await?; - Ok(XmppStream::start(stream, jid.clone(), ns.to_owned()).await?) + ns: &'static str, + ) -> Result, Error> { + let stream = BufStream::new(self.0.resolve().await?); + Ok(initiate_stream( + stream, + ns, + StreamHeader { + to: Some(Cow::Borrowed(jid.domain().as_str())), + from: None, + id: None, + }, + ) + .await?) } } diff --git a/tokio-xmpp/src/event.rs b/tokio-xmpp/src/event.rs index 05eaf9e69416e0d1339f29c549b3c4c9a7238d23..9c69204f98d5fd831c0e497690ea95ed2e581b7e 100644 --- a/tokio-xmpp/src/event.rs +++ b/tokio-xmpp/src/event.rs @@ -1,6 +1,103 @@ -use super::Error; -use minidom::Element; -use xmpp_parsers::jid::Jid; +use rand::{thread_rng, Rng}; +use xmpp_parsers::{iq::Iq, jid::Jid, message::Message, presence::Presence}; + +use crate::xmlstream::XmppStreamElement; +use crate::Error; + +fn make_id() -> String { + let id: u64 = thread_rng().gen(); + format!("{}", id) +} + +/// A stanza sent/received over the stream. +#[derive(Debug)] +pub enum Stanza { + /// IQ stanza + Iq(Iq), + + /// Message stanza + Message(Message), + + /// Presence stanza + Presence(Presence), +} + +impl Stanza { + /// Assign a random ID to the stanza, if no ID has been assigned yet. + pub fn ensure_id(&mut self) -> &str { + match self { + Self::Iq(iq) => { + if iq.id.len() == 0 { + iq.id = make_id(); + } + &iq.id + } + Self::Message(message) => message.id.get_or_insert_with(make_id), + Self::Presence(presence) => presence.id.get_or_insert_with(make_id), + } + } +} + +impl From for Stanza { + fn from(other: Iq) -> Self { + Self::Iq(other) + } +} + +impl From for Stanza { + fn from(other: Presence) -> Self { + Self::Presence(other) + } +} + +impl From for Stanza { + fn from(other: Message) -> Self { + Self::Message(other) + } +} + +impl TryFrom for Message { + type Error = Stanza; + + fn try_from(other: Stanza) -> Result { + match other { + Stanza::Message(st) => Ok(st), + other => Err(other), + } + } +} + +impl TryFrom for Presence { + type Error = Stanza; + + fn try_from(other: Stanza) -> Result { + match other { + Stanza::Presence(st) => Ok(st), + other => Err(other), + } + } +} + +impl TryFrom for Iq { + type Error = Stanza; + + fn try_from(other: Stanza) -> Result { + match other { + Stanza::Iq(st) => Ok(st), + other => Err(other), + } + } +} + +impl From for XmppStreamElement { + fn from(other: Stanza) -> Self { + match other { + Stanza::Iq(st) => Self::Iq(st), + Stanza::Message(st) => Self::Message(st), + Stanza::Presence(st) => Self::Presence(st), + } + } +} /// High-level event on the Stream implemented by Client and Component #[derive(Debug)] @@ -21,7 +118,7 @@ pub enum Event { /// Stream end Disconnected(Error), /// Received stanza/nonza - Stanza(Element), + Stanza(Stanza), } impl Event { @@ -41,16 +138,8 @@ impl Event { } } - /// `Stanza` event? - pub fn is_stanza(&self, name: &str) -> bool { - match *self { - Event::Stanza(ref stanza) => stanza.name() == name, - _ => false, - } - } - /// If this is a `Stanza` event, get its data - pub fn as_stanza(&self) -> Option<&Element> { + pub fn as_stanza(&self) -> Option<&Stanza> { match *self { Event::Stanza(ref stanza) => Some(stanza), _ => None, @@ -58,7 +147,7 @@ impl Event { } /// If this is a `Stanza` event, unwrap into its data - pub fn into_stanza(self) -> Option { + pub fn into_stanza(self) -> Option { match self { Event::Stanza(stanza) => Some(stanza), _ => None, diff --git a/tokio-xmpp/src/lib.rs b/tokio-xmpp/src/lib.rs index a10cf5f040a566fbec21b6856465198a1869578d..5b00e42f4ffcb05435b23196f5180a8fe03632bd 100644 --- a/tokio-xmpp/src/lib.rs +++ b/tokio-xmpp/src/lib.rs @@ -47,9 +47,8 @@ compile_error!( ); mod event; -pub use event::Event; +pub use event::{Event, Stanza}; pub mod connect; -pub mod proto; pub mod xmlstream; mod client; diff --git a/tokio-xmpp/src/proto/mod.rs b/tokio-xmpp/src/proto/mod.rs deleted file mode 100644 index 34de560a3fa423cc84c9da623c7751b9bac5a5a4..0000000000000000000000000000000000000000 --- a/tokio-xmpp/src/proto/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -//! Low-level stream establishment - -mod xmpp_codec; -mod xmpp_stream; - -pub use xmpp_codec::{Packet, XmppCodec}; -pub(crate) use xmpp_stream::add_stanza_id; -pub use xmpp_stream::XmppStream; diff --git a/tokio-xmpp/src/proto/xmpp_codec.rs b/tokio-xmpp/src/proto/xmpp_codec.rs deleted file mode 100644 index db028e7b815b47d703a8dd5af0ae82c801f26d0b..0000000000000000000000000000000000000000 --- a/tokio-xmpp/src/proto/xmpp_codec.rs +++ /dev/null @@ -1,410 +0,0 @@ -//! XML stream parser for XMPP - -use crate::Error; -use bytes::{BufMut, BytesMut}; -use log::debug; -use minidom::tree_builder::TreeBuilder; -use minidom::Element; -use rxml::{Parse, RawParser}; -use std::collections::HashMap; -use std::fmt::Write; -use std::io; -#[cfg(feature = "syntax-highlighting")] -use std::sync::OnceLock; -use tokio_util::codec::{Decoder, Encoder}; - -#[cfg(feature = "syntax-highlighting")] -static PS: OnceLock = OnceLock::new(); -#[cfg(feature = "syntax-highlighting")] -static SYNTAX: OnceLock = OnceLock::new(); -#[cfg(feature = "syntax-highlighting")] -static THEME: OnceLock = OnceLock::new(); - -#[cfg(feature = "syntax-highlighting")] -fn init_syntect() { - let ps = syntect::parsing::SyntaxSet::load_defaults_newlines(); - let syntax = ps.find_syntax_by_extension("xml").unwrap(); - let ts = syntect::highlighting::ThemeSet::load_defaults(); - let theme = ts.themes["Solarized (dark)"].clone(); - - SYNTAX.set(syntax.clone()).unwrap(); - PS.set(ps).unwrap(); - THEME.set(theme).unwrap(); -} - -#[cfg(feature = "syntax-highlighting")] -fn highlight_xml(xml: &str) -> String { - let mut h = syntect::easy::HighlightLines::new(SYNTAX.get().unwrap(), THEME.get().unwrap()); - let ranges: Vec<_> = h.highlight_line(&xml, PS.get().unwrap()).unwrap(); - let escaped = syntect::util::as_24_bit_terminal_escaped(&ranges[..], false); - format!("{}\x1b[0m", escaped) -} - -#[cfg(not(feature = "syntax-highlighting"))] -fn highlight_xml(xml: &str) -> &str { - xml -} - -/// Anything that can be sent or received on an XMPP/XML stream -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Packet { - /// `` start tag - StreamStart(HashMap), - /// A complete stanza or nonza - Stanza(Element), - /// Plain text (think whitespace keep-alive) - Text(String), - /// `` closing tag - StreamEnd, -} - -/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet` -pub struct XmppCodec { - /// Outgoing - ns: Option, - /// Incoming - driver: RawParser, - stanza_builder: TreeBuilder, -} - -impl XmppCodec { - /// Constructor - pub fn new() -> Self { - let stanza_builder = TreeBuilder::new(); - let driver = RawParser::new(); - #[cfg(feature = "syntax-highlighting")] - if log::log_enabled!(log::Level::Debug) && PS.get().is_none() { - init_syntect(); - } - XmppCodec { - ns: None, - driver, - stanza_builder, - } - } -} - -impl Default for XmppCodec { - fn default() -> Self { - Self::new() - } -} - -impl Decoder for XmppCodec { - type Item = Packet; - type Error = Error; - - fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { - loop { - let token = match self.driver.parse_buf(buf, false) { - Ok(Some(token)) => token, - Ok(None) => break, - Err(rxml::error::EndOrError::NeedMoreData) => break, - Err(rxml::error::EndOrError::Error(e)) => { - return Err(minidom::Error::from(e).into()) - } - }; - - let had_stream_root = self.stanza_builder.depth() > 0; - self.stanza_builder.process_event(token)?; - let has_stream_root = self.stanza_builder.depth() > 0; - - if !had_stream_root && has_stream_root { - let root = self.stanza_builder.top().unwrap(); - let attrs = - root.attrs() - .map(|(name, value)| (name.to_owned(), value.to_owned())) - .chain(root.prefixes.declared_prefixes().iter().map( - |(prefix, namespace)| { - ( - prefix - .as_ref() - .map(|prefix| format!("xmlns:{}", prefix)) - .unwrap_or_else(|| "xmlns".to_owned()), - namespace.clone(), - ) - }, - )) - .collect(); - debug!("<< {}", highlight_xml(&String::from(root))); - return Ok(Some(Packet::StreamStart(attrs))); - } else if self.stanza_builder.depth() == 1 { - self.driver.release_temporaries(); - - if let Some(stanza) = self.stanza_builder.unshift_child() { - debug!("<< {}", highlight_xml(&String::from(&stanza))); - return Ok(Some(Packet::Stanza(stanza))); - } - } else if let Some(_) = self.stanza_builder.root.take() { - self.driver.release_temporaries(); - - debug!("<< {}", highlight_xml("")); - return Ok(Some(Packet::StreamEnd)); - } - } - - Ok(None) - } - - fn decode_eof(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { - self.decode(buf) - } -} - -impl Encoder for XmppCodec { - type Error = Error; - - fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> { - let remaining = dst.capacity() - dst.len(); - let max_stanza_size: usize = 2usize.pow(16); - if remaining < max_stanza_size { - dst.reserve(max_stanza_size - remaining); - } - - fn to_io_err>>(e: E) -> io::Error { - io::Error::new(io::ErrorKind::InvalidInput, e) - } - - match item { - Packet::StreamStart(start_attrs) => { - let mut buf = String::new(); - write!(buf, "").map_err(to_io_err)?; - - write!(dst, "{}", buf)?; - let utf8 = std::str::from_utf8(dst)?; - debug!(">> {}", highlight_xml(utf8)) - } - Packet::Stanza(stanza) => { - let _ = stanza - .write_to(&mut WriteBytes::new(dst)) - .map_err(|e| to_io_err(format!("{}", e)))?; - let utf8 = std::str::from_utf8(dst)?; - debug!(">> {}", highlight_xml(utf8)); - } - Packet::Text(text) => { - let _ = write_text(&text, dst).map_err(to_io_err)?; - let utf8 = std::str::from_utf8(dst)?; - debug!(">> {}", highlight_xml(utf8)); - } - Packet::StreamEnd => { - let _ = write!(dst, "\n").map_err(to_io_err); - debug!(">> {}", highlight_xml("")); - } - } - - Ok(()) - } -} - -/// Write XML-escaped text string -pub fn write_text(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> { - write!(writer, "{}", escape(text)) -} - -/// Copied from `RustyXML` for now -pub fn escape(input: &str) -> String { - let mut result = String::with_capacity(input.len()); - - for c in input.chars() { - match c { - '&' => result.push_str("&"), - '<' => result.push_str("<"), - '>' => result.push_str(">"), - '\'' => result.push_str("'"), - '"' => result.push_str("""), - o => result.push(o), - } - } - result -} - -/// BytesMut impl only std::fmt::Write but not std::io::Write. The -/// latter trait is required for minidom's -/// `Element::write_to_inner()`. -struct WriteBytes<'a> { - dst: &'a mut BytesMut, -} - -impl<'a> WriteBytes<'a> { - fn new(dst: &'a mut BytesMut) -> Self { - WriteBytes { dst } - } -} - -impl<'a> std::io::Write for WriteBytes<'a> { - fn write(&mut self, buf: &[u8]) -> std::result::Result { - self.dst.put_slice(buf); - Ok(buf.len()) - } - - fn flush(&mut self) -> std::result::Result<(), std::io::Error> { - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_stream_start() { - let mut c = XmppCodec::new(); - let mut b = BytesMut::with_capacity(1024); - b.put_slice(b""); - let r = c.decode(&mut b); - assert!(match r { - Ok(Some(Packet::StreamStart(_))) => true, - _ => false, - }); - } - - #[test] - fn test_stream_end() { - let mut c = XmppCodec::new(); - let mut b = BytesMut::with_capacity(1024); - b.put_slice(b""); - let r = c.decode(&mut b); - assert!(match r { - Ok(Some(Packet::StreamStart(_))) => true, - _ => false, - }); - b.put_slice(b""); - let r = c.decode(&mut b); - assert!(match r { - Ok(Some(Packet::StreamEnd)) => true, - _ => false, - }); - } - - #[test] - fn test_truncated_stanza() { - let mut c = XmppCodec::new(); - let mut b = BytesMut::with_capacity(1024); - b.put_slice(b""); - let r = c.decode(&mut b); - assert!(match r { - Ok(Some(Packet::StreamStart(_))) => true, - _ => false, - }); - - b.put_slice("ß true, - _ => false, - }); - - b.put_slice(b">"); - let r = c.decode(&mut b); - assert!(match r { - Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true, - _ => false, - }); - } - - #[test] - fn test_truncated_utf8() { - let mut c = XmppCodec::new(); - let mut b = BytesMut::with_capacity(1024); - b.put_slice(b""); - let r = c.decode(&mut b); - assert!(match r { - Ok(Some(Packet::StreamStart(_))) => true, - _ => false, - }); - - b.put(&b"\xc3"[..]); - let r = c.decode(&mut b); - assert!(match r { - Ok(None) => true, - _ => false, - }); - - b.put(&b"\x9f"[..]); - let r = c.decode(&mut b); - assert!(match r { - Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true, - _ => false, - }); - } - - /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3 - #[test] - fn test_atrribute_prefix() { - let mut c = XmppCodec::new(); - let mut b = BytesMut::with_capacity(1024); - b.put_slice(b""); - let r = c.decode(&mut b); - assert!(match r { - Ok(Some(Packet::StreamStart(_))) => true, - _ => false, - }); - - b.put_slice(b"Test status"); - let r = c.decode(&mut b); - assert!(match r { - Ok(Some(Packet::Stanza(ref el))) - if el.name() == "status" - && el.text() == "Test status" - && el.attr("xml:lang").map_or(false, |a| a == "en") => - true, - _ => false, - }); - } - - /// By default, encode() only gets a BytesMut that has 8 KiB space reserved. - #[test] - fn test_large_stanza() { - use futures::{executor::block_on, sink::SinkExt}; - use std::io::Cursor; - use tokio_util::codec::FramedWrite; - let mut framed = FramedWrite::new(Cursor::new(vec![]), XmppCodec::new()); - let mut text = "".to_owned(); - for _ in 0..2usize.pow(15) { - text = text + "A"; - } - let stanza = Element::builder("message", "jabber:client") - .append( - Element::builder("body", "jabber:client") - .append(text.as_ref()) - .build(), - ) - .build(); - block_on(framed.send(Packet::Stanza(stanza))).expect("send"); - assert_eq!( - framed.get_ref().get_ref(), - &format!( - "{}", - text - ) - .as_bytes() - ); - } - - #[test] - fn test_cut_out_stanza() { - let mut c = XmppCodec::new(); - let mut b = BytesMut::with_capacity(1024); - b.put_slice(b""); - let r = c.decode(&mut b); - assert!(match r { - Ok(Some(Packet::StreamStart(_))) => true, - _ => false, - }); - - b.put_slice(b"Foo"); - let r = c.decode(&mut b); - assert!(match r { - Ok(Some(Packet::Stanza(_))) => true, - _ => false, - }); - } -} diff --git a/tokio-xmpp/src/proto/xmpp_stream.rs b/tokio-xmpp/src/proto/xmpp_stream.rs deleted file mode 100644 index 2624898d45ade8bef2ac64a4faa7d817204710af..0000000000000000000000000000000000000000 --- a/tokio-xmpp/src/proto/xmpp_stream.rs +++ /dev/null @@ -1,193 +0,0 @@ -//! `XmppStream` provides encoding/decoding for XMPP - -use futures::{ - sink::{Send, SinkExt}, - stream::StreamExt, - task::Poll, - Sink, Stream, -}; -use minidom::Element; -use rand::{thread_rng, Rng}; -use std::pin::Pin; -use std::task::Context; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::Framed; -use xmpp_parsers::{jid::Jid, ns, stream_features::StreamFeatures, Error as ParsersError}; - -use crate::error::{Error, ProtocolError}; -use crate::proto::{Packet, XmppCodec}; - -fn make_id() -> String { - let id: u64 = thread_rng().gen(); - format!("{}", id) -} - -pub(crate) fn add_stanza_id(mut stanza: Element, default_ns: &str) -> Element { - if stanza.is("iq", default_ns) - || stanza.is("message", default_ns) - || stanza.is("presence", default_ns) - { - if stanza.attr("id").is_none() { - stanza.set_attr("id", make_id()); - } - } - - stanza -} - -/// Wraps a binary stream (tokio's `AsyncRead + AsyncWrite`) to decode -/// and encode XMPP packets. -/// -/// Implements `Sink + Stream` -pub struct XmppStream { - /// The local Jabber-Id - pub jid: Jid, - /// Codec instance - pub stream: Framed, - /// `` for XMPP version 1.0 - pub stream_features: StreamFeatures, - /// Root namespace - /// - /// This is different for either c2s, s2s, or component - /// connections. - pub ns: String, - /// Stream `id` attribute - pub id: String, -} - -impl XmppStream { - /// Constructor - pub fn new( - jid: Jid, - stream: Framed, - ns: String, - id: String, - stream_features: StreamFeatures, - ) -> Self { - XmppStream { - jid, - stream, - stream_features, - ns, - id, - } - } - - /// Send a `` start tag - pub async fn start(stream: S, jid: Jid, ns: String) -> Result { - let mut stream = Framed::new(stream, XmppCodec::new()); - let attrs = [ - ("to".to_owned(), jid.domain().to_string()), - ("version".to_owned(), "1.0".to_owned()), - ("xmlns".to_owned(), ns.clone()), - ("xmlns:stream".to_owned(), ns::STREAM.to_owned()), - ] - .iter() - .cloned() - .collect(); - stream.send(Packet::StreamStart(attrs)).await?; - - let stream_attrs; - loop { - match stream.next().await { - Some(Ok(Packet::StreamStart(attrs))) => { - stream_attrs = attrs; - break; - } - Some(Ok(_)) => {} - Some(Err(e)) => return Err(e.into()), - None => return Err(Error::Disconnected), - } - } - - let stream_ns = stream_attrs - .get("xmlns") - .ok_or(ProtocolError::NoStreamNamespace)? - .clone(); - let stream_id = stream_attrs - .get("id") - .ok_or(ProtocolError::NoStreamId)? - .clone(); - if stream_ns == "jabber:client" && stream_attrs.get("version").is_some() { - loop { - match stream.next().await { - Some(Ok(Packet::Stanza(stanza))) => { - let stream_features = StreamFeatures::try_from(stanza) - .map_err(|e| Error::Protocol(ParsersError::from(e).into()))?; - return Ok(XmppStream::new(jid, stream, ns, stream_id, stream_features)); - } - Some(Ok(_)) => {} - Some(Err(e)) => return Err(e.into()), - None => return Err(Error::Disconnected), - } - } - } else { - // FIXME: huge hack, shouldn’t be an element! - return Ok(XmppStream::new( - jid, - stream, - ns, - stream_id.clone(), - StreamFeatures::default(), - )); - } - } - - /// Unwraps the inner stream - pub fn into_inner(self) -> S { - self.stream.into_inner() - } - - /// Re-run `start()` - pub async fn restart(self) -> Result { - let stream = self.stream.into_inner(); - Self::start(stream, self.jid, self.ns).await - } -} - -impl XmppStream { - /// Convenience method - pub fn send_stanza>(&mut self, e: E) -> Send { - self.send(Packet::Stanza(e.into())) - } -} - -/// Proxy to self.stream -impl Sink for XmppStream { - type Error = crate::Error; - - fn poll_ready(self: Pin<&mut Self>, _ctx: &mut Context) -> Poll> { - // Pin::new(&mut self.stream).poll_ready(ctx) - // .map_err(|e| e.into()) - Poll::Ready(Ok(())) - } - - fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> { - Pin::new(&mut self.stream) - .start_send(item) - .map_err(|e| e.into()) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - Pin::new(&mut self.stream) - .poll_flush(cx) - .map_err(|e| e.into()) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - Pin::new(&mut self.stream) - .poll_close(cx) - .map_err(|e| e.into()) - } -} - -/// Proxy to self.stream -impl Stream for XmppStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - Pin::new(&mut self.stream) - .poll_next(cx) - .map(|result| result.map(|result| result.map_err(|e| e.into()))) - } -} diff --git a/tokio-xmpp/src/xmlstream/common.rs b/tokio-xmpp/src/xmlstream/common.rs index c560ab375195831597795db91bf5f2b053e289ec..9d0d6e8e9d1a445e9fe7132a7ca39e3bbfc9fd49 100644 --- a/tokio-xmpp/src/xmlstream/common.rs +++ b/tokio-xmpp/src/xmlstream/common.rs @@ -127,6 +127,10 @@ impl RawXmlStream { *this.parser.parser_pinned() = rxml::Parser::default(); *this.writer = Self::new_writer(this.stream_ns); } + + pub(super) fn into_inner(self) -> Io { + self.parser.into_inner().0 + } } impl RawXmlStream { diff --git a/tokio-xmpp/src/xmlstream/initiator.rs b/tokio-xmpp/src/xmlstream/initiator.rs index 5d6528e04ee19ee13ba13ed9478c9f096685d845..0829c9a0f62679c4c3ddee49683e7db33627851e 100644 --- a/tokio-xmpp/src/xmlstream/initiator.rs +++ b/tokio-xmpp/src/xmlstream/initiator.rs @@ -83,4 +83,19 @@ impl PendingFeaturesRecv { let features = ReadXso::read_from(Pin::new(&mut stream)).await?; Ok((features, XmlStream::wrap(stream))) } + + /// Skip receiving the responder's stream features. + /// + /// The stream can be used for exchanging stream-level elements (stanzas + /// or "nonzas"). The Rust type for these elements must be given as type + /// parameter `T`. + /// + /// **Note:** Using this on RFC 6120 compliant streams where stream + /// features **are** sent after the stream header will cause a parse error + /// down the road (because the feature stream element cannot be handled). + /// The only place where this is useful is in + /// [XEP-0114](https://xmpp.org/extensions/xep-0114.html) connections. + pub fn skip_features(self) -> XmlStream { + XmlStream::wrap(self.stream) + } } diff --git a/tokio-xmpp/src/xmlstream/mod.rs b/tokio-xmpp/src/xmlstream/mod.rs index 66f4d01c4b57d82e47301f3a02120984bb9f9abe..799c6306681ebe165b5622405a7d5b791776b668 100644 --- a/tokio-xmpp/src/xmlstream/mod.rs +++ b/tokio-xmpp/src/xmlstream/mod.rs @@ -54,7 +54,8 @@ mod responder; mod tests; pub(crate) mod xmpp; -use self::common::{RawXmlStream, ReadXsoError, ReadXsoState, StreamHeader}; +pub use self::common::StreamHeader; +use self::common::{RawXmlStream, ReadXsoError, ReadXsoState}; pub use self::initiator::{InitiatingStream, PendingFeaturesRecv}; pub use self::responder::{AcceptedStream, PendingFeaturesSend}; pub use self::xmpp::XmppStreamElement; @@ -235,6 +236,12 @@ impl XmlStream let header = StreamHeader::recv(Pin::new(&mut stream)).await?; Ok(AcceptedStream { stream, header }) } + + /// Discard all XML state and return the inner I/O object. + pub fn into_inner(self) -> Io { + self.assert_retypable(); + self.inner.into_inner() + } } impl Stream for XmlStream { diff --git a/tokio-xmpp/src/xmlstream/xmpp.rs b/tokio-xmpp/src/xmlstream/xmpp.rs index 8e39aab465db5b4d07e559073693c13dca727326..9cee7e473a16347f8736f171d5cd3c3400db8e0d 100644 --- a/tokio-xmpp/src/xmlstream/xmpp.rs +++ b/tokio-xmpp/src/xmlstream/xmpp.rs @@ -6,7 +6,7 @@ use xso::{AsXml, FromXml}; -use xmpp_parsers::{iq::Iq, message::Message, presence::Presence, sasl, starttls}; +use xmpp_parsers::{component, iq::Iq, message::Message, presence::Presence, sasl, starttls}; /// Any valid XMPP stream-level element. #[derive(FromXml, AsXml, Debug)] @@ -31,4 +31,8 @@ pub enum XmppStreamElement { /// STARTTLS-related nonza #[xml(transparent)] Starttls(starttls::Nonza), + + /// Component protocol nonzas + #[xml(transparent)] + ComponentHandshake(component::Handshake), } diff --git a/xmpp/src/event_loop.rs b/xmpp/src/event_loop.rs index 9f0f470086d8c543c41dd94216cd25fdaed8ef32..79c29305fc6ac54b2aca7ffa916ac476525197ff 100644 --- a/xmpp/src/event_loop.rs +++ b/xmpp/src/event_loop.rs @@ -7,10 +7,8 @@ use futures::StreamExt; use tokio_xmpp::connect::ServerConnector; use tokio_xmpp::{ - parsers::{ - disco::DiscoInfoQuery, iq::Iq, message::Message, presence::Presence, roster::Roster, - }, - Event as TokioXmppEvent, + parsers::{disco::DiscoInfoQuery, iq::Iq, roster::Roster}, + Event as TokioXmppEvent, Stanza, }; use crate::{iq, message, presence, Agent, Event}; @@ -46,24 +44,17 @@ pub async fn wait_for_events(agent: &mut Agent) -> Vec { events.push(Event::Disconnected(e)); } - TokioXmppEvent::Stanza(elem) => { - if elem.is("iq", "jabber:client") { - let iq = Iq::try_from(elem).unwrap(); - let new_events = iq::handle_iq(agent, iq).await; - events.extend(new_events); - } else if elem.is("message", "jabber:client") { - let message = Message::try_from(elem).unwrap(); - let new_events = message::receive::handle_message(agent, message).await; - events.extend(new_events); - } else if elem.is("presence", "jabber:client") { - let presence = Presence::try_from(elem).unwrap(); - let new_events = presence::receive::handle_presence(agent, presence).await; - events.extend(new_events); - } else if elem.is("error", "http://etherx.jabber.org/streams") { - println!("Received a fatal stream error: {}", String::from(&elem)); - } else { - panic!("Unknown stanza: {}", String::from(&elem)); - } + TokioXmppEvent::Stanza(Stanza::Iq(iq)) => { + let new_events = iq::handle_iq(agent, iq).await; + events.extend(new_events); + } + TokioXmppEvent::Stanza(Stanza::Message(message)) => { + let new_events = message::receive::handle_message(agent, message).await; + events.extend(new_events); + } + TokioXmppEvent::Stanza(Stanza::Presence(presence)) => { + let new_events = presence::receive::handle_presence(agent, presence).await; + events.extend(new_events); } }