Detailed changes
@@ -17,7 +17,6 @@ futures = "0.3"
log = "0.4"
tokio = { version = "1", features = ["net", "rt", "rt-multi-thread", "macros"] }
tokio-stream = { version = "0.1", features = [] }
-tokio-util = { version = "0.7", features = ["codec"] }
webpki-roots = { version = "0.26", optional = true }
rustls-native-certs = { version = "0.7", optional = true }
rxml = { version = "0.12.0", features = ["compact_str"] }
@@ -1,9 +1,8 @@
use futures::stream::StreamExt;
-use minidom::Element;
use std::env::args;
use std::process::exit;
use std::str::FromStr;
-use tokio_xmpp::Client;
+use tokio_xmpp::{Client, Stanza};
use xmpp_parsers::{
disco::{DiscoInfoQuery, DiscoInfoResult},
iq::{Iq, IqType},
@@ -41,24 +40,21 @@ async fn main() {
let target_jid: Jid = target.clone().parse().unwrap();
let iq = make_disco_iq(target_jid);
println!("Sending disco#info request to {}", target.clone());
- println!(">> {}", String::from(&iq));
- client.send_stanza(iq).await.unwrap();
- } else if let Some(stanza) = event.into_stanza() {
- if stanza.is("iq", "jabber:client") {
- let iq = Iq::try_from(stanza).unwrap();
- if let IqType::Result(Some(payload)) = iq.payload {
- if payload.is("query", ns::DISCO_INFO) {
- if let Ok(disco_info) = DiscoInfoResult::try_from(payload) {
- for ext in disco_info.extensions {
- if let Ok(server_info) = ServerInfo::try_from(ext) {
- print_server_info(server_info);
- }
+ println!(">> {:?}", iq);
+ client.send_stanza(iq.into()).await.unwrap();
+ } else if let Some(Stanza::Iq(iq)) = event.into_stanza() {
+ if let IqType::Result(Some(payload)) = iq.payload {
+ if payload.is("query", ns::DISCO_INFO) {
+ if let Ok(disco_info) = DiscoInfoResult::try_from(payload) {
+ for ext in disco_info.extensions {
+ if let Ok(server_info) = ServerInfo::try_from(ext) {
+ print_server_info(server_info);
}
}
}
- wait_for_stream_end = true;
- client.send_end().await.unwrap();
}
+ wait_for_stream_end = true;
+ client.send_end().await.unwrap();
}
}
} else {
@@ -67,11 +63,10 @@ async fn main() {
}
}
-fn make_disco_iq(target: Jid) -> Element {
+fn make_disco_iq(target: Jid) -> Iq {
Iq::from_get("disco", DiscoInfoQuery { node: None })
.with_id(String::from("contact"))
.with_to(target)
- .into()
}
fn convert_field(field: Vec<String>) -> String {
@@ -1,11 +1,10 @@
use futures::stream::StreamExt;
-use minidom::Element;
use std::env::args;
use std::fs::{create_dir_all, File};
use std::io::{self, Write};
use std::process::exit;
use std::str::FromStr;
-use tokio_xmpp::Client;
+use tokio_xmpp::{Client, Stanza};
use xmpp_parsers::{
avatar::{Data as AvatarData, Metadata as AvatarMetadata},
caps::{compute_disco, hash_caps, Caps},
@@ -13,7 +12,6 @@ use xmpp_parsers::{
hashes::Algo,
iq::{Iq, IqType},
jid::{BareJid, Jid},
- message::Message,
ns,
presence::{Presence, Type as PresenceType},
pubsub::{
@@ -55,100 +53,107 @@ async fn main() {
let presence = make_presence(caps);
client.send_stanza(presence.into()).await.unwrap();
} else if let Some(stanza) = event.into_stanza() {
- if stanza.is("iq", "jabber:client") {
- let iq = Iq::try_from(stanza).unwrap();
- if let IqType::Get(payload) = iq.payload {
- if payload.is("query", ns::DISCO_INFO) {
- let query = DiscoInfoQuery::try_from(payload);
- match query {
- Ok(query) => {
- let mut disco = disco_info.clone();
- disco.node = query.node;
- let iq = Iq::from_result(iq.id, Some(disco))
- .with_to(iq.from.unwrap());
- client.send_stanza(iq.into()).await.unwrap();
+ match stanza {
+ Stanza::Iq(iq) => {
+ if let IqType::Get(payload) = iq.payload {
+ if payload.is("query", ns::DISCO_INFO) {
+ let query = DiscoInfoQuery::try_from(payload);
+ match query {
+ Ok(query) => {
+ let mut disco = disco_info.clone();
+ disco.node = query.node;
+ let iq = Iq::from_result(iq.id, Some(disco))
+ .with_to(iq.from.unwrap());
+ client.send_stanza(iq.into()).await.unwrap();
+ }
+ Err(err) => client
+ .send_stanza(
+ make_error(
+ iq.from.unwrap(),
+ iq.id,
+ ErrorType::Modify,
+ DefinedCondition::BadRequest,
+ &format!("{}", err),
+ )
+ .into(),
+ )
+ .await
+ .unwrap(),
}
- Err(err) => client
- .send_stanza(make_error(
- iq.from.unwrap(),
- iq.id,
- ErrorType::Modify,
- DefinedCondition::BadRequest,
- &format!("{}", err),
- ))
+ } else {
+ // We MUST answer unhandled get iqs with a service-unavailable error.
+ client
+ .send_stanza(
+ make_error(
+ iq.from.unwrap(),
+ iq.id,
+ ErrorType::Cancel,
+ DefinedCondition::ServiceUnavailable,
+ "No handler defined for this kind of iq.",
+ )
+ .into(),
+ )
.await
- .unwrap(),
+ .unwrap();
}
- } else {
- // We MUST answer unhandled get iqs with a service-unavailable error.
+ } else if let IqType::Result(Some(payload)) = iq.payload {
+ if payload.is("pubsub", ns::PUBSUB) {
+ let pubsub = PubSub::try_from(payload).unwrap();
+ let from = iq.from.clone().unwrap_or(jid.clone().into());
+ handle_iq_result(pubsub, &from);
+ }
+ } else if let IqType::Set(_) = iq.payload {
+ // We MUST answer unhandled set iqs with a service-unavailable error.
client
- .send_stanza(make_error(
- iq.from.unwrap(),
- iq.id,
- ErrorType::Cancel,
- DefinedCondition::ServiceUnavailable,
- "No handler defined for this kind of iq.",
- ))
+ .send_stanza(
+ make_error(
+ iq.from.unwrap(),
+ iq.id,
+ ErrorType::Cancel,
+ DefinedCondition::ServiceUnavailable,
+ "No handler defined for this kind of iq.",
+ )
+ .into(),
+ )
.await
.unwrap();
}
- } else if let IqType::Result(Some(payload)) = iq.payload {
- if payload.is("pubsub", ns::PUBSUB) {
- let pubsub = PubSub::try_from(payload).unwrap();
- let from = iq.from.clone().unwrap_or(jid.clone().into());
- handle_iq_result(pubsub, &from);
- }
- } else if let IqType::Set(_) = iq.payload {
- // We MUST answer unhandled set iqs with a service-unavailable error.
- client
- .send_stanza(make_error(
- iq.from.unwrap(),
- iq.id,
- ErrorType::Cancel,
- DefinedCondition::ServiceUnavailable,
- "No handler defined for this kind of iq.",
- ))
- .await
- .unwrap();
}
- } else if stanza.is("message", "jabber:client") {
- let message = Message::try_from(stanza).unwrap();
- let from = message.from.clone().unwrap();
- if let Some(body) = message.get_best_body(vec!["en"]) {
- if body.0 == "die" {
- println!("Secret die command triggered by {}", from);
- wait_for_stream_end = true;
- client.send_end().await.unwrap();
+ Stanza::Message(message) => {
+ let from = message.from.clone().unwrap();
+ if let Some(body) = message.get_best_body(vec!["en"]) {
+ if body.0 == "die" {
+ println!("Secret die command triggered by {}", from);
+ wait_for_stream_end = true;
+ client.send_end().await.unwrap();
+ }
}
- }
- for child in message.payloads {
- if child.is("event", ns::PUBSUB_EVENT) {
- let event = PubSubEvent::try_from(child).unwrap();
- if let PubSubEvent::PublishedItems { node, items } = event {
- if node.0 == ns::AVATAR_METADATA {
- for item in items.into_iter() {
- let payload = item.payload.clone().unwrap();
- if payload.is("metadata", ns::AVATAR_METADATA) {
- // TODO: do something with these metadata.
- let _metadata =
- AvatarMetadata::try_from(payload).unwrap();
- println!(
- "[1m{}[0m has published an avatar, downloading...",
- from.clone()
- );
- let iq = download_avatar(from.clone());
- client.send_stanza(iq.into()).await.unwrap();
+ for child in message.payloads {
+ if child.is("event", ns::PUBSUB_EVENT) {
+ let event = PubSubEvent::try_from(child).unwrap();
+ if let PubSubEvent::PublishedItems { node, items } = event {
+ if node.0 == ns::AVATAR_METADATA {
+ for item in items.into_iter() {
+ let payload = item.payload.clone().unwrap();
+ if payload.is("metadata", ns::AVATAR_METADATA) {
+ // TODO: do something with these metadata.
+ let _metadata =
+ AvatarMetadata::try_from(payload).unwrap();
+ println!(
+ "[1m{}[0m has published an avatar, downloading...",
+ from.clone()
+ );
+ let iq = download_avatar(from.clone());
+ client.send_stanza(iq.into()).await.unwrap();
+ }
}
}
}
}
}
}
- } else if stanza.is("presence", "jabber:client") {
// Nothing to do here.
- ()
- } else {
- panic!("Unknown stanza: {}", String::from(&stanza));
+ Stanza::Presence(_) => (),
}
}
} else {
@@ -164,10 +169,9 @@ fn make_error(
type_: ErrorType,
condition: DefinedCondition,
text: &str,
-) -> Element {
+) -> Iq {
let error = StanzaError::new(type_, condition, "en", text);
- let iq = Iq::from_error(id, error).with_to(to);
- iq.into()
+ Iq::from_error(id, error).with_to(to)
}
fn make_disco() -> DiscoInfoResult {
@@ -1,5 +1,4 @@
use futures::stream::StreamExt;
-use minidom::Element;
use std::env::args;
use std::process::exit;
use std::str::FromStr;
@@ -40,7 +39,7 @@ async fn main() {
println!("Online at {}", jid);
let presence = make_presence();
- client.send_stanza(presence).await.unwrap();
+ client.send_stanza(presence.into()).await.unwrap();
} else if let Some(message) = event
.into_stanza()
.and_then(|stanza| Message::try_from(stanza).ok())
@@ -55,7 +54,7 @@ async fn main() {
if message.type_ != MessageType::Error {
// This is a message we'll echo
let reply = make_reply(from.clone(), &body.0);
- client.send_stanza(reply).await.unwrap();
+ client.send_stanza(reply.into()).await.unwrap();
}
}
_ => {}
@@ -69,18 +68,18 @@ async fn main() {
}
// Construct a <presence/>
-fn make_presence() -> Element {
+fn make_presence() -> Presence {
let mut presence = Presence::new(PresenceType::None);
presence.show = Some(PresenceShow::Chat);
presence
.statuses
.insert(String::from("en"), String::from("Echoing messages."));
- presence.into()
+ presence
}
// Construct a chat <message/>
-fn make_reply(to: Jid, body: &str) -> Element {
+fn make_reply(to: Jid, body: &str) -> Message {
let mut message = Message::new(Some(to));
message.bodies.insert(String::new(), Body(body.to_owned()));
- message.into()
+ message
}
@@ -1,5 +1,4 @@
use futures::stream::StreamExt;
-use minidom::Element;
use std::env::args;
use std::process::exit;
use std::str::FromStr;
@@ -45,7 +44,7 @@ async fn main() {
Jid::from_str("test@component.linkmauve.fr/coucou").unwrap(),
Jid::from_str("linkmauve@linkmauve.fr").unwrap(),
);
- component.send_stanza(presence).await.unwrap();
+ component.send_stanza(presence.into()).await.unwrap();
// Main loop, processes events
loop {
@@ -56,7 +55,7 @@ async fn main() {
(Some(from), Some(body)) => {
if message.type_ != MessageType::Error {
let reply = make_reply(from, &body.0);
- component.send_stanza(reply).await.unwrap();
+ component.send_stanza(reply.into()).await.unwrap();
}
}
_ => (),
@@ -69,7 +68,7 @@ async fn main() {
}
// Construct a <presence/>
-fn make_presence(from: Jid, to: Jid) -> Element {
+fn make_presence(from: Jid, to: Jid) -> Presence {
let mut presence = Presence::new(PresenceType::None);
presence.from = Some(from);
presence.to = Some(to);
@@ -77,12 +76,12 @@ fn make_presence(from: Jid, to: Jid) -> Element {
presence
.statuses
.insert(String::from("en"), String::from("Echoing messages."));
- presence.into()
+ presence
}
// Construct a chat <message/>
-fn make_reply(to: Jid, body: &str) -> Element {
+fn make_reply(to: Jid, body: &str) -> Message {
let mut message = Message::new(Some(to));
message.bodies.insert(String::new(), Body(body.to_owned()));
- message.into()
+ message
}
@@ -1,8 +1,8 @@
use futures::{SinkExt, StreamExt};
use tokio::{self, io, net::TcpSocket};
-use tokio_util::codec::Framed;
-use tokio_xmpp::proto::XmppCodec;
+use tokio_xmpp::parsers::stream_features::StreamFeatures;
+use tokio_xmpp::xmlstream::{accept_stream, StreamHeader};
#[tokio::main]
async fn main() -> Result<(), io::Error> {
@@ -16,16 +16,22 @@ async fn main() -> Result<(), io::Error> {
// Main loop, accepts incoming connections
loop {
let (stream, _addr) = listener.accept().await?;
-
- // Use the `XMPPCodec` to encode and decode frames
- let mut framed = Framed::new(stream, XmppCodec::new());
+ let stream = accept_stream(
+ tokio::io::BufStream::new(stream),
+ tokio_xmpp::parsers::ns::DEFAULT_NS,
+ )
+ .await?;
+ let stream = stream.send_header(StreamHeader::default()).await?;
+ let mut stream = stream
+ .send_features::<minidom::Element>(&StreamFeatures::default())
+ .await?;
tokio::spawn(async move {
- while let Some(packet) = framed.next().await {
+ while let Some(packet) = stream.next().await {
match packet {
Ok(packet) => {
println!("Received packet: {:?}", packet);
- framed.send(packet).await.unwrap();
+ stream.send(&packet).await.unwrap();
}
Err(e) => {
eprintln!("Error: {:?}", e);
@@ -1,46 +1,53 @@
-use futures::stream::StreamExt;
-use tokio::io::{AsyncRead, AsyncWrite};
+use std::io;
+
+use futures::{SinkExt, StreamExt};
+use tokio::io::{AsyncBufRead, AsyncWrite};
use xmpp_parsers::bind::{BindQuery, BindResponse};
use xmpp_parsers::iq::{Iq, IqType};
+use xmpp_parsers::stream_features::StreamFeatures;
use crate::error::{Error, ProtocolError};
-use crate::proto::{Packet, XmppStream};
+use crate::jid::{FullJid, Jid};
+use crate::xmlstream::{ReadError, XmppStream, XmppStreamElement};
const BIND_REQ_ID: &str = "resource-bind";
-pub async fn bind<S: AsyncRead + AsyncWrite + Unpin>(
- mut stream: XmppStream<S>,
-) -> Result<XmppStream<S>, Error> {
- if stream.stream_features.can_bind() {
- let resource = stream
- .jid
+pub async fn bind<S: AsyncBufRead + AsyncWrite + Unpin>(
+ stream: &mut XmppStream<S>,
+ features: &StreamFeatures,
+ jid: &Jid,
+) -> Result<Option<FullJid>, Error> {
+ if features.can_bind() {
+ let resource = jid
.resource()
.and_then(|resource| Some(resource.to_string()));
let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource));
- stream.send_stanza(iq).await?;
+ stream.send(&XmppStreamElement::Iq(iq)).await?;
loop {
match stream.next().await {
- Some(Ok(Packet::Stanza(stanza))) => match Iq::try_from(stanza) {
- Ok(iq) if iq.id == BIND_REQ_ID => match iq.payload {
- IqType::Result(payload) => {
- payload
- .and_then(|payload| BindResponse::try_from(payload).ok())
- .map(|bind| stream.jid = bind.into());
- return Ok(stream);
+ Some(Ok(XmppStreamElement::Iq(iq))) if iq.id == BIND_REQ_ID => match iq.payload {
+ IqType::Result(Some(payload)) => match BindResponse::try_from(payload) {
+ Ok(v) => {
+ return Ok(Some(v.into()));
}
- _ => return Err(ProtocolError::InvalidBindResponse.into()),
+ Err(_) => return Err(ProtocolError::InvalidBindResponse.into()),
},
- _ => {}
+ _ => return Err(ProtocolError::InvalidBindResponse.into()),
},
Some(Ok(_)) => {}
- Some(Err(e)) => return Err(e),
- None => return Err(Error::Disconnected),
+ Some(Err(ReadError::SoftTimeout)) => {}
+ Some(Err(ReadError::HardError(e))) => return Err(e.into()),
+ Some(Err(ReadError::ParseError(e))) => {
+ return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
+ }
+ Some(Err(ReadError::StreamFooterReceived)) | None => {
+ return Err(Error::Disconnected)
+ }
}
}
} else {
- // No resource binding available,
- // return the (probably // usable) stream immediately
- return Ok(stream);
+ // No resource binding available, do nothing.
+ return Ok(None);
}
}
@@ -1,25 +1,32 @@
-use futures::stream::StreamExt;
+use futures::{SinkExt, StreamExt};
use sasl::client::mechanisms::{Anonymous, Plain, Scram};
use sasl::client::Mechanism;
use sasl::common::scram::{Sha1, Sha256};
use sasl::common::Credentials;
+use std::borrow::Cow;
use std::collections::HashSet;
+use std::io;
use std::str::FromStr;
-use tokio::io::{AsyncRead, AsyncWrite};
-use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
-use xmpp_parsers::{jid::Jid, ns};
+use tokio::io::{AsyncBufRead, AsyncWrite};
+use xmpp_parsers::{
+ jid::{FullJid, Jid},
+ ns,
+ sasl::{Auth, Mechanism as XMPPMechanism, Nonza, Response},
+ stream_features::{SaslMechanisms, StreamFeatures},
+};
use crate::{
client::bind::bind,
connect::ServerConnector,
error::{AuthError, Error, ProtocolError},
- proto::{Packet, XmppStream},
+ xmlstream::{xmpp::XmppStreamElement, InitiatingStream, ReadError, StreamHeader, XmppStream},
};
-pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
+pub async fn auth<S: AsyncBufRead + AsyncWrite + Unpin>(
mut stream: XmppStream<S>,
+ sasl_mechanisms: &SaslMechanisms,
creds: Credentials,
-) -> Result<S, Error> {
+) -> Result<InitiatingStream<S>, Error> {
let local_mechs: Vec<Box<dyn Fn() -> Box<dyn Mechanism + Send + Sync> + Send>> = vec![
Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
@@ -27,13 +34,7 @@ pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
Box::new(|| Box::new(Anonymous::new())),
];
- let remote_mechs: HashSet<String> = stream
- .stream_features
- .sasl_mechanisms
- .mechanisms
- .iter()
- .cloned()
- .collect();
+ let remote_mechs: HashSet<String> = sasl_mechanisms.mechanisms.iter().cloned().collect();
for local_mech in local_mechs {
let mut mechanism = local_mech();
@@ -43,43 +44,55 @@ pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
stream
- .send_stanza(Auth {
+ .send(&XmppStreamElement::Sasl(Nonza::Auth(Auth {
mechanism: mechanism_name,
data: initial,
- })
+ })))
.await?;
loop {
match stream.next().await {
- Some(Ok(Packet::Stanza(stanza))) => {
- if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
+ Some(Ok(XmppStreamElement::Sasl(sasl))) => match sasl {
+ Nonza::Challenge(challenge) => {
let response = mechanism
.response(&challenge.data)
.map_err(|e| AuthError::Sasl(e))?;
// Send response and loop
- stream.send_stanza(Response { data: response }).await?;
- } else if let Ok(_) = Success::try_from(stanza.clone()) {
- return Ok(stream.into_inner());
- } else if let Ok(failure) = Failure::try_from(stanza.clone()) {
+ stream
+ .send(&XmppStreamElement::Sasl(Nonza::Response(Response {
+ data: response,
+ })))
+ .await?;
+ }
+ Nonza::Success(_) => return Ok(stream.initiate_reset()),
+ Nonza::Failure(failure) => {
return Err(Error::Auth(AuthError::Fail(failure.defined_condition)));
- // TODO: This code was needed for compatibility with some broken server,
- // but itβs been forgotten which. It is currently commented out so that we
- // can find it and fix the server software instead.
- /*
- } else if stanza.name() == "failure" {
- // Workaround for https://gitlab.com/xmpp-rs/xmpp-parsers/merge_requests/1
- return Err(Error::Auth(AuthError::Sasl("failure".to_string())));
- */
- } else {
- // ignore and loop
}
+ _ => {
+ // Ignore?!
+ }
+ },
+ Some(Ok(el)) => {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!(
+ "unexpected stream element during SASL negotiation: {:?}",
+ el
+ ),
+ )
+ .into())
+ }
+ Some(Err(ReadError::HardError(e))) => return Err(e.into()),
+ Some(Err(ReadError::ParseError(e))) => {
+ return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
+ }
+ Some(Err(ReadError::SoftTimeout)) => {
+ // We cannot do anything about soft timeouts here...
}
- Some(Ok(_)) => {
- // ignore and loop
+ Some(Err(ReadError::StreamFooterReceived)) | None => {
+ return Err(Error::Disconnected)
}
- Some(Err(e)) => return Err(e),
- None => return Err(Error::Disconnected),
}
}
}
@@ -94,24 +107,31 @@ pub async fn client_login<C: ServerConnector>(
server: C,
jid: Jid,
password: String,
-) -> Result<XmppStream<C::Stream>, Error> {
+) -> Result<(Option<FullJid>, StreamFeatures, XmppStream<C::Stream>), Error> {
let username = jid.node().unwrap().as_str();
let password = password;
let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT).await?;
+ let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
- let channel_binding = C::channel_binding(xmpp_stream.stream.get_ref())?;
+ let channel_binding = C::channel_binding(xmpp_stream.get_stream())?;
let creds = Credentials::default()
.with_username(username)
.with_password(password)
.with_channel_binding(channel_binding);
// Authenticated (unspecified) stream
- let stream = auth(xmpp_stream, creds).await?;
- // Authenticated XmppStream
- let xmpp_stream = XmppStream::start(stream, jid, ns::JABBER_CLIENT.to_owned()).await?;
+ let stream = auth(xmpp_stream, &features.sasl_mechanisms, creds).await?;
+ let stream = stream
+ .send_header(StreamHeader {
+ to: Some(Cow::Borrowed(jid.domain().as_str())),
+ from: None,
+ id: None,
+ })
+ .await?;
+ let (features, mut stream) = stream.recv_features().await?;
// XmppStream bound to user session
- let xmpp_stream = bind(xmpp_stream).await?;
- Ok(xmpp_stream)
+ let full_jid = bind(&mut stream, &features, &jid).await?;
+ Ok((full_jid, features, stream))
}
@@ -1,12 +1,11 @@
use futures::sink::SinkExt;
-use minidom::Element;
-use xmpp_parsers::{jid::Jid, ns, stream_features::StreamFeatures};
+use xmpp_parsers::{jid::Jid, stream_features::StreamFeatures};
use crate::{
client::{login::client_login, stream::ClientState},
connect::ServerConnector,
error::Error,
- proto::{add_stanza_id, Packet},
+ Stanza,
};
#[cfg(any(feature = "starttls", feature = "insecure-tcp"))]
@@ -47,21 +46,21 @@ impl<C: ServerConnector> Client<C> {
/// server).
pub fn bound_jid(&self) -> Option<&Jid> {
match self.state {
- ClientState::Connected(ref stream) => Some(&stream.jid),
+ ClientState::Connected { ref bound_jid, .. } => Some(bound_jid),
_ => None,
}
}
/// Send stanza
- pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> {
- self.send(Packet::Stanza(add_stanza_id(stanza, ns::JABBER_CLIENT)))
- .await
+ pub async fn send_stanza(&mut self, mut stanza: Stanza) -> Result<(), Error> {
+ stanza.ensure_id();
+ self.send(stanza).await
}
/// Get the stream features (`<stream:features/>`) of the underlying stream
pub fn get_stream_features(&self) -> Option<&StreamFeatures> {
match self.state {
- ClientState::Connected(ref stream) => Some(&stream.stream_features),
+ ClientState::Connected { ref features, .. } => Some(features),
_ => None,
}
}
@@ -73,7 +72,14 @@ impl<C: ServerConnector> Client<C> {
///
/// Make sure to disable reconnect.
pub async fn send_end(&mut self) -> Result<(), Error> {
- self.send(Packet::StreamEnd).await
+ match self.state {
+ ClientState::Connected { ref mut stream, .. } => Ok(stream.close().await?),
+ ClientState::Connecting { .. } => {
+ self.state = ClientState::Disconnected;
+ Ok(())
+ }
+ _ => Ok(()),
+ }
}
}
@@ -1,22 +1,31 @@
use futures::{task::Poll, Future, Sink, Stream};
+use std::io;
use std::mem::replace;
use std::pin::Pin;
use std::task::Context;
use tokio::task::JoinHandle;
+use xmpp_parsers::{
+ jid::{FullJid, Jid},
+ stream_features::StreamFeatures,
+};
use crate::{
- client::login::client_login,
+ client::{login::client_login, Client},
connect::{AsyncReadAndWrite, ServerConnector},
- error::{Error, ProtocolError},
- proto::{Packet, XmppStream},
- Client, Event,
+ error::Error,
+ xmlstream::{xmpp::XmppStreamElement, ReadError, XmppStream},
+ Event, Stanza,
};
pub(crate) enum ClientState<S: AsyncReadAndWrite> {
Invalid,
Disconnected,
- Connecting(JoinHandle<Result<XmppStream<S>, Error>>),
- Connected(XmppStream<S>),
+ Connecting(JoinHandle<Result<(Option<FullJid>, StreamFeatures, XmppStream<S>), Error>>),
+ Connected {
+ stream: XmppStream<S>,
+ features: StreamFeatures,
+ bound_jid: Jid,
+ },
}
/// Incoming XMPP events
@@ -56,9 +65,13 @@ impl<C: ServerConnector> Stream for Client<C> {
Poll::Ready(None)
}
ClientState::Connecting(mut connect) => match Pin::new(&mut connect).poll(cx) {
- Poll::Ready(Ok(Ok(stream))) => {
- let bound_jid = stream.jid.clone();
- self.state = ClientState::Connected(stream);
+ Poll::Ready(Ok(Ok((bound_jid, features, stream)))) => {
+ let bound_jid = bound_jid.map(Jid::from).unwrap_or_else(|| self.jid.clone());
+ self.state = ClientState::Connected {
+ stream,
+ bound_jid: bound_jid.clone(),
+ features,
+ };
Poll::Ready(Some(Event::Online {
bound_jid,
resumed: false,
@@ -77,7 +90,11 @@ impl<C: ServerConnector> Stream for Client<C> {
Poll::Pending
}
},
- ClientState::Connected(mut stream) => {
+ ClientState::Connected {
+ mut stream,
+ features,
+ bound_jid,
+ } => {
// Poll sink
match Pin::new(&mut stream).poll_ready(cx) {
Poll::Pending => (),
@@ -99,40 +116,69 @@ impl<C: ServerConnector> Stream for Client<C> {
// return.
loop {
match Pin::new(&mut stream).poll_next(cx) {
- Poll::Ready(None) => {
+ Poll::Ready(None)
+ | Poll::Ready(Some(Err(ReadError::StreamFooterReceived))) => {
// EOF
self.state = ClientState::Disconnected;
return Poll::Ready(Some(Event::Disconnected(Error::Disconnected)));
}
- Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => {
- // Receive stanza
- self.state = ClientState::Connected(stream);
- return Poll::Ready(Some(Event::Stanza(stanza)));
- }
- Poll::Ready(Some(Ok(Packet::Text(_)))) => {
- // Ignore text between stanzas
+ Poll::Ready(Some(Err(ReadError::HardError(e)))) => {
+ // Treat stream as dead on I/O errors
+ self.state = ClientState::Disconnected;
+ return Poll::Ready(Some(Event::Disconnected(e.into())));
}
- Poll::Ready(Some(Ok(Packet::StreamStart(_)))) => {
- // <stream:stream>
+ Poll::Ready(Some(Err(ReadError::ParseError(e)))) => {
+ // Treat stream as dead on parse errors, too (for now...)
self.state = ClientState::Disconnected;
return Poll::Ready(Some(Event::Disconnected(
- ProtocolError::InvalidStreamStart.into(),
+ io::Error::new(io::ErrorKind::InvalidData, e).into(),
)));
}
- Poll::Ready(Some(Ok(Packet::StreamEnd))) => {
- // End of stream: </stream:stream>
- self.state = ClientState::Disconnected;
- return Poll::Ready(Some(Event::Disconnected(Error::Disconnected)));
+ Poll::Ready(Some(Err(ReadError::SoftTimeout))) => {
+ // TODO: do something smart about this.
+ }
+ Poll::Ready(Some(Ok(XmppStreamElement::Iq(stanza)))) => {
+ // Receive stanza
+ self.state = ClientState::Connected {
+ stream,
+ features,
+ bound_jid,
+ };
+ // TODO: use specific stanza types instead of going back to elements...
+ return Poll::Ready(Some(Event::Stanza(stanza.into())));
+ }
+ Poll::Ready(Some(Ok(XmppStreamElement::Message(stanza)))) => {
+ // Receive stanza
+ self.state = ClientState::Connected {
+ stream,
+ features,
+ bound_jid,
+ };
+ // TODO: use specific stanza types instead of going back to elements...
+ return Poll::Ready(Some(Event::Stanza(stanza.into())));
+ }
+ Poll::Ready(Some(Ok(XmppStreamElement::Presence(stanza)))) => {
+ // Receive stanza
+ self.state = ClientState::Connected {
+ stream,
+ features,
+ bound_jid,
+ };
+ // TODO: use specific stanza types instead of going back to elements...
+ return Poll::Ready(Some(Event::Stanza(stanza.into())));
+ }
+ Poll::Ready(Some(Ok(_))) => {
+ // We ignore these for now.
}
Poll::Pending => {
// Try again later
- self.state = ClientState::Connected(stream);
+ self.state = ClientState::Connected {
+ stream,
+ features,
+ bound_jid,
+ };
return Poll::Pending;
}
- Poll::Ready(Some(Err(e))) => {
- self.state = ClientState::Disconnected;
- return Poll::Ready(Some(Event::Disconnected(e.into())));
- }
}
}
}
@@ -143,21 +189,21 @@ impl<C: ServerConnector> Stream for Client<C> {
/// Outgoing XMPP packets
///
/// See `send_stanza()` for an `async fn`
-impl<C: ServerConnector> Sink<Packet> for Client<C> {
+impl<C: ServerConnector> Sink<Stanza> for Client<C> {
type Error = Error;
- fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
+ fn start_send(mut self: Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> {
match self.state {
- ClientState::Connected(ref mut stream) => {
- Pin::new(stream).start_send(item).map_err(|e| e.into())
- }
+ ClientState::Connected { ref mut stream, .. } => Pin::new(stream)
+ .start_send(&item.into())
+ .map_err(|e| e.into()),
_ => Err(Error::InvalidState),
}
}
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
match self.state {
- ClientState::Connected(ref mut stream) => {
+ ClientState::Connected { ref mut stream, .. } => {
Pin::new(stream).poll_ready(cx).map_err(|e| e.into())
}
_ => Poll::Pending,
@@ -166,7 +212,7 @@ impl<C: ServerConnector> Sink<Packet> for Client<C> {
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
match self.state {
- ClientState::Connected(ref mut stream) => {
+ ClientState::Connected { ref mut stream, .. } => {
Pin::new(stream).poll_flush(cx).map_err(|e| e.into())
}
_ => Poll::Pending,
@@ -175,7 +221,7 @@ impl<C: ServerConnector> Sink<Packet> for Client<C> {
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
match self.state {
- ClientState::Connected(ref mut stream) => {
+ ClientState::Connected { ref mut stream, .. } => {
Pin::new(stream).poll_close(cx).map_err(|e| e.into())
}
_ => Poll::Pending,
@@ -1,12 +1,12 @@
-use futures::stream::StreamExt;
-use tokio::io::{AsyncRead, AsyncWrite};
+use std::io;
+
+use futures::{SinkExt, StreamExt};
+use tokio::io::{AsyncBufRead, AsyncWrite};
use xmpp_parsers::{component::Handshake, jid::Jid, ns};
-use crate::{
- connect::ServerConnector,
- error::{AuthError, Error},
- proto::{Packet, XmppStream},
-};
+use crate::component::ServerConnector;
+use crate::error::{AuthError, Error};
+use crate::xmlstream::{ReadError, XmppStream, XmppStreamElement};
/// Log into an XMPP server as a client with a jid+pass
pub async fn component_login<C: ServerConnector>(
@@ -15,32 +15,47 @@ pub async fn component_login<C: ServerConnector>(
password: String,
) -> Result<XmppStream<C::Stream>, Error> {
let password = password;
- let mut xmpp_stream = connector.connect(&jid, ns::COMPONENT).await?;
- auth(&mut xmpp_stream, password).await?;
- Ok(xmpp_stream)
+ let mut stream = connector.connect(&jid, ns::COMPONENT).await?;
+ let header = stream.take_header();
+ let mut stream = stream.skip_features();
+ let stream_id = match header.id {
+ Some(ref v) => &**v,
+ None => {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ "stream id missing on component stream",
+ )
+ .into())
+ }
+ };
+ auth(&mut stream, stream_id, &password).await?;
+ Ok(stream)
}
-pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
+pub async fn auth<S: AsyncBufRead + AsyncWrite + Unpin>(
stream: &mut XmppStream<S>,
- password: String,
+ stream_id: &str,
+ password: &str,
) -> Result<(), Error> {
- let nonza = Handshake::from_password_and_stream_id(&password, &stream.id);
- stream.send_stanza(nonza).await?;
+ let nonza = Handshake::from_password_and_stream_id(password, stream_id);
+ stream
+ .send(&XmppStreamElement::ComponentHandshake(nonza))
+ .await?;
loop {
match stream.next().await {
- Some(Ok(Packet::Stanza(ref stanza)))
- if stanza.is("handshake", ns::COMPONENT_ACCEPT) =>
- {
+ Some(Ok(XmppStreamElement::ComponentHandshake(_))) => {
return Ok(());
}
- Some(Ok(Packet::Stanza(ref stanza)))
- if stanza.is("error", "http://etherx.jabber.org/streams") =>
- {
+ Some(Ok(_)) => {
return Err(AuthError::ComponentFail.into());
}
- Some(_) => {}
- None => return Err(Error::Disconnected),
+ Some(Err(ReadError::SoftTimeout)) => (),
+ Some(Err(ReadError::HardError(e))) => return Err(e.into()),
+ Some(Err(ReadError::ParseError(e))) => {
+ return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
+ }
+ Some(Err(ReadError::StreamFooterReceived)) | None => return Err(Error::Disconnected),
}
}
}
@@ -2,15 +2,12 @@
//! XMPP server under a JID consisting of just a domain name. They are
//! allowed to use any user and resource identifiers in their stanzas.
use futures::sink::SinkExt;
-use minidom::Element;
use std::str::FromStr;
-use xmpp_parsers::{jid::Jid, ns};
+use xmpp_parsers::jid::Jid;
use crate::{
- component::login::component_login,
- connect::ServerConnector,
- proto::{add_stanza_id, XmppStream},
- Error,
+ component::login::component_login, connect::ServerConnector, xmlstream::XmppStream, Error,
+ Stanza,
};
#[cfg(any(feature = "starttls", feature = "insecure-tcp"))]
@@ -33,8 +30,9 @@ pub struct Component<C: ServerConnector> {
impl<C: ServerConnector> Component<C> {
/// Send stanza
- pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> {
- self.send(add_stanza_id(stanza, ns::COMPONENT_ACCEPT)).await
+ pub async fn send_stanza(&mut self, mut stanza: Stanza) -> Result<(), Error> {
+ stanza.ensure_id();
+ self.send(stanza).await
}
/// End connection
@@ -2,21 +2,27 @@
//! XMPP server under a JID consisting of just a domain name. They are
//! allowed to use any user and resource identifiers in their stanzas.
use futures::{task::Poll, Sink, Stream};
-use minidom::Element;
use std::pin::Pin;
use std::task::Context;
-use crate::{component::Component, connect::ServerConnector, proto::Packet, Error};
+use crate::{
+ component::Component, connect::ServerConnector, xmlstream::XmppStreamElement, Error, Stanza,
+};
impl<C: ServerConnector> Stream for Component<C> {
- type Item = Element;
+ type Item = Stanza;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
loop {
match Pin::new(&mut self.stream).poll_next(cx) {
- Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => return Poll::Ready(Some(stanza)),
- Poll::Ready(Some(Ok(Packet::Text(_)))) => {
- // retry
+ Poll::Ready(Some(Ok(XmppStreamElement::Iq(stanza)))) => {
+ return Poll::Ready(Some(Stanza::Iq(stanza)))
+ }
+ Poll::Ready(Some(Ok(XmppStreamElement::Message(stanza)))) => {
+ return Poll::Ready(Some(Stanza::Message(stanza)))
+ }
+ Poll::Ready(Some(Ok(XmppStreamElement::Presence(stanza)))) => {
+ return Poll::Ready(Some(Stanza::Presence(stanza)))
}
Poll::Ready(Some(Ok(_))) =>
// unexpected
@@ -31,12 +37,12 @@ impl<C: ServerConnector> Stream for Component<C> {
}
}
-impl<C: ServerConnector> Sink<Element> for Component<C> {
+impl<C: ServerConnector> Sink<Stanza> for Component<C> {
type Error = Error;
- fn start_send(mut self: Pin<&mut Self>, item: Element) -> Result<(), Self::Error> {
+ fn start_send(mut self: Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> {
Pin::new(&mut self.stream)
- .start_send(Packet::Stanza(item))
+ .start_send(&item.into())
.map_err(|e| e.into())
}
@@ -1,10 +1,10 @@
//! `ServerConnector` provides streams for XMPP clients
use sasl::common::ChannelBinding;
-use tokio::io::{AsyncRead, AsyncWrite};
+use tokio::io::{AsyncBufRead, AsyncWrite};
use xmpp_parsers::jid::Jid;
-use crate::proto::XmppStream;
+use crate::xmlstream::PendingFeaturesRecv;
use crate::Error;
#[cfg(feature = "starttls")]
@@ -21,8 +21,8 @@ mod dns;
pub use dns::DnsConfig;
/// trait returned wrapped in XmppStream by ServerConnector
-pub trait AsyncReadAndWrite: AsyncRead + AsyncWrite + Unpin + Send {}
-impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncReadAndWrite for T {}
+pub trait AsyncReadAndWrite: AsyncBufRead + AsyncWrite + Unpin + Send {}
+impl<T: AsyncBufRead + AsyncWrite + Unpin + Send> AsyncReadAndWrite for T {}
/// Trait that must be extended by the implementation of ServerConnector
pub trait ServerConnectorError: std::error::Error + Sync + Send {}
@@ -35,8 +35,8 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
fn connect(
&self,
jid: &Jid,
- ns: &str,
- ) -> impl std::future::Future<Output = Result<XmppStream<Self::Stream>, Error>> + Send;
+ ns: &'static str,
+ ) -> impl std::future::Future<Output = Result<PendingFeaturesRecv<Self::Stream>, Error>> + Send;
/// Return channel binding data if available
/// do not fail if channel binding is simply unavailable, just return Ok(None)
@@ -2,8 +2,10 @@
#[cfg(feature = "tls-native")]
use native_tls::Error as TlsError;
+use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt;
+use std::io;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
@@ -28,18 +30,23 @@ use {
tokio_native_tls::{TlsConnector, TlsStream},
};
-use minidom::Element;
use sasl::common::ChannelBinding;
use tokio::{
- io::{AsyncRead, AsyncWrite},
+ io::{AsyncRead, AsyncWrite, BufStream},
net::TcpStream,
};
-use xmpp_parsers::{jid::Jid, ns};
+use xmpp_parsers::{
+ jid::Jid,
+ starttls::{self, Request},
+};
use crate::{
connect::{DnsConfig, ServerConnector, ServerConnectorError},
error::{Error, ProtocolError},
- proto::{Packet, XmppStream},
+ xmlstream::{
+ initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, XmppStream,
+ XmppStreamElement,
+ },
Client,
};
@@ -57,20 +64,44 @@ impl From<DnsConfig> for StartTlsServerConnector {
}
impl ServerConnector for StartTlsServerConnector {
- type Stream = TlsStream<TcpStream>;
- async fn connect(&self, jid: &Jid, ns: &str) -> Result<XmppStream<Self::Stream>, Error> {
- let tcp_stream = self.0.resolve().await?;
+ type Stream = BufStream<TlsStream<TcpStream>>;
- // Unencryped XmppStream
- let xmpp_stream = XmppStream::start(tcp_stream, jid.clone(), ns.to_owned()).await?;
+ async fn connect(
+ &self,
+ jid: &Jid,
+ ns: &'static str,
+ ) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
+ let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
- if xmpp_stream.stream_features.can_starttls() {
+ // Unencryped XmppStream
+ let xmpp_stream = initiate_stream(
+ tcp_stream,
+ ns,
+ StreamHeader {
+ to: Some(Cow::Borrowed(jid.domain().as_str())),
+ from: None,
+ id: None,
+ },
+ )
+ .await?;
+ let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
+
+ if features.can_starttls() {
// TlsStream
- let tls_stream = starttls(xmpp_stream).await?;
+ let tls_stream = starttls(xmpp_stream, jid.domain().as_str()).await?;
// Encrypted XmppStream
- Ok(XmppStream::start(tls_stream, jid.clone(), ns.to_owned()).await?)
+ Ok(initiate_stream(
+ tokio::io::BufStream::new(tls_stream),
+ ns,
+ StreamHeader {
+ to: Some(Cow::Borrowed(jid.domain().as_str())),
+ from: None,
+ id: None,
+ },
+ )
+ .await?)
} else {
- return Err(crate::Error::Protocol(ProtocolError::NoTls).into());
+ Err(crate::Error::Protocol(ProtocolError::NoTls).into())
}
}
@@ -84,7 +115,7 @@ impl ServerConnector for StartTlsServerConnector {
}
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
{
- let (_, connection) = stream.get_ref();
+ let (_, connection) = stream.get_ref().get_ref();
Ok(match connection.protocol_version() {
// TODO: Add support for TLS 1.2 and earlier.
Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
@@ -102,10 +133,11 @@ impl ServerConnector for StartTlsServerConnector {
#[cfg(feature = "tls-native")]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
- xmpp_stream: XmppStream<S>,
+ xmpp_stream: XmppStream<BufStream<S>>,
+ domain: &str,
) -> Result<TlsStream<S>, Error> {
- let domain = xmpp_stream.jid.domain().to_owned();
- let stream = xmpp_stream.into_inner();
+ let domain = domain.to_owned();
+ let stream = xmpp_stream.into_inner().into_inner();
let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
.connect(&domain, stream)
.await
@@ -115,11 +147,11 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
- xmpp_stream: XmppStream<S>,
+ xmpp_stream: XmppStream<BufStream<S>>,
+ domain: &str,
) -> Result<TlsStream<S>, Error> {
- let domain = xmpp_stream.jid.domain().to_string();
- let domain = ServerName::try_from(domain).map_err(|e| StartTlsError::DnsNameError(e))?;
- let stream = xmpp_stream.into_inner();
+ let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
+ let stream = xmpp_stream.into_inner().into_inner();
let mut root_store = RootCertStore::empty();
#[cfg(feature = "webpki-roots")]
{
@@ -142,24 +174,33 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
/// Performs `<starttls/>` on an XmppStream and returns a binary
/// TlsStream.
pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
- mut xmpp_stream: XmppStream<S>,
+ mut stream: XmppStream<BufStream<S>>,
+ domain: &str,
) -> Result<TlsStream<S>, Error> {
- let nonza = Element::builder("starttls", ns::TLS).build();
- let packet = Packet::Stanza(nonza);
- xmpp_stream.send(packet).await?;
+ stream
+ .send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
+ Request,
+ )))
+ .await?;
loop {
- match xmpp_stream.next().await {
- Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break,
- Some(Ok(Packet::Text(_))) => {}
- Some(Err(e)) => return Err(e.into()),
- _ => {
- return Err(crate::Error::Protocol(ProtocolError::NoTls).into());
+ match stream.next().await {
+ Some(Ok(XmppStreamElement::Starttls(starttls::Nonza::Proceed(_)))) => {
+ break;
+ }
+ Some(Ok(_)) => (),
+ Some(Err(ReadError::SoftTimeout)) => (),
+ Some(Err(ReadError::HardError(e))) => return Err(e.into()),
+ Some(Err(ReadError::ParseError(e))) => {
+ return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
+ }
+ None | Some(Err(ReadError::StreamFooterReceived)) => {
+ return Err(crate::Error::Disconnected)
}
}
}
- get_tls_stream(xmpp_stream).await
+ get_tls_stream(stream, domain).await
}
/// StartTLS ServerConnector Error
@@ -1,10 +1,12 @@
//! `starttls::ServerConfig` provides a `ServerConnector` for starttls connections
-use tokio::net::TcpStream;
+use std::borrow::Cow;
+
+use tokio::{io::BufStream, net::TcpStream};
use crate::{
connect::{DnsConfig, ServerConnector},
- proto::XmppStream,
+ xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader},
Client, Component, Error,
};
@@ -27,14 +29,23 @@ impl From<DnsConfig> for TcpServerConnector {
}
impl ServerConnector for TcpServerConnector {
- type Stream = TcpStream;
+ type Stream = BufStream<TcpStream>;
async fn connect(
&self,
jid: &xmpp_parsers::jid::Jid,
- ns: &str,
- ) -> Result<XmppStream<Self::Stream>, Error> {
- let stream = self.0.resolve().await?;
- Ok(XmppStream::start(stream, jid.clone(), ns.to_owned()).await?)
+ ns: &'static str,
+ ) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
+ let stream = BufStream::new(self.0.resolve().await?);
+ Ok(initiate_stream(
+ stream,
+ ns,
+ StreamHeader {
+ to: Some(Cow::Borrowed(jid.domain().as_str())),
+ from: None,
+ id: None,
+ },
+ )
+ .await?)
}
}
@@ -1,6 +1,103 @@
-use super::Error;
-use minidom::Element;
-use xmpp_parsers::jid::Jid;
+use rand::{thread_rng, Rng};
+use xmpp_parsers::{iq::Iq, jid::Jid, message::Message, presence::Presence};
+
+use crate::xmlstream::XmppStreamElement;
+use crate::Error;
+
+fn make_id() -> String {
+ let id: u64 = thread_rng().gen();
+ format!("{}", id)
+}
+
+/// A stanza sent/received over the stream.
+#[derive(Debug)]
+pub enum Stanza {
+ /// IQ stanza
+ Iq(Iq),
+
+ /// Message stanza
+ Message(Message),
+
+ /// Presence stanza
+ Presence(Presence),
+}
+
+impl Stanza {
+ /// Assign a random ID to the stanza, if no ID has been assigned yet.
+ pub fn ensure_id(&mut self) -> &str {
+ match self {
+ Self::Iq(iq) => {
+ if iq.id.len() == 0 {
+ iq.id = make_id();
+ }
+ &iq.id
+ }
+ Self::Message(message) => message.id.get_or_insert_with(make_id),
+ Self::Presence(presence) => presence.id.get_or_insert_with(make_id),
+ }
+ }
+}
+
+impl From<Iq> for Stanza {
+ fn from(other: Iq) -> Self {
+ Self::Iq(other)
+ }
+}
+
+impl From<Presence> for Stanza {
+ fn from(other: Presence) -> Self {
+ Self::Presence(other)
+ }
+}
+
+impl From<Message> for Stanza {
+ fn from(other: Message) -> Self {
+ Self::Message(other)
+ }
+}
+
+impl TryFrom<Stanza> for Message {
+ type Error = Stanza;
+
+ fn try_from(other: Stanza) -> Result<Self, Self::Error> {
+ match other {
+ Stanza::Message(st) => Ok(st),
+ other => Err(other),
+ }
+ }
+}
+
+impl TryFrom<Stanza> for Presence {
+ type Error = Stanza;
+
+ fn try_from(other: Stanza) -> Result<Self, Self::Error> {
+ match other {
+ Stanza::Presence(st) => Ok(st),
+ other => Err(other),
+ }
+ }
+}
+
+impl TryFrom<Stanza> for Iq {
+ type Error = Stanza;
+
+ fn try_from(other: Stanza) -> Result<Self, Self::Error> {
+ match other {
+ Stanza::Iq(st) => Ok(st),
+ other => Err(other),
+ }
+ }
+}
+
+impl From<Stanza> for XmppStreamElement {
+ fn from(other: Stanza) -> Self {
+ match other {
+ Stanza::Iq(st) => Self::Iq(st),
+ Stanza::Message(st) => Self::Message(st),
+ Stanza::Presence(st) => Self::Presence(st),
+ }
+ }
+}
/// High-level event on the Stream implemented by Client and Component
#[derive(Debug)]
@@ -21,7 +118,7 @@ pub enum Event {
/// Stream end
Disconnected(Error),
/// Received stanza/nonza
- Stanza(Element),
+ Stanza(Stanza),
}
impl Event {
@@ -41,16 +138,8 @@ impl Event {
}
}
- /// `Stanza` event?
- pub fn is_stanza(&self, name: &str) -> bool {
- match *self {
- Event::Stanza(ref stanza) => stanza.name() == name,
- _ => false,
- }
- }
-
/// If this is a `Stanza` event, get its data
- pub fn as_stanza(&self) -> Option<&Element> {
+ pub fn as_stanza(&self) -> Option<&Stanza> {
match *self {
Event::Stanza(ref stanza) => Some(stanza),
_ => None,
@@ -58,7 +147,7 @@ impl Event {
}
/// If this is a `Stanza` event, unwrap into its data
- pub fn into_stanza(self) -> Option<Element> {
+ pub fn into_stanza(self) -> Option<Stanza> {
match self {
Event::Stanza(stanza) => Some(stanza),
_ => None,
@@ -47,9 +47,8 @@ compile_error!(
);
mod event;
-pub use event::Event;
+pub use event::{Event, Stanza};
pub mod connect;
-pub mod proto;
pub mod xmlstream;
mod client;
@@ -1,8 +0,0 @@
-//! Low-level stream establishment
-
-mod xmpp_codec;
-mod xmpp_stream;
-
-pub use xmpp_codec::{Packet, XmppCodec};
-pub(crate) use xmpp_stream::add_stanza_id;
-pub use xmpp_stream::XmppStream;
@@ -1,410 +0,0 @@
-//! XML stream parser for XMPP
-
-use crate::Error;
-use bytes::{BufMut, BytesMut};
-use log::debug;
-use minidom::tree_builder::TreeBuilder;
-use minidom::Element;
-use rxml::{Parse, RawParser};
-use std::collections::HashMap;
-use std::fmt::Write;
-use std::io;
-#[cfg(feature = "syntax-highlighting")]
-use std::sync::OnceLock;
-use tokio_util::codec::{Decoder, Encoder};
-
-#[cfg(feature = "syntax-highlighting")]
-static PS: OnceLock<syntect::parsing::SyntaxSet> = OnceLock::new();
-#[cfg(feature = "syntax-highlighting")]
-static SYNTAX: OnceLock<syntect::parsing::SyntaxReference> = OnceLock::new();
-#[cfg(feature = "syntax-highlighting")]
-static THEME: OnceLock<syntect::highlighting::Theme> = OnceLock::new();
-
-#[cfg(feature = "syntax-highlighting")]
-fn init_syntect() {
- let ps = syntect::parsing::SyntaxSet::load_defaults_newlines();
- let syntax = ps.find_syntax_by_extension("xml").unwrap();
- let ts = syntect::highlighting::ThemeSet::load_defaults();
- let theme = ts.themes["Solarized (dark)"].clone();
-
- SYNTAX.set(syntax.clone()).unwrap();
- PS.set(ps).unwrap();
- THEME.set(theme).unwrap();
-}
-
-#[cfg(feature = "syntax-highlighting")]
-fn highlight_xml(xml: &str) -> String {
- let mut h = syntect::easy::HighlightLines::new(SYNTAX.get().unwrap(), THEME.get().unwrap());
- let ranges: Vec<_> = h.highlight_line(&xml, PS.get().unwrap()).unwrap();
- let escaped = syntect::util::as_24_bit_terminal_escaped(&ranges[..], false);
- format!("{}\x1b[0m", escaped)
-}
-
-#[cfg(not(feature = "syntax-highlighting"))]
-fn highlight_xml(xml: &str) -> &str {
- xml
-}
-
-/// Anything that can be sent or received on an XMPP/XML stream
-#[derive(Debug, Clone, PartialEq, Eq)]
-pub enum Packet {
- /// `<stream:stream>` start tag
- StreamStart(HashMap<String, String>),
- /// A complete stanza or nonza
- Stanza(Element),
- /// Plain text (think whitespace keep-alive)
- Text(String),
- /// `</stream:stream>` closing tag
- StreamEnd,
-}
-
-/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
-pub struct XmppCodec {
- /// Outgoing
- ns: Option<String>,
- /// Incoming
- driver: RawParser,
- stanza_builder: TreeBuilder,
-}
-
-impl XmppCodec {
- /// Constructor
- pub fn new() -> Self {
- let stanza_builder = TreeBuilder::new();
- let driver = RawParser::new();
- #[cfg(feature = "syntax-highlighting")]
- if log::log_enabled!(log::Level::Debug) && PS.get().is_none() {
- init_syntect();
- }
- XmppCodec {
- ns: None,
- driver,
- stanza_builder,
- }
- }
-}
-
-impl Default for XmppCodec {
- fn default() -> Self {
- Self::new()
- }
-}
-
-impl Decoder for XmppCodec {
- type Item = Packet;
- type Error = Error;
-
- fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
- loop {
- let token = match self.driver.parse_buf(buf, false) {
- Ok(Some(token)) => token,
- Ok(None) => break,
- Err(rxml::error::EndOrError::NeedMoreData) => break,
- Err(rxml::error::EndOrError::Error(e)) => {
- return Err(minidom::Error::from(e).into())
- }
- };
-
- let had_stream_root = self.stanza_builder.depth() > 0;
- self.stanza_builder.process_event(token)?;
- let has_stream_root = self.stanza_builder.depth() > 0;
-
- if !had_stream_root && has_stream_root {
- let root = self.stanza_builder.top().unwrap();
- let attrs =
- root.attrs()
- .map(|(name, value)| (name.to_owned(), value.to_owned()))
- .chain(root.prefixes.declared_prefixes().iter().map(
- |(prefix, namespace)| {
- (
- prefix
- .as_ref()
- .map(|prefix| format!("xmlns:{}", prefix))
- .unwrap_or_else(|| "xmlns".to_owned()),
- namespace.clone(),
- )
- },
- ))
- .collect();
- debug!("<< {}", highlight_xml(&String::from(root)));
- return Ok(Some(Packet::StreamStart(attrs)));
- } else if self.stanza_builder.depth() == 1 {
- self.driver.release_temporaries();
-
- if let Some(stanza) = self.stanza_builder.unshift_child() {
- debug!("<< {}", highlight_xml(&String::from(&stanza)));
- return Ok(Some(Packet::Stanza(stanza)));
- }
- } else if let Some(_) = self.stanza_builder.root.take() {
- self.driver.release_temporaries();
-
- debug!("<< {}", highlight_xml("</stream:stream>"));
- return Ok(Some(Packet::StreamEnd));
- }
- }
-
- Ok(None)
- }
-
- fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
- self.decode(buf)
- }
-}
-
-impl Encoder<Packet> for XmppCodec {
- type Error = Error;
-
- fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
- let remaining = dst.capacity() - dst.len();
- let max_stanza_size: usize = 2usize.pow(16);
- if remaining < max_stanza_size {
- dst.reserve(max_stanza_size - remaining);
- }
-
- fn to_io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
- io::Error::new(io::ErrorKind::InvalidInput, e)
- }
-
- match item {
- Packet::StreamStart(start_attrs) => {
- let mut buf = String::new();
- write!(buf, "<stream:stream").map_err(to_io_err)?;
- for (name, value) in start_attrs {
- write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).map_err(to_io_err)?;
- if name == "xmlns" {
- self.ns = Some(value);
- }
- }
- write!(buf, ">").map_err(to_io_err)?;
-
- write!(dst, "{}", buf)?;
- let utf8 = std::str::from_utf8(dst)?;
- debug!(">> {}", highlight_xml(utf8))
- }
- Packet::Stanza(stanza) => {
- let _ = stanza
- .write_to(&mut WriteBytes::new(dst))
- .map_err(|e| to_io_err(format!("{}", e)))?;
- let utf8 = std::str::from_utf8(dst)?;
- debug!(">> {}", highlight_xml(utf8));
- }
- Packet::Text(text) => {
- let _ = write_text(&text, dst).map_err(to_io_err)?;
- let utf8 = std::str::from_utf8(dst)?;
- debug!(">> {}", highlight_xml(utf8));
- }
- Packet::StreamEnd => {
- let _ = write!(dst, "</stream:stream>\n").map_err(to_io_err);
- debug!(">> {}", highlight_xml("</stream:stream>"));
- }
- }
-
- Ok(())
- }
-}
-
-/// Write XML-escaped text string
-pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
- write!(writer, "{}", escape(text))
-}
-
-/// Copied from `RustyXML` for now
-pub fn escape(input: &str) -> String {
- let mut result = String::with_capacity(input.len());
-
- for c in input.chars() {
- match c {
- '&' => result.push_str("&"),
- '<' => result.push_str("<"),
- '>' => result.push_str(">"),
- '\'' => result.push_str("'"),
- '"' => result.push_str("""),
- o => result.push(o),
- }
- }
- result
-}
-
-/// BytesMut impl only std::fmt::Write but not std::io::Write. The
-/// latter trait is required for minidom's
-/// `Element::write_to_inner()`.
-struct WriteBytes<'a> {
- dst: &'a mut BytesMut,
-}
-
-impl<'a> WriteBytes<'a> {
- fn new(dst: &'a mut BytesMut) -> Self {
- WriteBytes { dst }
- }
-}
-
-impl<'a> std::io::Write for WriteBytes<'a> {
- fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
- self.dst.put_slice(buf);
- Ok(buf.len())
- }
-
- fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
- Ok(())
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_stream_start() {
- let mut c = XmppCodec::new();
- let mut b = BytesMut::with_capacity(1024);
- b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(Some(Packet::StreamStart(_))) => true,
- _ => false,
- });
- }
-
- #[test]
- fn test_stream_end() {
- let mut c = XmppCodec::new();
- let mut b = BytesMut::with_capacity(1024);
- b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(Some(Packet::StreamStart(_))) => true,
- _ => false,
- });
- b.put_slice(b"</stream:stream>");
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(Some(Packet::StreamEnd)) => true,
- _ => false,
- });
- }
-
- #[test]
- fn test_truncated_stanza() {
- let mut c = XmppCodec::new();
- let mut b = BytesMut::with_capacity(1024);
- b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(Some(Packet::StreamStart(_))) => true,
- _ => false,
- });
-
- b.put_slice("<test>Γ</test".as_bytes());
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(None) => true,
- _ => false,
- });
-
- b.put_slice(b">");
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "Γ" => true,
- _ => false,
- });
- }
-
- #[test]
- fn test_truncated_utf8() {
- let mut c = XmppCodec::new();
- let mut b = BytesMut::with_capacity(1024);
- b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(Some(Packet::StreamStart(_))) => true,
- _ => false,
- });
-
- b.put(&b"<test>\xc3"[..]);
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(None) => true,
- _ => false,
- });
-
- b.put(&b"\x9f</test>"[..]);
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "Γ" => true,
- _ => false,
- });
- }
-
- /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3
- #[test]
- fn test_atrribute_prefix() {
- let mut c = XmppCodec::new();
- let mut b = BytesMut::with_capacity(1024);
- b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(Some(Packet::StreamStart(_))) => true,
- _ => false,
- });
-
- b.put_slice(b"<status xml:lang='en'>Test status</status>");
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(Some(Packet::Stanza(ref el)))
- if el.name() == "status"
- && el.text() == "Test status"
- && el.attr("xml:lang").map_or(false, |a| a == "en") =>
- true,
- _ => false,
- });
- }
-
- /// By default, encode() only gets a BytesMut that has 8Β KiB space reserved.
- #[test]
- fn test_large_stanza() {
- use futures::{executor::block_on, sink::SinkExt};
- use std::io::Cursor;
- use tokio_util::codec::FramedWrite;
- let mut framed = FramedWrite::new(Cursor::new(vec![]), XmppCodec::new());
- let mut text = "".to_owned();
- for _ in 0..2usize.pow(15) {
- text = text + "A";
- }
- let stanza = Element::builder("message", "jabber:client")
- .append(
- Element::builder("body", "jabber:client")
- .append(text.as_ref())
- .build(),
- )
- .build();
- block_on(framed.send(Packet::Stanza(stanza))).expect("send");
- assert_eq!(
- framed.get_ref().get_ref(),
- &format!(
- "<message xmlns='jabber:client'><body>{}</body></message>",
- text
- )
- .as_bytes()
- );
- }
-
- #[test]
- fn test_cut_out_stanza() {
- let mut c = XmppCodec::new();
- let mut b = BytesMut::with_capacity(1024);
- b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(Some(Packet::StreamStart(_))) => true,
- _ => false,
- });
-
- b.put_slice(b"<message ");
- b.put_slice(b"type='chat'><body>Foo</body></message>");
- let r = c.decode(&mut b);
- assert!(match r {
- Ok(Some(Packet::Stanza(_))) => true,
- _ => false,
- });
- }
-}
@@ -1,193 +0,0 @@
-//! `XmppStream` provides encoding/decoding for XMPP
-
-use futures::{
- sink::{Send, SinkExt},
- stream::StreamExt,
- task::Poll,
- Sink, Stream,
-};
-use minidom::Element;
-use rand::{thread_rng, Rng};
-use std::pin::Pin;
-use std::task::Context;
-use tokio::io::{AsyncRead, AsyncWrite};
-use tokio_util::codec::Framed;
-use xmpp_parsers::{jid::Jid, ns, stream_features::StreamFeatures, Error as ParsersError};
-
-use crate::error::{Error, ProtocolError};
-use crate::proto::{Packet, XmppCodec};
-
-fn make_id() -> String {
- let id: u64 = thread_rng().gen();
- format!("{}", id)
-}
-
-pub(crate) fn add_stanza_id(mut stanza: Element, default_ns: &str) -> Element {
- if stanza.is("iq", default_ns)
- || stanza.is("message", default_ns)
- || stanza.is("presence", default_ns)
- {
- if stanza.attr("id").is_none() {
- stanza.set_attr("id", make_id());
- }
- }
-
- stanza
-}
-
-/// Wraps a binary stream (tokio's `AsyncRead + AsyncWrite`) to decode
-/// and encode XMPP packets.
-///
-/// Implements `Sink + Stream`
-pub struct XmppStream<S: AsyncRead + AsyncWrite + Unpin> {
- /// The local Jabber-Id
- pub jid: Jid,
- /// Codec instance
- pub stream: Framed<S, XmppCodec>,
- /// `<stream:features/>` for XMPP version 1.0
- pub stream_features: StreamFeatures,
- /// Root namespace
- ///
- /// This is different for either c2s, s2s, or component
- /// connections.
- pub ns: String,
- /// Stream `id` attribute
- pub id: String,
-}
-
-impl<S: AsyncRead + AsyncWrite + Unpin> XmppStream<S> {
- /// Constructor
- pub fn new(
- jid: Jid,
- stream: Framed<S, XmppCodec>,
- ns: String,
- id: String,
- stream_features: StreamFeatures,
- ) -> Self {
- XmppStream {
- jid,
- stream,
- stream_features,
- ns,
- id,
- }
- }
-
- /// Send a `<stream:stream>` start tag
- pub async fn start(stream: S, jid: Jid, ns: String) -> Result<Self, Error> {
- let mut stream = Framed::new(stream, XmppCodec::new());
- let attrs = [
- ("to".to_owned(), jid.domain().to_string()),
- ("version".to_owned(), "1.0".to_owned()),
- ("xmlns".to_owned(), ns.clone()),
- ("xmlns:stream".to_owned(), ns::STREAM.to_owned()),
- ]
- .iter()
- .cloned()
- .collect();
- stream.send(Packet::StreamStart(attrs)).await?;
-
- let stream_attrs;
- loop {
- match stream.next().await {
- Some(Ok(Packet::StreamStart(attrs))) => {
- stream_attrs = attrs;
- break;
- }
- Some(Ok(_)) => {}
- Some(Err(e)) => return Err(e.into()),
- None => return Err(Error::Disconnected),
- }
- }
-
- let stream_ns = stream_attrs
- .get("xmlns")
- .ok_or(ProtocolError::NoStreamNamespace)?
- .clone();
- let stream_id = stream_attrs
- .get("id")
- .ok_or(ProtocolError::NoStreamId)?
- .clone();
- if stream_ns == "jabber:client" && stream_attrs.get("version").is_some() {
- loop {
- match stream.next().await {
- Some(Ok(Packet::Stanza(stanza))) => {
- let stream_features = StreamFeatures::try_from(stanza)
- .map_err(|e| Error::Protocol(ParsersError::from(e).into()))?;
- return Ok(XmppStream::new(jid, stream, ns, stream_id, stream_features));
- }
- Some(Ok(_)) => {}
- Some(Err(e)) => return Err(e.into()),
- None => return Err(Error::Disconnected),
- }
- }
- } else {
- // FIXME: huge hack, shouldnβt be an element!
- return Ok(XmppStream::new(
- jid,
- stream,
- ns,
- stream_id.clone(),
- StreamFeatures::default(),
- ));
- }
- }
-
- /// Unwraps the inner stream
- pub fn into_inner(self) -> S {
- self.stream.into_inner()
- }
-
- /// Re-run `start()`
- pub async fn restart(self) -> Result<Self, Error> {
- let stream = self.stream.into_inner();
- Self::start(stream, self.jid, self.ns).await
- }
-}
-
-impl<S: AsyncRead + AsyncWrite + Unpin> XmppStream<S> {
- /// Convenience method
- pub fn send_stanza<E: Into<Element>>(&mut self, e: E) -> Send<Self, Packet> {
- self.send(Packet::Stanza(e.into()))
- }
-}
-
-/// Proxy to self.stream
-impl<S: AsyncRead + AsyncWrite + Unpin> Sink<Packet> for XmppStream<S> {
- type Error = crate::Error;
-
- fn poll_ready(self: Pin<&mut Self>, _ctx: &mut Context) -> Poll<Result<(), Self::Error>> {
- // Pin::new(&mut self.stream).poll_ready(ctx)
- // .map_err(|e| e.into())
- Poll::Ready(Ok(()))
- }
-
- fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
- Pin::new(&mut self.stream)
- .start_send(item)
- .map_err(|e| e.into())
- }
-
- fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
- Pin::new(&mut self.stream)
- .poll_flush(cx)
- .map_err(|e| e.into())
- }
-
- fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
- Pin::new(&mut self.stream)
- .poll_close(cx)
- .map_err(|e| e.into())
- }
-}
-
-/// Proxy to self.stream
-impl<S: AsyncRead + AsyncWrite + Unpin> Stream for XmppStream<S> {
- type Item = Result<Packet, crate::Error>;
-
- fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
- Pin::new(&mut self.stream)
- .poll_next(cx)
- .map(|result| result.map(|result| result.map_err(|e| e.into())))
- }
-}
@@ -127,6 +127,10 @@ impl<Io: AsyncBufRead + AsyncWrite> RawXmlStream<Io> {
*this.parser.parser_pinned() = rxml::Parser::default();
*this.writer = Self::new_writer(this.stream_ns);
}
+
+ pub(super) fn into_inner(self) -> Io {
+ self.parser.into_inner().0
+ }
}
impl<Io> RawXmlStream<Io> {
@@ -83,4 +83,19 @@ impl<Io: AsyncBufRead + AsyncWrite + Unpin> PendingFeaturesRecv<Io> {
let features = ReadXso::read_from(Pin::new(&mut stream)).await?;
Ok((features, XmlStream::wrap(stream)))
}
+
+ /// Skip receiving the responder's stream features.
+ ///
+ /// The stream can be used for exchanging stream-level elements (stanzas
+ /// or "nonzas"). The Rust type for these elements must be given as type
+ /// parameter `T`.
+ ///
+ /// **Note:** Using this on RFC 6120 compliant streams where stream
+ /// features **are** sent after the stream header will cause a parse error
+ /// down the road (because the feature stream element cannot be handled).
+ /// The only place where this is useful is in
+ /// [XEP-0114](https://xmpp.org/extensions/xep-0114.html) connections.
+ pub fn skip_features<T: FromXml + AsXml>(self) -> XmlStream<Io, T> {
+ XmlStream::wrap(self.stream)
+ }
}
@@ -54,7 +54,8 @@ mod responder;
mod tests;
pub(crate) mod xmpp;
-use self::common::{RawXmlStream, ReadXsoError, ReadXsoState, StreamHeader};
+pub use self::common::StreamHeader;
+use self::common::{RawXmlStream, ReadXsoError, ReadXsoState};
pub use self::initiator::{InitiatingStream, PendingFeaturesRecv};
pub use self::responder::{AcceptedStream, PendingFeaturesSend};
pub use self::xmpp::XmppStreamElement;
@@ -235,6 +236,12 @@ impl<Io: AsyncBufRead + AsyncWrite + Unpin, T: FromXml + AsXml> XmlStream<Io, T>
let header = StreamHeader::recv(Pin::new(&mut stream)).await?;
Ok(AcceptedStream { stream, header })
}
+
+ /// Discard all XML state and return the inner I/O object.
+ pub fn into_inner(self) -> Io {
+ self.assert_retypable();
+ self.inner.into_inner()
+ }
}
impl<Io: AsyncBufRead, T: FromXml + AsXml> Stream for XmlStream<Io, T> {
@@ -6,7 +6,7 @@
use xso::{AsXml, FromXml};
-use xmpp_parsers::{iq::Iq, message::Message, presence::Presence, sasl, starttls};
+use xmpp_parsers::{component, iq::Iq, message::Message, presence::Presence, sasl, starttls};
/// Any valid XMPP stream-level element.
#[derive(FromXml, AsXml, Debug)]
@@ -31,4 +31,8 @@ pub enum XmppStreamElement {
/// STARTTLS-related nonza
#[xml(transparent)]
Starttls(starttls::Nonza),
+
+ /// Component protocol nonzas
+ #[xml(transparent)]
+ ComponentHandshake(component::Handshake),
}
@@ -7,10 +7,8 @@
use futures::StreamExt;
use tokio_xmpp::connect::ServerConnector;
use tokio_xmpp::{
- parsers::{
- disco::DiscoInfoQuery, iq::Iq, message::Message, presence::Presence, roster::Roster,
- },
- Event as TokioXmppEvent,
+ parsers::{disco::DiscoInfoQuery, iq::Iq, roster::Roster},
+ Event as TokioXmppEvent, Stanza,
};
use crate::{iq, message, presence, Agent, Event};
@@ -46,24 +44,17 @@ pub async fn wait_for_events<C: ServerConnector>(agent: &mut Agent<C>) -> Vec<Ev
TokioXmppEvent::Disconnected(e) => {
events.push(Event::Disconnected(e));
}
- TokioXmppEvent::Stanza(elem) => {
- if elem.is("iq", "jabber:client") {
- let iq = Iq::try_from(elem).unwrap();
- let new_events = iq::handle_iq(agent, iq).await;
- events.extend(new_events);
- } else if elem.is("message", "jabber:client") {
- let message = Message::try_from(elem).unwrap();
- let new_events = message::receive::handle_message(agent, message).await;
- events.extend(new_events);
- } else if elem.is("presence", "jabber:client") {
- let presence = Presence::try_from(elem).unwrap();
- let new_events = presence::receive::handle_presence(agent, presence).await;
- events.extend(new_events);
- } else if elem.is("error", "http://etherx.jabber.org/streams") {
- println!("Received a fatal stream error: {}", String::from(&elem));
- } else {
- panic!("Unknown stanza: {}", String::from(&elem));
- }
+ TokioXmppEvent::Stanza(Stanza::Iq(iq)) => {
+ let new_events = iq::handle_iq(agent, iq).await;
+ events.extend(new_events);
+ }
+ TokioXmppEvent::Stanza(Stanza::Message(message)) => {
+ let new_events = message::receive::handle_message(agent, message).await;
+ events.extend(new_events);
+ }
+ TokioXmppEvent::Stanza(Stanza::Presence(presence)) => {
+ let new_events = presence::receive::handle_presence(agent, presence).await;
+ events.extend(new_events);
}
}