Detailed changes
@@ -1,6 +1,6 @@
[package]
name = "tokio-xmpp"
-version = "1.0.1"
+version = "2.0.0"
authors = ["Astro <astro@spaceboyz.net>", "Emmanuel Gil Peyrot <linkmauve@linkmauve.fr>", "pep <pep+code@bouah.net>", "O01eg <o01eg@yandex.ru>"]
description = "Asynchronous XMPP for Rust with tokio"
license = "MPL-2.0"
@@ -12,17 +12,16 @@ keywords = ["xmpp", "tokio"]
edition = "2018"
[dependencies]
-bytes = "0.4"
-futures = "0.1"
+bytes = "0.5"
+futures = "0.3"
idna = "0.2"
log = "0.4"
native-tls = "0.2"
sasl = "0.4"
-tokio = "0.1"
-tokio-codec = "0.1"
-trust-dns-resolver = "0.12"
-trust-dns-proto = "0.8"
-tokio-io = "0.1"
-tokio-tls = "0.2"
+tokio = { version = "0.2", features = ["net", "stream", "rt-util", "rt-threaded", "macros"] }
+tokio-util = { version = "0.2", features = ["codec"] }
+tokio-tls = "0.3"
+trust-dns-resolver = "0.19"
+trust-dns-proto = "0.19"
xml5ever = "0.16"
xmpp-parsers = "0.17"
@@ -1,9 +1,8 @@
-use futures::{future, Sink, Stream};
+use futures::stream::StreamExt;
use std::convert::TryFrom;
use std::env::args;
use std::process::exit;
-use tokio::runtime::current_thread::Runtime;
-use tokio_xmpp::{xmpp_codec::Packet, Client};
+use tokio_xmpp::Client;
use xmpp_parsers::{
disco::{DiscoInfoQuery, DiscoInfoResult},
iq::{Iq, IqType},
@@ -12,70 +11,55 @@ use xmpp_parsers::{
Element, Jid,
};
-fn main() {
+#[tokio::main]
+async fn main() {
let args: Vec<String> = args().collect();
if args.len() != 4 {
println!("Usage: {} <jid> <password> <target>", args[0]);
exit(1);
}
let jid = &args[1];
- let password = &args[2];
+ let password = args[2].clone();
let target = &args[3];
- // tokio_core context
- let mut rt = Runtime::new().unwrap();
// Client instance
- let client = Client::new(jid, password).unwrap();
+ let mut client = Client::new(jid, password).unwrap();
- // Make the two interfaces for sending and receiving independent
- // of each other so we can move one into a closure.
- let (mut sink, stream) = client.split();
- // Wrap sink in Option so that we can take() it for the send(self)
- // to consume and return it back when ready.
- let mut send = move |packet| {
- sink.start_send(packet).expect("start_send");
- };
// Main loop, processes events
let mut wait_for_stream_end = false;
- let done = stream.for_each(|event| {
- if wait_for_stream_end {
- /* Do Nothing. */
- } else if event.is_online() {
- println!("Online!");
+ let mut stream_ended = false;
+ while !stream_ended {
+ if let Some(event) = client.next().await {
+ if wait_for_stream_end {
+ /* Do Nothing. */
+ } else if event.is_online() {
+ println!("Online!");
- let target_jid: Jid = target.clone().parse().unwrap();
- let iq = make_disco_iq(target_jid);
- println!("Sending disco#info request to {}", target.clone());
- println!(">> {}", String::from(&iq));
- send(Packet::Stanza(iq));
- } else if let Some(stanza) = event.into_stanza() {
- if stanza.is("iq", "jabber:client") {
- let iq = Iq::try_from(stanza).unwrap();
- if let IqType::Result(Some(payload)) = iq.payload {
- if payload.is("query", ns::DISCO_INFO) {
- if let Ok(disco_info) = DiscoInfoResult::try_from(payload) {
- for ext in disco_info.extensions {
- if let Ok(server_info) = ServerInfo::try_from(ext) {
- print_server_info(server_info);
- wait_for_stream_end = true;
- send(Packet::StreamEnd);
+ let target_jid: Jid = target.clone().parse().unwrap();
+ let iq = make_disco_iq(target_jid);
+ println!("Sending disco#info request to {}", target.clone());
+ println!(">> {}", String::from(&iq));
+ client.send_stanza(iq).await.unwrap();
+ } else if let Some(stanza) = event.into_stanza() {
+ if stanza.is("iq", "jabber:client") {
+ let iq = Iq::try_from(stanza).unwrap();
+ if let IqType::Result(Some(payload)) = iq.payload {
+ if payload.is("query", ns::DISCO_INFO) {
+ if let Ok(disco_info) = DiscoInfoResult::try_from(payload) {
+ for ext in disco_info.extensions {
+ if let Ok(server_info) = ServerInfo::try_from(ext) {
+ print_server_info(server_info);
+ }
}
}
}
+ wait_for_stream_end = true;
+ client.send_end().await.unwrap();
}
}
}
- }
-
- Box::new(future::ok(()))
- });
-
- // Start polling `done`
- match rt.block_on(done) {
- Ok(_) => (),
- Err(e) => {
- println!("Fatal: {}", e);
- ()
+ } else {
+ stream_ended = true;
}
}
}
@@ -1,12 +1,12 @@
-use futures::{future, Future, Sink, Stream};
+use futures::stream::StreamExt;
use std::convert::TryFrom;
use std::env::args;
use std::fs::{create_dir_all, File};
use std::io::{self, Write};
use std::process::exit;
use std::str::FromStr;
-use tokio::runtime::current_thread::Runtime;
-use tokio_xmpp::{Client, Packet};
+use tokio;
+use tokio_xmpp::Client;
use xmpp_parsers::{
avatar::{Data as AvatarData, Metadata as AvatarMetadata},
caps::{compute_disco, hash_caps, Caps},
@@ -22,162 +22,153 @@ use xmpp_parsers::{
NodeName,
},
stanza_error::{DefinedCondition, ErrorType, StanzaError},
- Jid,
+ Element, Jid,
};
-fn main() {
+#[tokio::main]
+async fn main() {
let args: Vec<String> = args().collect();
if args.len() != 3 {
println!("Usage: {} <jid> <password>", args[0]);
exit(1);
}
let jid = &args[1];
- let password = &args[2];
+ let password = args[2].clone();
- // tokio_core context
- let mut rt = Runtime::new().unwrap();
// Client instance
- let client = Client::new(jid, password).unwrap();
-
- // Make the two interfaces for sending and receiving independent
- // of each other so we can move one into a closure.
- let (sink, stream) = client.split();
-
- // Create outgoing pipe
- let (mut tx, rx) = futures::unsync::mpsc::unbounded();
- rt.spawn(
- rx.forward(sink.sink_map_err(|_| panic!("Pipe")))
- .map(|(rx, mut sink)| {
- drop(rx);
- let _ = sink.close();
- })
- .map_err(|e| {
- panic!("Send error: {:?}", e);
- }),
- );
+ let mut client = Client::new(jid, password).unwrap();
let disco_info = make_disco();
// Main loop, processes events
let mut wait_for_stream_end = false;
- let done = stream.for_each(move |event| {
- // Helper function to send an iq error.
- let mut send_error = |to, id, type_, condition, text: &str| {
- let error = StanzaError::new(type_, condition, "en", text);
- let iq = Iq::from_error(id, error).with_to(to);
- tx.start_send(Packet::Stanza(iq.into())).unwrap();
- };
-
- if wait_for_stream_end {
- /* Do nothing */
- } else if event.is_online() {
- println!("Online!");
-
- let caps = get_disco_caps(&disco_info, "https://gitlab.com/xmpp-rs/tokio-xmpp");
- let presence = make_presence(caps);
- tx.start_send(Packet::Stanza(presence.into())).unwrap();
- } else if let Some(stanza) = event.into_stanza() {
- if stanza.is("iq", "jabber:client") {
- let iq = Iq::try_from(stanza).unwrap();
- if let IqType::Get(payload) = iq.payload {
- if payload.is("query", ns::DISCO_INFO) {
- let query = DiscoInfoQuery::try_from(payload);
- match query {
- Ok(query) => {
- let mut disco = disco_info.clone();
- disco.node = query.node;
- let iq =
- Iq::from_result(iq.id, Some(disco)).with_to(iq.from.unwrap());
- tx.start_send(Packet::Stanza(iq.into())).unwrap();
+ let mut stream_ended = false;
+ while !stream_ended {
+ if let Some(event) = client.next().await {
+ if wait_for_stream_end {
+ /* Do nothing */
+ } else if event.is_online() {
+ println!("Online!");
+
+ let caps = get_disco_caps(&disco_info, "https://gitlab.com/xmpp-rs/tokio-xmpp");
+ let presence = make_presence(caps);
+ client.send_stanza(presence.into()).await.unwrap();
+ } else if let Some(stanza) = event.into_stanza() {
+ if stanza.is("iq", "jabber:client") {
+ let iq = Iq::try_from(stanza).unwrap();
+ if let IqType::Get(payload) = iq.payload {
+ if payload.is("query", ns::DISCO_INFO) {
+ let query = DiscoInfoQuery::try_from(payload);
+ match query {
+ Ok(query) => {
+ let mut disco = disco_info.clone();
+ disco.node = query.node;
+ let iq = Iq::from_result(iq.id, Some(disco))
+ .with_to(iq.from.unwrap());
+ client.send_stanza(iq.into()).await.unwrap();
+ }
+ Err(err) => client
+ .send_stanza(make_error(
+ iq.from.unwrap(),
+ iq.id,
+ ErrorType::Modify,
+ DefinedCondition::BadRequest,
+ &format!("{}", err),
+ ))
+ .await
+ .unwrap(),
}
- Err(err) => {
- send_error(
+ } else {
+ // We MUST answer unhandled get iqs with a service-unavailable error.
+ client
+ .send_stanza(make_error(
iq.from.unwrap(),
iq.id,
- ErrorType::Modify,
- DefinedCondition::BadRequest,
- &format!("{}", err),
- );
- }
+ ErrorType::Cancel,
+ DefinedCondition::ServiceUnavailable,
+ "No handler defined for this kind of iq.",
+ ))
+ .await
+ .unwrap();
}
- } else {
- // We MUST answer unhandled get iqs with a service-unavailable error.
- send_error(
- iq.from.unwrap(),
- iq.id,
- ErrorType::Cancel,
- DefinedCondition::ServiceUnavailable,
- "No handler defined for this kind of iq.",
- );
- }
- } else if let IqType::Result(Some(payload)) = iq.payload {
- if payload.is("pubsub", ns::PUBSUB) {
- let pubsub = PubSub::try_from(payload).unwrap();
- let from = iq.from.clone().unwrap_or(Jid::from_str(jid).unwrap());
- handle_iq_result(pubsub, &from);
+ } else if let IqType::Result(Some(payload)) = iq.payload {
+ if payload.is("pubsub", ns::PUBSUB) {
+ let pubsub = PubSub::try_from(payload).unwrap();
+ let from = iq.from.clone().unwrap_or(Jid::from_str(jid).unwrap());
+ handle_iq_result(pubsub, &from);
+ }
+ } else if let IqType::Set(_) = iq.payload {
+ // We MUST answer unhandled set iqs with a service-unavailable error.
+ client
+ .send_stanza(make_error(
+ iq.from.unwrap(),
+ iq.id,
+ ErrorType::Cancel,
+ DefinedCondition::ServiceUnavailable,
+ "No handler defined for this kind of iq.",
+ ))
+ .await
+ .unwrap();
}
- } else if let IqType::Set(_) = iq.payload {
- // We MUST answer unhandled set iqs with a service-unavailable error.
- send_error(
- iq.from.unwrap(),
- iq.id,
- ErrorType::Cancel,
- DefinedCondition::ServiceUnavailable,
- "No handler defined for this kind of iq.",
- );
- }
- } else if stanza.is("message", "jabber:client") {
- let message = Message::try_from(stanza).unwrap();
- let from = message.from.clone().unwrap();
- if let Some(body) = message.get_best_body(vec!["en"]) {
- if body.1 .0 == "die" {
- println!("Secret die command triggered by {}", from);
- wait_for_stream_end = true;
- tx.start_send(Packet::StreamEnd).unwrap();
+ } else if stanza.is("message", "jabber:client") {
+ let message = Message::try_from(stanza).unwrap();
+ let from = message.from.clone().unwrap();
+ if let Some(body) = message.get_best_body(vec!["en"]) {
+ if body.0 == "die" {
+ println!("Secret die command triggered by {}", from);
+ wait_for_stream_end = true;
+ client.send_end().await.unwrap();
+ }
}
- }
- for child in message.payloads {
- if child.is("event", ns::PUBSUB_EVENT) {
- let event = PubSubEvent::try_from(child).unwrap();
- if let PubSubEvent::PublishedItems { node, items } = event {
- if node.0 == ns::AVATAR_METADATA {
- for item in items.into_iter() {
- let payload = item.payload.clone().unwrap();
- if payload.is("metadata", ns::AVATAR_METADATA) {
- // TODO: do something with these metadata.
- let _metadata = AvatarMetadata::try_from(payload).unwrap();
- println!(
- "[1m{}[0m has published an avatar, downloading...",
- from.clone()
- );
- let iq = download_avatar(from.clone());
- tx.start_send(Packet::Stanza(iq.into())).unwrap();
+ for child in message.payloads {
+ if child.is("event", ns::PUBSUB_EVENT) {
+ let event = PubSubEvent::try_from(child).unwrap();
+ if let PubSubEvent::PublishedItems { node, items } = event {
+ if node.0 == ns::AVATAR_METADATA {
+ for item in items.into_iter() {
+ let payload = item.payload.clone().unwrap();
+ if payload.is("metadata", ns::AVATAR_METADATA) {
+ // TODO: do something with these metadata.
+ let _metadata =
+ AvatarMetadata::try_from(payload).unwrap();
+ println!(
+ "[1m{}[0m has published an avatar, downloading...",
+ from.clone()
+ );
+ let iq = download_avatar(from.clone());
+ client.send_stanza(iq.into()).await.unwrap();
+ }
}
}
}
}
}
+ } else if stanza.is("presence", "jabber:client") {
+ // Nothing to do here.
+ ()
+ } else {
+ panic!("Unknown stanza: {}", String::from(&stanza));
}
- } else if stanza.is("presence", "jabber:client") {
- // Nothing to do here.
- } else {
- panic!("Unknown stanza: {}", String::from(&stanza));
}
- }
-
- future::ok(())
- });
-
- // Start polling `done`
- match rt.block_on(done) {
- Ok(_) => (),
- Err(e) => {
- println!("Fatal: {}", e);
- ()
+ } else {
+ println!("stream_ended");
+ stream_ended = true;
}
}
}
+fn make_error(
+ to: Jid,
+ id: String,
+ type_: ErrorType,
+ condition: DefinedCondition,
+ text: &str,
+) -> Element {
+ let error = StanzaError::new(type_, condition, "en", text);
+ let iq = Iq::from_error(id, error).with_to(to);
+ iq.into()
+}
+
fn make_disco() -> DiscoInfoResult {
let identities = vec![Identity::new("client", "bot", "en", "tokio-xmpp")];
let features = vec![
@@ -235,6 +226,7 @@ fn handle_iq_result(pubsub: PubSub, from: &Jid) {
}
}
+// TODO: may use tokio?
fn save_avatar(from: &Jid, id: String, data: &[u8]) -> io::Result<()> {
let directory = format!("data/{}", from);
let filename = format!("data/{}/{}", from, id);
@@ -1,14 +1,15 @@
-use futures::{future, Future, Sink, Stream};
+use futures::stream::StreamExt;
use std::convert::TryFrom;
use std::env::args;
use std::process::exit;
-use tokio::runtime::current_thread::Runtime;
-use tokio_xmpp::{Client, Packet};
+use tokio;
+use tokio_xmpp::Client;
use xmpp_parsers::message::{Body, Message, MessageType};
use xmpp_parsers::presence::{Presence, Show as PresenceShow, Type as PresenceType};
use xmpp_parsers::{Element, Jid};
-fn main() {
+#[tokio::main]
+async fn main() {
let args: Vec<String> = args().collect();
if args.len() != 3 {
println!("Usage: {} <jid> <password>", args[0]);
@@ -17,72 +18,50 @@ fn main() {
let jid = &args[1];
let password = &args[2];
- // tokio_core context
- let mut rt = Runtime::new().unwrap();
// Client instance
- let client = Client::new(jid, password).unwrap();
-
- // Make the two interfaces for sending and receiving independent
- // of each other so we can move one into a closure.
- let (sink, stream) = client.split();
-
- // Create outgoing pipe
- let (mut tx, rx) = futures::unsync::mpsc::unbounded();
- rt.spawn(
- rx.forward(sink.sink_map_err(|_| panic!("Pipe")))
- .map(|(rx, mut sink)| {
- drop(rx);
- let _ = sink.close();
- })
- .map_err(|e| {
- panic!("Send error: {:?}", e);
- }),
- );
+ let mut client = Client::new(jid, password.to_owned()).unwrap();
+ client.set_reconnect(true);
// Main loop, processes events
let mut wait_for_stream_end = false;
- let done = stream.for_each(move |event| {
- if wait_for_stream_end {
- /* Do nothing */
- } else if event.is_online() {
- let jid = event
- .get_jid()
- .map(|jid| format!("{}", jid))
- .unwrap_or("unknown".to_owned());
- println!("Online at {}", jid);
+ let mut stream_ended = false;
+ while !stream_ended {
+ if let Some(event) = client.next().await {
+ println!("event: {:?}", event);
+ if wait_for_stream_end {
+ /* Do nothing */
+ } else if event.is_online() {
+ let jid = event
+ .get_jid()
+ .map(|jid| format!("{}", jid))
+ .unwrap_or("unknown".to_owned());
+ println!("Online at {}", jid);
- let presence = make_presence();
- tx.start_send(Packet::Stanza(presence)).unwrap();
- } else if let Some(message) = event
- .into_stanza()
- .and_then(|stanza| Message::try_from(stanza).ok())
- {
- match (message.from, message.bodies.get("")) {
- (Some(ref from), Some(ref body)) if body.0 == "die" => {
- println!("Secret die command triggered by {}", from);
- wait_for_stream_end = true;
- tx.start_send(Packet::StreamEnd).unwrap();
- }
- (Some(ref from), Some(ref body)) => {
- if message.type_ != MessageType::Error {
- // This is a message we'll echo
- let reply = make_reply(from.clone(), &body.0);
- tx.start_send(Packet::Stanza(reply)).unwrap();
+ let presence = make_presence();
+ client.send_stanza(presence).await.unwrap();
+ } else if let Some(message) = event
+ .into_stanza()
+ .and_then(|stanza| Message::try_from(stanza).ok())
+ {
+ match (message.from, message.bodies.get("")) {
+ (Some(ref from), Some(ref body)) if body.0 == "die" => {
+ println!("Secret die command triggered by {}", from);
+ wait_for_stream_end = true;
+ client.send_end().await.unwrap();
+ }
+ (Some(ref from), Some(ref body)) => {
+ if message.type_ != MessageType::Error {
+ // This is a message we'll echo
+ let reply = make_reply(from.clone(), &body.0);
+ client.send_stanza(reply).await.unwrap();
+ }
}
+ _ => {}
}
- _ => {}
}
- }
-
- future::ok(())
- });
-
- // Start polling `done`
- match rt.block_on(done) {
- Ok(_) => (),
- Err(e) => {
- println!("Fatal: {}", e);
- ()
+ } else {
+ println!("stream_ended");
+ stream_ended = true;
}
}
}
@@ -1,15 +1,15 @@
-use futures::{future, Sink, Stream};
+use futures::stream::StreamExt;
use std::convert::TryFrom;
use std::env::args;
use std::process::exit;
use std::str::FromStr;
-use tokio::runtime::current_thread::Runtime;
use tokio_xmpp::Component;
use xmpp_parsers::message::{Body, Message, MessageType};
use xmpp_parsers::presence::{Presence, Show as PresenceShow, Type as PresenceType};
use xmpp_parsers::{Element, Jid};
-fn main() {
+#[tokio::main]
+async fn main() {
let args: Vec<String> = args().collect();
if args.len() < 3 || args.len() > 5 {
println!("Usage: {} <jid> <password> [server] [port]", args[0]);
@@ -24,57 +24,38 @@ fn main() {
.unwrap_or("127.0.0.1".to_owned());
let port: u16 = args.get(4).unwrap().parse().unwrap_or(5347u16);
- // tokio_core context
- let mut rt = Runtime::new().unwrap();
// Component instance
println!("{} {} {} {}", jid, password, server, port);
- let component = Component::new(jid, password, server, port).unwrap();
+ let mut component = Component::new(jid, password, server, port).await.unwrap();
// Make the two interfaces for sending and receiving independent
// of each other so we can move one into a closure.
- println!("Got it: {}", component.jid.clone());
- let (mut sink, stream) = component.split();
- // Wrap sink in Option so that we can take() it for the send(self)
- // to consume and return it back when ready.
- let mut send = move |stanza| {
- sink.start_send(stanza).expect("start_send");
- };
- // Main loop, processes events
- let done = stream.for_each(|event| {
- if event.is_online() {
- println!("Online!");
+ println!("Online: {}", component.jid);
+
+ // TODO: replace these hardcoded JIDs
+ let presence = make_presence(
+ Jid::from_str("test@component.linkmauve.fr/coucou").unwrap(),
+ Jid::from_str("linkmauve@linkmauve.fr").unwrap(),
+ );
+ component.send_stanza(presence).await.unwrap();
- // TODO: replace these hardcoded JIDs
- let presence = make_presence(
- Jid::from_str("test@component.linkmauve.fr/coucou").unwrap(),
- Jid::from_str("linkmauve@linkmauve.fr").unwrap(),
- );
- send(presence);
- } else if let Some(message) = event
- .into_stanza()
- .and_then(|stanza| Message::try_from(stanza).ok())
- {
- // This is a message we'll echo
- match (message.from, message.bodies.get("")) {
- (Some(from), Some(body)) => {
- if message.type_ != MessageType::Error {
- let reply = make_reply(from, &body.0);
- send(reply);
+ // Main loop, processes events
+ loop {
+ if let Some(stanza) = component.next().await {
+ if let Some(message) = Message::try_from(stanza).ok() {
+ // This is a message we'll echo
+ match (message.from, message.bodies.get("")) {
+ (Some(from), Some(body)) => {
+ if message.type_ != MessageType::Error {
+ let reply = make_reply(from, &body.0);
+ component.send_stanza(reply).await.unwrap();
+ }
}
+ _ => (),
}
- _ => (),
}
- }
-
- Box::new(future::ok(()))
- });
-
- // Start polling `done`
- match rt.block_on(done) {
- Ok(_) => (),
- Err(e) => {
- println!("Fatal: {}", e);
- ()
+ } else {
+ break;
}
}
}
@@ -1,7 +1,4 @@
-use futures::{
- future::{err, ok, IntoFuture},
- Future, Poll, Stream,
-};
+use futures::stream::StreamExt;
use sasl::client::mechanisms::{Anonymous, Plain, Scram};
use sasl::client::Mechanism;
use sasl::common::scram::{Sha1, Sha256};
@@ -9,7 +6,7 @@ use sasl::common::Credentials;
use std::collections::HashSet;
use std::convert::TryFrom;
use std::str::FromStr;
-use tokio_io::{AsyncRead, AsyncWrite};
+use tokio::io::{AsyncRead, AsyncWrite};
use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
use crate::xmpp_codec::Packet;
@@ -18,109 +15,70 @@ use crate::{AuthError, Error, ProtocolError};
const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
-pub struct ClientAuth<S: AsyncRead + AsyncWrite> {
- future: Box<dyn Future<Item = XMPPStream<S>, Error = Error>>,
-}
-
-impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> {
- pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
- let local_mechs: Vec<Box<dyn Fn() -> Box<dyn Mechanism>>> = vec![
- Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
- Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
- Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
- Box::new(|| Box::new(Anonymous::new())),
- ];
+pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
+ mut stream: XMPPStream<S>,
+ creds: Credentials,
+) -> Result<S, Error> {
+ let local_mechs: Vec<Box<dyn Fn() -> Box<dyn Mechanism + Send + Sync> + Send>> = vec![
+ Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
+ Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
+ Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
+ Box::new(|| Box::new(Anonymous::new())),
+ ];
- let remote_mechs: HashSet<String> = stream
- .stream_features
- .get_child("mechanisms", NS_XMPP_SASL)
- .ok_or(AuthError::NoMechanism)?
- .children()
- .filter(|child| child.is("mechanism", NS_XMPP_SASL))
- .map(|mech_el| mech_el.text())
- .collect();
+ let remote_mechs: HashSet<String> = stream
+ .stream_features
+ .get_child("mechanisms", NS_XMPP_SASL)
+ .ok_or(AuthError::NoMechanism)?
+ .children()
+ .filter(|child| child.is("mechanism", NS_XMPP_SASL))
+ .map(|mech_el| mech_el.text())
+ .collect();
- for local_mech in local_mechs {
- let mut mechanism = local_mech();
- if remote_mechs.contains(mechanism.name()) {
- let initial = mechanism.initial().map_err(AuthError::Sasl)?;
- let mechanism_name =
- XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
+ for local_mech in local_mechs {
+ let mut mechanism = local_mech();
+ if remote_mechs.contains(mechanism.name()) {
+ let initial = mechanism.initial().map_err(AuthError::Sasl)?;
+ let mechanism_name =
+ XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
- let send_initial = Box::new(stream.send_stanza(Auth {
+ stream
+ .send_stanza(Auth {
mechanism: mechanism_name,
data: initial,
- }))
- .map_err(Error::Io);
- let future = Box::new(
- send_initial
- .and_then(|stream| Self::handle_challenge(stream, mechanism))
- .and_then(|stream| stream.restart()),
- );
- return Ok(ClientAuth { future });
- }
- }
+ })
+ .await?;
- Err(AuthError::NoMechanism)?
- }
+ loop {
+ match stream.next().await {
+ Some(Ok(Packet::Stanza(stanza))) => {
+ if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
+ let response = mechanism
+ .response(&challenge.data)
+ .map_err(|e| AuthError::Sasl(e))?;
- fn handle_challenge(
- stream: XMPPStream<S>,
- mut mechanism: Box<dyn Mechanism>,
- ) -> Box<dyn Future<Item = XMPPStream<S>, Error = Error>> {
- Box::new(
- stream
- .into_future()
- .map_err(|(e, _stream)| e.into())
- .and_then(|(stanza, stream)| {
- match stanza {
- Some(Packet::Stanza(stanza)) => {
- if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
- let response = mechanism.response(&challenge.data);
- Box::new(
- response
- .map_err(|e| AuthError::Sasl(e).into())
- .into_future()
- .and_then(|response| {
- // Send response and loop
- stream
- .send_stanza(Response { data: response })
- .map_err(Error::Io)
- .and_then(|stream| {
- Self::handle_challenge(stream, mechanism)
- })
- }),
- )
- } else if let Ok(_) = Success::try_from(stanza.clone()) {
- Box::new(ok(stream))
- } else if let Ok(failure) = Failure::try_from(stanza.clone()) {
- Box::new(err(Error::Auth(AuthError::Fail(
- failure.defined_condition,
- ))))
- } else if stanza.name() == "failure" {
- // Workaround for https://gitlab.com/xmpp-rs/xmpp-parsers/merge_requests/1
- Box::new(err(Error::Auth(AuthError::Sasl("failure".to_string()))))
- } else {
- // ignore and loop
- Self::handle_challenge(stream, mechanism)
- }
- }
- Some(_) => {
+ // Send response and loop
+ stream.send_stanza(Response { data: response }).await?;
+ } else if let Ok(_) = Success::try_from(stanza.clone()) {
+ return Ok(stream.into_inner());
+ } else if let Ok(failure) = Failure::try_from(stanza.clone()) {
+ return Err(Error::Auth(AuthError::Fail(failure.defined_condition)));
+ } else if stanza.name() == "failure" {
+ // Workaround for https://gitlab.com/xmpp-rs/xmpp-parsers/merge_requests/1
+ return Err(Error::Auth(AuthError::Sasl("failure".to_string())));
+ } else {
// ignore and loop
- Self::handle_challenge(stream, mechanism)
}
- None => Box::new(err(Error::Disconnected)),
}
- }),
- )
+ Some(Ok(_)) => {
+ // ignore and loop
+ }
+ Some(Err(e)) => return Err(e),
+ None => return Err(Error::Disconnected),
+ }
+ }
+ }
}
-}
-
-impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
- type Item = XMPPStream<S>;
- type Error = Error;
- fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
- self.future.poll()
- }
+ Err(AuthError::NoMechanism.into())
}
@@ -1,7 +1,7 @@
-use futures::{sink, Async, Future, Poll, Stream};
+use futures::stream::StreamExt;
use std::convert::TryFrom;
-use std::mem::replace;
-use tokio_io::{AsyncRead, AsyncWrite};
+use std::marker::Unpin;
+use tokio::io::{AsyncRead, AsyncWrite};
use xmpp_parsers::bind::{BindQuery, BindResponse};
use xmpp_parsers::iq::{Iq, IqType};
use xmpp_parsers::Jid;
@@ -13,90 +13,43 @@ use crate::{Error, ProtocolError};
const NS_XMPP_BIND: &str = "urn:ietf:params:xml:ns:xmpp-bind";
const BIND_REQ_ID: &str = "resource-bind";
-pub enum ClientBind<S: AsyncWrite> {
- Unsupported(XMPPStream<S>),
- WaitSend(sink::Send<XMPPStream<S>>),
- WaitRecv(XMPPStream<S>),
- Invalid,
-}
-
-impl<S: AsyncWrite> ClientBind<S> {
- /// Consumes and returns the stream to express that you cannot use
- /// the stream for anything else until the resource binding
- /// req/resp are done.
- pub fn new(stream: XMPPStream<S>) -> Self {
- match stream.stream_features.get_child("bind", NS_XMPP_BIND) {
- None =>
+pub async fn bind<S: AsyncRead + AsyncWrite + Unpin>(
+ mut stream: XMPPStream<S>,
+) -> Result<XMPPStream<S>, Error> {
+ match stream.stream_features.get_child("bind", NS_XMPP_BIND) {
+ None => {
// No resource binding available,
// return the (probably // usable) stream immediately
- {
- ClientBind::Unsupported(stream)
- }
- Some(_) => {
- let resource;
- if let Jid::Full(jid) = stream.jid.clone() {
- resource = Some(jid.resource);
- } else {
- resource = None;
- }
- let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource));
- let send = stream.send_stanza(iq);
- ClientBind::WaitSend(send)
- }
+ return Ok(stream);
}
- }
-}
+ Some(_) => {
+ let resource = if let Jid::Full(jid) = stream.jid.clone() {
+ Some(jid.resource)
+ } else {
+ None
+ };
+ let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource));
+ stream.send_stanza(iq).await?;
-impl<S: AsyncRead + AsyncWrite> Future for ClientBind<S> {
- type Item = XMPPStream<S>;
- type Error = Error;
-
- fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
- let state = replace(self, ClientBind::Invalid);
-
- match state {
- ClientBind::Unsupported(stream) => Ok(Async::Ready(stream)),
- ClientBind::WaitSend(mut send) => match send.poll() {
- Ok(Async::Ready(stream)) => {
- replace(self, ClientBind::WaitRecv(stream));
- self.poll()
- }
- Ok(Async::NotReady) => {
- replace(self, ClientBind::WaitSend(send));
- Ok(Async::NotReady)
- }
- Err(e) => Err(e)?,
- },
- ClientBind::WaitRecv(mut stream) => match stream.poll() {
- Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => match Iq::try_from(stanza) {
- Ok(iq) => {
- if iq.id == BIND_REQ_ID {
- match iq.payload {
- IqType::Result(payload) => {
- payload
- .and_then(|payload| BindResponse::try_from(payload).ok())
- .map(|bind| stream.jid = bind.into());
- Ok(Async::Ready(stream))
- }
- _ => Err(ProtocolError::InvalidBindResponse)?,
+ loop {
+ match stream.next().await {
+ Some(Ok(Packet::Stanza(stanza))) => match Iq::try_from(stanza) {
+ Ok(iq) if iq.id == BIND_REQ_ID => match iq.payload {
+ IqType::Result(payload) => {
+ payload
+ .and_then(|payload| BindResponse::try_from(payload).ok())
+ .map(|bind| stream.jid = bind.into());
+ return Ok(stream);
}
- } else {
- Ok(Async::NotReady)
- }
- }
- _ => Ok(Async::NotReady),
- },
- Ok(Async::Ready(_)) => {
- replace(self, ClientBind::WaitRecv(stream));
- self.poll()
- }
- Ok(Async::NotReady) => {
- replace(self, ClientBind::WaitRecv(stream));
- Ok(Async::NotReady)
+ _ => return Err(ProtocolError::InvalidBindResponse.into()),
+ },
+ _ => {}
+ },
+ Some(Ok(_)) => {}
+ Some(Err(e)) => return Err(e),
+ None => return Err(Error::Disconnected),
}
- Err(e) => Err(e)?,
- },
- ClientBind::Invalid => unreachable!(),
+ }
}
}
}
@@ -1,28 +1,33 @@
-use futures::{done, Async, AsyncSink, Future, Poll, Sink, StartSend, Stream};
+use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream};
use idna;
use sasl::common::{ChannelBinding, Credentials};
use std::mem::replace;
+use std::pin::Pin;
use std::str::FromStr;
+use std::task::Context;
+use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
-use tokio_io::{AsyncRead, AsyncWrite};
+use tokio::task::JoinHandle;
+use tokio::task::LocalSet;
use tokio_tls::TlsStream;
-use xmpp_parsers::{Jid, JidParseError};
+use xmpp_parsers::{Element, Jid, JidParseError};
use super::event::Event;
-use super::happy_eyeballs::Connecter;
-use super::starttls::{StartTlsClient, NS_XMPP_TLS};
+use super::happy_eyeballs::connect;
+use super::starttls::{starttls, NS_XMPP_TLS};
use super::xmpp_codec::Packet;
use super::xmpp_stream;
use super::{Error, ProtocolError};
mod auth;
-use self::auth::ClientAuth;
mod bind;
-use self::bind::ClientBind;
/// XMPP client connection and state
pub struct Client {
state: ClientState,
+ jid: Jid,
+ password: String,
+ reconnect: bool,
}
type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>;
@@ -31,7 +36,7 @@ const NS_JABBER_CLIENT: &str = "jabber:client";
enum ClientState {
Invalid,
Disconnected,
- Connecting(Box<dyn Future<Item = XMPPStream, Error = Error>>),
+ Connecting(JoinHandle<Result<XMPPStream, Error>>, LocalSet),
Connected(XMPPStream),
}
@@ -40,87 +45,87 @@ impl Client {
///
/// Start polling the returned instance so that it will connect
/// and yield events.
- pub fn new(jid: &str, password: &str) -> Result<Self, JidParseError> {
+ pub fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, JidParseError> {
let jid = Jid::from_str(jid)?;
- let client = Self::new_with_jid(jid, password);
+ let client = Self::new_with_jid(jid, password.into());
Ok(client)
}
/// Start a new client given that the JID is already parsed.
- pub fn new_with_jid(jid: Jid, password: &str) -> Self {
- let password = password.to_owned();
- let connect = Self::make_connect(jid, password.clone());
+ pub fn new_with_jid(jid: Jid, password: String) -> Self {
+ let local = LocalSet::new();
+ let connect = local.spawn_local(Self::connect(jid.clone(), password.clone()));
let client = Client {
- state: ClientState::Connecting(Box::new(connect)),
+ jid,
+ password,
+ state: ClientState::Connecting(connect, local),
+ reconnect: false,
};
client
}
- fn make_connect(jid: Jid, password: String) -> impl Future<Item = XMPPStream, Error = Error> {
+ /// Set whether to reconnect (`true`) or end the stream (`false`)
+ /// when a connection to the server has ended.
+ pub fn set_reconnect(&mut self, reconnect: bool) -> &mut Self {
+ self.reconnect = reconnect;
+ self
+ }
+
+ async fn connect(jid: Jid, password: String) -> Result<XMPPStream, Error> {
let username = jid.clone().node().unwrap();
- let jid1 = jid.clone();
- let jid2 = jid.clone();
let password = password;
- done(idna::domain_to_ascii(&jid.domain()))
- .map_err(|_| Error::Idna)
- .and_then(|domain| {
- done(Connecter::from_lookup(
- &domain,
- Some("_xmpp-client._tcp"),
- 5222,
- ))
- })
- .flatten()
- .and_then(move |tcp_stream| {
- xmpp_stream::XMPPStream::start(tcp_stream, jid1, NS_JABBER_CLIENT.to_owned())
- })
- .and_then(|xmpp_stream| {
- if Self::can_starttls(&xmpp_stream) {
- Ok(Self::starttls(xmpp_stream))
- } else {
- Err(Error::Protocol(ProtocolError::NoTls))
- }
- })
- .flatten()
- .and_then(|tls_stream| XMPPStream::start(tls_stream, jid2, NS_JABBER_CLIENT.to_owned()))
- .and_then(
- move |xmpp_stream| done(Self::auth(xmpp_stream, username, password)), // TODO: flatten?
- )
- .and_then(|auth| auth)
- .and_then(|xmpp_stream| Self::bind(xmpp_stream))
- .and_then(|xmpp_stream| {
- // println!("Bound to {}", xmpp_stream.jid);
- Ok(xmpp_stream)
- })
- }
-
- fn can_starttls<S>(stream: &xmpp_stream::XMPPStream<S>) -> bool {
- stream
+ let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?;
+
+ let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), 5222).await?;
+
+ let xmpp_stream =
+ xmpp_stream::XMPPStream::start(tcp_stream, jid, NS_JABBER_CLIENT.to_owned()).await?;
+ let xmpp_stream = if Self::can_starttls(&xmpp_stream) {
+ Self::starttls(xmpp_stream).await?
+ } else {
+ return Err(Error::Protocol(ProtocolError::NoTls));
+ };
+
+ let xmpp_stream = Self::auth(xmpp_stream, username, password).await?;
+ let xmpp_stream = Self::bind(xmpp_stream).await?;
+ Ok(xmpp_stream)
+ }
+
+ fn can_starttls<S: AsyncRead + AsyncWrite + Unpin>(
+ xmpp_stream: &xmpp_stream::XMPPStream<S>,
+ ) -> bool {
+ xmpp_stream
.stream_features
.get_child("starttls", NS_XMPP_TLS)
.is_some()
}
- fn starttls<S: AsyncRead + AsyncWrite>(
- stream: xmpp_stream::XMPPStream<S>,
- ) -> StartTlsClient<S> {
- StartTlsClient::from_stream(stream)
+ async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
+ xmpp_stream: xmpp_stream::XMPPStream<S>,
+ ) -> Result<xmpp_stream::XMPPStream<TlsStream<S>>, Error> {
+ let jid = xmpp_stream.jid.clone();
+ let tls_stream = starttls(xmpp_stream).await?;
+ xmpp_stream::XMPPStream::start(tls_stream, jid, NS_JABBER_CLIENT.to_owned()).await
}
- fn auth<S: AsyncRead + AsyncWrite + 'static>(
- stream: xmpp_stream::XMPPStream<S>,
+ async fn auth<S: AsyncRead + AsyncWrite + Unpin + 'static>(
+ xmpp_stream: xmpp_stream::XMPPStream<S>,
username: String,
password: String,
- ) -> Result<ClientAuth<S>, Error> {
+ ) -> Result<xmpp_stream::XMPPStream<S>, Error> {
+ let jid = xmpp_stream.jid.clone();
let creds = Credentials::default()
.with_username(username)
.with_password(password)
.with_channel_binding(ChannelBinding::None);
- ClientAuth::new(stream, creds)
+ let stream = auth::auth(xmpp_stream, creds).await?;
+ xmpp_stream::XMPPStream::start(stream, jid, NS_JABBER_CLIENT.to_owned()).await
}
- fn bind<S: AsyncWrite>(stream: xmpp_stream::XMPPStream<S>) -> ClientBind<S> {
- ClientBind::new(stream)
+ async fn bind<S: Unpin + AsyncRead + AsyncWrite>(
+ stream: xmpp_stream::XMPPStream<S>,
+ ) -> Result<xmpp_stream::XMPPStream<S>, Error> {
+ bind::bind(stream).await
}
/// Get the client's bound JID (the one reported by the XMPP
@@ -131,102 +136,150 @@ impl Client {
_ => None,
}
}
+
+ /// Send stanza
+ pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> {
+ self.send(Packet::Stanza(stanza)).await
+ }
+
+ /// End connection
+ pub async fn send_end(&mut self) -> Result<(), Error> {
+ self.send(Packet::StreamEnd).await
+ }
}
impl Stream for Client {
type Item = Event;
- type Error = Error;
- fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let state = replace(&mut self.state, ClientState::Invalid);
match state {
- ClientState::Invalid => Err(Error::InvalidState),
- ClientState::Disconnected => Ok(Async::Ready(None)),
- ClientState::Connecting(mut connect) => match connect.poll() {
- Ok(Async::Ready(stream)) => {
- let jid = stream.jid.clone();
- self.state = ClientState::Connected(stream);
- Ok(Async::Ready(Some(Event::Online(jid))))
- }
- Ok(Async::NotReady) => {
- self.state = ClientState::Connecting(connect);
- Ok(Async::NotReady)
+ ClientState::Invalid => panic!("Invalid client state"),
+ ClientState::Disconnected if self.reconnect => {
+ // TODO: add timeout
+ let mut local = LocalSet::new();
+ let connect =
+ local.spawn_local(Self::connect(self.jid.clone(), self.password.clone()));
+ let _ = Pin::new(&mut local).poll(cx);
+ self.state = ClientState::Connecting(connect, local);
+ self.poll_next(cx)
+ }
+ ClientState::Disconnected => Poll::Ready(None),
+ ClientState::Connecting(mut connect, mut local) => {
+ match Pin::new(&mut connect).poll(cx) {
+ Poll::Ready(Ok(Ok(stream))) => {
+ let bound_jid = stream.jid.clone();
+ self.state = ClientState::Connected(stream);
+ Poll::Ready(Some(Event::Online(bound_jid)))
+ }
+ Poll::Ready(Ok(Err(e))) => {
+ self.state = ClientState::Disconnected;
+ return Poll::Ready(Some(Event::Disconnected(e.into())));
+ }
+ Poll::Ready(Err(e)) => {
+ self.state = ClientState::Disconnected;
+ panic!("connect task: {}", e);
+ }
+ Poll::Pending => {
+ let _ = Pin::new(&mut local).poll(cx);
+
+ self.state = ClientState::Connecting(connect, local);
+ Poll::Pending
+ }
}
- Err(e) => Err(e),
- },
+ }
ClientState::Connected(mut stream) => {
// Poll sink
- match stream.poll_complete() {
- Ok(Async::NotReady) => (),
- Ok(Async::Ready(())) => (),
- Err(e) => return Err(e)?,
+ match Pin::new(&mut stream).poll_ready(cx) {
+ Poll::Pending => (),
+ Poll::Ready(Ok(())) => (),
+ Poll::Ready(Err(e)) => {
+ self.state = ClientState::Disconnected;
+ return Poll::Ready(Some(Event::Disconnected(e.into())));
+ }
};
// Poll stream
- match stream.poll() {
- Ok(Async::Ready(None)) => {
+ match Pin::new(&mut stream).poll_next(cx) {
+ Poll::Ready(None) => {
// EOF
self.state = ClientState::Disconnected;
- Ok(Async::Ready(Some(Event::Disconnected)))
+ Poll::Ready(Some(Event::Disconnected(Error::Disconnected)))
}
- Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
+ Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => {
// Receive stanza
self.state = ClientState::Connected(stream);
- Ok(Async::Ready(Some(Event::Stanza(stanza))))
+ Poll::Ready(Some(Event::Stanza(stanza)))
}
- Ok(Async::Ready(Some(Packet::Text(_)))) => {
+ Poll::Ready(Some(Ok(Packet::Text(_)))) => {
// Ignore text between stanzas
self.state = ClientState::Connected(stream);
- Ok(Async::NotReady)
+ Poll::Pending
}
- Ok(Async::Ready(Some(Packet::StreamStart(_)))) => {
+ Poll::Ready(Some(Ok(Packet::StreamStart(_)))) => {
// <stream:stream>
- Err(ProtocolError::InvalidStreamStart.into())
+ self.state = ClientState::Disconnected;
+ Poll::Ready(Some(Event::Disconnected(
+ ProtocolError::InvalidStreamStart.into(),
+ )))
}
- Ok(Async::Ready(Some(Packet::StreamEnd))) => {
+ Poll::Ready(Some(Ok(Packet::StreamEnd))) => {
// End of stream: </stream:stream>
- Ok(Async::Ready(None))
+ self.state = ClientState::Disconnected;
+ Poll::Ready(Some(Event::Disconnected(Error::Disconnected)))
}
- Ok(Async::NotReady) => {
+ Poll::Pending => {
// Try again later
self.state = ClientState::Connected(stream);
- Ok(Async::NotReady)
+ Poll::Pending
+ }
+ Poll::Ready(Some(Err(e))) => {
+ self.state = ClientState::Disconnected;
+ Poll::Ready(Some(Event::Disconnected(e.into())))
}
- Err(e) => Err(e)?,
}
}
}
}
}
-impl Sink for Client {
- type SinkItem = Packet;
- type SinkError = Error;
+impl Sink<Packet> for Client {
+ type Error = Error;
- fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
+ fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
match self.state {
- ClientState::Connected(ref mut stream) => Ok(stream.start_send(item)?),
- _ => Ok(AsyncSink::NotReady(item)),
+ ClientState::Connected(ref mut stream) => {
+ Pin::new(stream).start_send(item).map_err(|e| e.into())
+ }
+ _ => Err(Error::InvalidState),
}
}
- fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
+ fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
match self.state {
- ClientState::Connected(ref mut stream) => stream.poll_complete().map_err(|e| e.into()),
- _ => Ok(Async::Ready(())),
+ ClientState::Connected(ref mut stream) => {
+ Pin::new(stream).poll_ready(cx).map_err(|e| e.into())
+ }
+ _ => Poll::Pending,
}
}
- /// This closes the inner TCP stream.
- ///
- /// To synchronize your shutdown with the server side, you should
- /// first send `Packet::StreamEnd` and wait for the end of the
- /// incoming stream before closing the connection.
- fn close(&mut self) -> Poll<(), Self::SinkError> {
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
match self.state {
- ClientState::Connected(ref mut stream) => stream.close().map_err(|e| e.into()),
- _ => Ok(Async::Ready(())),
+ ClientState::Connected(ref mut stream) => {
+ Pin::new(stream).poll_flush(cx).map_err(|e| e.into())
+ }
+ _ => Poll::Pending,
+ }
+ }
+
+ fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+ match self.state {
+ ClientState::Connected(ref mut stream) => {
+ Pin::new(stream).poll_close(cx).map_err(|e| e.into())
+ }
+ _ => Poll::Pending,
}
}
}
@@ -1,6 +1,6 @@
-use futures::{sink, Async, Future, Poll, Stream};
-use std::mem::replace;
-use tokio_io::{AsyncRead, AsyncWrite};
+use futures::stream::StreamExt;
+use std::marker::Unpin;
+use tokio::io::{AsyncRead, AsyncWrite};
use xmpp_parsers::component::Handshake;
use crate::xmpp_codec::Packet;
@@ -9,81 +9,27 @@ use crate::{AuthError, Error};
const NS_JABBER_COMPONENT_ACCEPT: &str = "jabber:component:accept";
-pub struct ComponentAuth<S: AsyncWrite> {
- state: ComponentAuthState<S>,
-}
-
-enum ComponentAuthState<S: AsyncWrite> {
- WaitSend(sink::Send<XMPPStream<S>>),
- WaitRecv(XMPPStream<S>),
- Invalid,
-}
-
-impl<S: AsyncWrite> ComponentAuth<S> {
- // TODO: doesn't have to be a Result<> actually
- pub fn new(stream: XMPPStream<S>, password: String) -> Result<Self, Error> {
- // FIXME: huge hack, shouldnβt be an element!
- let sid = stream.stream_features.name().to_owned();
- let mut this = ComponentAuth {
- state: ComponentAuthState::Invalid,
- };
- this.send(
- stream,
- Handshake::from_password_and_stream_id(&password, &sid),
- );
- Ok(this)
- }
-
- fn send(&mut self, stream: XMPPStream<S>, handshake: Handshake) {
- let nonza = handshake;
- let send = stream.send_stanza(nonza);
-
- self.state = ComponentAuthState::WaitSend(send);
- }
-}
-
-impl<S: AsyncRead + AsyncWrite> Future for ComponentAuth<S> {
- type Item = XMPPStream<S>;
- type Error = Error;
-
- fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
- let state = replace(&mut self.state, ComponentAuthState::Invalid);
-
- match state {
- ComponentAuthState::WaitSend(mut send) => match send.poll() {
- Ok(Async::Ready(stream)) => {
- self.state = ComponentAuthState::WaitRecv(stream);
- self.poll()
- }
- Ok(Async::NotReady) => {
- self.state = ComponentAuthState::WaitSend(send);
- Ok(Async::NotReady)
- }
- Err(e) => Err(e)?,
- },
- ComponentAuthState::WaitRecv(mut stream) => match stream.poll() {
- Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
- if stanza.is("handshake", NS_JABBER_COMPONENT_ACCEPT) =>
- {
- self.state = ComponentAuthState::Invalid;
- Ok(Async::Ready(stream))
- }
- Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
- if stanza.is("error", "http://etherx.jabber.org/streams") =>
- {
- Err(AuthError::ComponentFail.into())
- }
- Ok(Async::Ready(_event)) => {
- // println!("ComponentAuth ignore {:?}", _event);
- Ok(Async::NotReady)
- }
- Ok(_) => {
- self.state = ComponentAuthState::WaitRecv(stream);
- Ok(Async::NotReady)
- }
- Err(e) => Err(e)?,
- },
- ComponentAuthState::Invalid => unreachable!(),
+pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
+ stream: &mut XMPPStream<S>,
+ password: String,
+) -> Result<(), Error> {
+ let nonza = Handshake::from_password_and_stream_id(&password, &stream.id);
+ stream.send_stanza(nonza).await?;
+
+ loop {
+ match stream.next().await {
+ Some(Ok(Packet::Stanza(ref stanza)))
+ if stanza.is("handshake", NS_JABBER_COMPONENT_ACCEPT) =>
+ {
+ return Ok(());
+ }
+ Some(Ok(Packet::Stanza(ref stanza)))
+ if stanza.is("error", "http://etherx.jabber.org/streams") =>
+ {
+ return Err(AuthError::ComponentFail.into());
+ }
+ Some(_) => {}
+ None => return Err(Error::Disconnected),
}
}
}
@@ -1,163 +1,115 @@
//! Components in XMPP are services/gateways that are logged into an
//! XMPP server under a JID consisting of just a domain name. They are
//! allowed to use any user and resource identifiers in their stanzas.
-use futures::{done, Async, AsyncSink, Future, Poll, Sink, StartSend, Stream};
-use std::mem::replace;
+use futures::{sink::SinkExt, task::Poll, Sink, Stream};
+use std::pin::Pin;
use std::str::FromStr;
+use std::task::Context;
use tokio::net::TcpStream;
-use tokio_io::{AsyncRead, AsyncWrite};
-use xmpp_parsers::{Element, Jid, JidParseError};
+use xmpp_parsers::{Element, Jid};
-use super::event::Event;
-use super::happy_eyeballs::Connecter;
+use super::happy_eyeballs::connect;
use super::xmpp_codec::Packet;
use super::xmpp_stream;
use super::Error;
mod auth;
-use self::auth::ComponentAuth;
/// Component connection to an XMPP server
+///
+/// This simplifies the `XMPPStream` to a `Stream`/`Sink` of `Element`
+/// (stanzas). Connection handling however is up to the user.
pub struct Component {
/// The component's Jabber-Id
pub jid: Jid,
- state: ComponentState,
+ stream: XMPPStream,
}
type XMPPStream = xmpp_stream::XMPPStream<TcpStream>;
const NS_JABBER_COMPONENT_ACCEPT: &str = "jabber:component:accept";
-enum ComponentState {
- Invalid,
- Disconnected,
- Connecting(Box<dyn Future<Item = XMPPStream, Error = Error>>),
- Connected(XMPPStream),
-}
-
impl Component {
/// Start a new XMPP component
- ///
- /// Start polling the returned instance so that it will connect
- /// and yield events.
- pub fn new(jid: &str, password: &str, server: &str, port: u16) -> Result<Self, JidParseError> {
+ pub async fn new(jid: &str, password: &str, server: &str, port: u16) -> Result<Self, Error> {
let jid = Jid::from_str(jid)?;
let password = password.to_owned();
- let connect = Self::make_connect(jid.clone(), password, server, port);
- Ok(Component {
- jid,
- state: ComponentState::Connecting(Box::new(connect)),
- })
+ let stream = Self::connect(jid.clone(), password, server, port).await?;
+ Ok(Component { jid, stream })
}
- fn make_connect(
+ async fn connect(
jid: Jid,
password: String,
server: &str,
port: u16,
- ) -> impl Future<Item = XMPPStream, Error = Error> {
- let jid1 = jid.clone();
+ ) -> Result<XMPPStream, Error> {
let password = password;
- done(Connecter::from_lookup(server, None, port))
- .flatten()
- .and_then(move |tcp_stream| {
- xmpp_stream::XMPPStream::start(
- tcp_stream,
- jid1,
- NS_JABBER_COMPONENT_ACCEPT.to_owned(),
- )
- })
- .and_then(move |xmpp_stream| Self::auth(xmpp_stream, password).expect("auth"))
+ let tcp_stream = connect(server, None, port).await?;
+ let mut xmpp_stream =
+ xmpp_stream::XMPPStream::start(tcp_stream, jid, NS_JABBER_COMPONENT_ACCEPT.to_owned())
+ .await?;
+ auth::auth(&mut xmpp_stream, password).await?;
+ Ok(xmpp_stream)
}
- fn auth<S: AsyncRead + AsyncWrite>(
- stream: xmpp_stream::XMPPStream<S>,
- password: String,
- ) -> Result<ComponentAuth<S>, Error> {
- ComponentAuth::new(stream, password)
+ /// Send stanza
+ pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> {
+ self.send(stanza).await
+ }
+
+ /// End connection
+ pub async fn send_end(&mut self) -> Result<(), Error> {
+ self.close().await
}
}
impl Stream for Component {
- type Item = Event;
- type Error = Error;
-
- fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
- let state = replace(&mut self.state, ComponentState::Invalid);
+ type Item = Element;
- match state {
- ComponentState::Invalid => Err(Error::InvalidState),
- ComponentState::Disconnected => Ok(Async::Ready(None)),
- ComponentState::Connecting(mut connect) => match connect.poll() {
- Ok(Async::Ready(stream)) => {
- self.state = ComponentState::Connected(stream);
- Ok(Async::Ready(Some(Event::Online(self.jid.clone()))))
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
+ loop {
+ match Pin::new(&mut self.stream).poll_next(cx) {
+ Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => return Poll::Ready(Some(stanza)),
+ Poll::Ready(Some(Ok(Packet::Text(_)))) => {
+ // retry
}
- Ok(Async::NotReady) => {
- self.state = ComponentState::Connecting(connect);
- Ok(Async::NotReady)
- }
- Err(e) => Err(e),
- },
- ComponentState::Connected(mut stream) => {
- // Poll sink
- match stream.poll_complete() {
- Ok(Async::NotReady) => (),
- Ok(Async::Ready(())) => (),
- Err(e) => return Err(e)?,
- };
-
- // Poll stream
- match stream.poll() {
- Ok(Async::NotReady) => {
- self.state = ComponentState::Connected(stream);
- Ok(Async::NotReady)
- }
- Ok(Async::Ready(None)) => {
- // EOF
- self.state = ComponentState::Disconnected;
- Ok(Async::Ready(Some(Event::Disconnected)))
- }
- Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
- self.state = ComponentState::Connected(stream);
- Ok(Async::Ready(Some(Event::Stanza(stanza))))
- }
- Ok(Async::Ready(_)) => {
- self.state = ComponentState::Connected(stream);
- Ok(Async::NotReady)
- }
- Err(e) => Err(e)?,
+ Poll::Ready(Some(Ok(_))) =>
+ // unexpected
+ {
+ return Poll::Ready(None)
}
+ Poll::Ready(Some(Err(_))) => return Poll::Ready(None),
+ Poll::Ready(None) => return Poll::Ready(None),
+ Poll::Pending => return Poll::Pending,
}
}
}
}
-impl Sink for Component {
- type SinkItem = Element;
- type SinkError = Error;
+impl Sink<Element> for Component {
+ type Error = Error;
- fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
- match self.state {
- ComponentState::Connected(ref mut stream) => match stream
- .start_send(Packet::Stanza(item))
- {
- Ok(AsyncSink::NotReady(Packet::Stanza(stanza))) => Ok(AsyncSink::NotReady(stanza)),
- Ok(AsyncSink::NotReady(_)) => {
- panic!("Component.start_send with stanza but got something else back")
- }
- Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
- Err(e) => Err(e)?,
- },
- _ => Ok(AsyncSink::NotReady(item)),
- }
+ fn start_send(mut self: Pin<&mut Self>, item: Element) -> Result<(), Self::Error> {
+ Pin::new(&mut self.stream)
+ .start_send(Packet::Stanza(item))
+ .map_err(|e| e.into())
}
- fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
- match &mut self.state {
- &mut ComponentState::Connected(ref mut stream) => {
- stream.poll_complete().map_err(|e| e.into())
- }
- _ => Ok(Async::Ready(())),
- }
+ fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+ Pin::new(&mut self.stream)
+ .poll_ready(cx)
+ .map_err(|e| e.into())
+ }
+
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+ Pin::new(&mut self.stream)
+ .poll_flush(cx)
+ .map_err(|e| e.into())
+ }
+
+ fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+ Pin::new(&mut self.stream)
+ .poll_close(cx)
+ .map_err(|e| e.into())
}
}
@@ -8,7 +8,7 @@ use trust_dns_proto::error::ProtoError;
use trust_dns_resolver::error::ResolveError;
use xmpp_parsers::sasl::DefinedCondition as SaslDefinedCondition;
-use xmpp_parsers::Error as ParsersError;
+use xmpp_parsers::{Error as ParsersError, JidParseError};
/// Top-level error type
#[derive(Debug)]
@@ -20,6 +20,8 @@ pub enum Error {
/// DNS label conversion error, no details available from module
/// `idna`
Idna,
+ /// Error parsing Jabber-Id
+ JidParse(JidParseError),
/// Protocol-level error
Protocol(ProtocolError),
/// Authentication error
@@ -38,6 +40,7 @@ impl fmt::Display for Error {
Error::Io(e) => write!(fmt, "IO error: {}", e),
Error::Connection(e) => write!(fmt, "connection error: {}", e),
Error::Idna => write!(fmt, "IDNA error"),
+ Error::JidParse(e) => write!(fmt, "jid parse error: {}", e),
Error::Protocol(e) => write!(fmt, "protocol error: {}", e),
Error::Auth(e) => write!(fmt, "authentication error: {}", e),
Error::Tls(e) => write!(fmt, "TLS error: {}", e),
@@ -59,6 +62,12 @@ impl From<ConnecterError> for Error {
}
}
+impl From<JidParseError> for Error {
+ fn from(e: JidParseError) -> Self {
+ Error::JidParse(e)
+ }
+}
+
impl From<ProtocolError> for Error {
fn from(e: ProtocolError) -> Self {
Error::Protocol(e)
@@ -1,3 +1,4 @@
+use super::Error;
use xmpp_parsers::{Element, Jid};
/// High-level event on the Stream implemented by Client and Component
@@ -6,7 +7,7 @@ pub enum Event {
/// Stream is connected and initialized
Online(Jid),
/// Stream end
- Disconnected,
+ Disconnected(Error),
/// Received stanza/nonza
Stanza(Element),
}
@@ -1,195 +1,63 @@
use crate::{ConnecterError, Error};
-use futures::{Async, Future, Poll};
-use std::cell::RefCell;
-use std::collections::BTreeMap;
-use std::collections::VecDeque;
-use std::io::Error as IoError;
-use std::mem;
use std::net::SocketAddr;
-use tokio::net::tcp::ConnectFuture;
use tokio::net::TcpStream;
-use trust_dns_resolver::config::LookupIpStrategy;
-use trust_dns_resolver::lookup::SrvLookupFuture;
-use trust_dns_resolver::lookup_ip::LookupIpFuture;
-use trust_dns_resolver::{AsyncResolver, Background, BackgroundLookup, IntoName, Name};
-
-enum State {
- ResolveSrv(AsyncResolver, BackgroundLookup<SrvLookupFuture>),
- ResolveTarget(AsyncResolver, Background<LookupIpFuture>, u16),
- Connecting(Option<AsyncResolver>, Vec<RefCell<ConnectFuture>>),
- Invalid,
+use trust_dns_resolver::{IntoName, TokioAsyncResolver};
+
+async fn connect_to_host(
+ resolver: &TokioAsyncResolver,
+ host: &str,
+ port: u16,
+) -> Result<TcpStream, Error> {
+ let ips = resolver
+ .lookup_ip(host)
+ .await
+ .map_err(ConnecterError::Resolve)?;
+ for ip in ips.iter() {
+ match TcpStream::connect(&SocketAddr::new(ip, port)).await {
+ Ok(stream) => return Ok(stream),
+ Err(_) => {}
+ }
+ }
+ Err(Error::Disconnected)
}
-pub struct Connecter {
+pub async fn connect(
+ domain: &str,
+ srv: Option<&str>,
fallback_port: u16,
- srv_domain: Option<Name>,
- domain: Name,
- state: State,
- targets: VecDeque<(Name, u16)>,
- error: Option<Error>,
-}
-
-fn resolver() -> Result<AsyncResolver, IoError> {
- let (config, mut opts) = trust_dns_resolver::system_conf::read_system_conf()?;
- opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
- let (resolver, resolver_background) = AsyncResolver::new(config, opts);
- tokio::runtime::current_thread::spawn(resolver_background);
- Ok(resolver)
-}
-
-impl Connecter {
- pub fn from_lookup(
- domain: &str,
- srv: Option<&str>,
- fallback_port: u16,
- ) -> Result<Connecter, Error> {
- if let Ok(ip) = domain.parse() {
- // use specified IP address, not domain name, skip the whole dns part
- let connect = RefCell::new(TcpStream::connect(&SocketAddr::new(ip, fallback_port)));
- return Ok(Connecter {
- fallback_port,
- srv_domain: None,
- domain: "nohost".into_name().map_err(ConnecterError::Dns)?,
- state: State::Connecting(None, vec![connect]),
- targets: VecDeque::new(),
- error: None,
- });
- }
-
- let srv_domain = match srv {
- Some(srv) => Some(
- format!("{}.{}.", srv, domain)
- .into_name()
- .map_err(ConnecterError::Dns)?,
- ),
- None => None,
- };
-
- let mut self_ = Connecter {
- fallback_port,
- srv_domain,
- domain: domain.into_name().map_err(ConnecterError::Dns)?,
- state: State::Invalid,
- targets: VecDeque::new(),
- error: None,
- };
-
- let resolver = resolver()?;
- // Initialize state
- match &self_.srv_domain {
- &Some(ref srv_domain) => {
- let srv_lookup = resolver.lookup_srv(srv_domain.clone());
- self_.state = State::ResolveSrv(resolver, srv_lookup);
- }
- None => {
- self_.targets = [(self_.domain.clone(), self_.fallback_port)]
- .iter()
- .cloned()
- .collect();
- self_.state = State::Connecting(Some(resolver), vec![]);
- }
- }
-
- Ok(self_)
+) -> Result<TcpStream, Error> {
+ if let Ok(ip) = domain.parse() {
+ return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
}
-}
-impl Future for Connecter {
- type Item = TcpStream;
- type Error = Error;
+ let resolver = TokioAsyncResolver::tokio_from_system_conf()
+ .await
+ .map_err(ConnecterError::Resolve)?;
- fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
- let state = mem::replace(&mut self.state, State::Invalid);
- match state {
- State::ResolveSrv(resolver, mut srv_lookup) => {
- match srv_lookup.poll() {
- Ok(Async::NotReady) => {
- self.state = State::ResolveSrv(resolver, srv_lookup);
- Ok(Async::NotReady)
- }
- Ok(Async::Ready(srv_result)) => {
- let srv_map: BTreeMap<_, _> = srv_result
- .iter()
- .map(|srv| (srv.priority(), (srv.target().clone(), srv.port())))
- .collect();
- let targets = srv_map.into_iter().map(|(_, tp)| tp).collect();
- self.targets = targets;
- self.state = State::Connecting(Some(resolver), vec![]);
- self.poll()
- }
- Err(_) => {
- // ignore, fallback
- self.targets = [(self.domain.clone(), self.fallback_port)]
- .iter()
- .cloned()
- .collect();
- self.state = State::Connecting(Some(resolver), vec![]);
- self.poll()
- }
- }
- }
- State::Connecting(resolver, mut connects) => {
- if resolver.is_some() && connects.len() == 0 && self.targets.len() > 0 {
- let resolver = resolver.unwrap();
- let (host, port) = self.targets.pop_front().unwrap();
- let ip_lookup = resolver.lookup_ip(host);
- self.state = State::ResolveTarget(resolver, ip_lookup, port);
- self.poll()
- } else if connects.len() > 0 {
- let mut success = None;
- connects.retain(|connect| match connect.borrow_mut().poll() {
- Ok(Async::NotReady) => true,
- Ok(Async::Ready(connection)) => {
- success = Some(connection);
- false
- }
- Err(e) => {
- if self.error.is_none() {
- self.error = Some(e.into());
- }
- false
- }
- });
- match success {
- Some(connection) => Ok(Async::Ready(connection)),
- None => {
- self.state = State::Connecting(resolver, connects);
- Ok(Async::NotReady)
- }
- }
- } else {
- // All targets tried
- match self.error.take() {
- None => Err(ConnecterError::AllFailed.into()),
- Some(e) => Err(e),
- }
- }
- }
- State::ResolveTarget(resolver, mut ip_lookup, port) => {
- match ip_lookup.poll() {
- Ok(Async::NotReady) => {
- self.state = State::ResolveTarget(resolver, ip_lookup, port);
- Ok(Async::NotReady)
- }
- Ok(Async::Ready(ip_result)) => {
- let connects = ip_result
- .iter()
- .map(|ip| RefCell::new(TcpStream::connect(&SocketAddr::new(ip, port))))
- .collect();
- self.state = State::Connecting(Some(resolver), connects);
- self.poll()
- }
- Err(e) => {
- if self.error.is_none() {
- self.error = Some(ConnecterError::Resolve(e).into());
- }
- // ignore, nextβ¦
- self.state = State::Connecting(Some(resolver), vec![]);
- self.poll()
- }
+ let srv_records = match srv {
+ Some(srv) => {
+ let srv_domain = format!("{}.{}.", srv, domain)
+ .into_name()
+ .map_err(ConnecterError::Dns)?;
+ resolver.srv_lookup(srv_domain).await.ok()
+ }
+ None => None,
+ };
+
+ match srv_records {
+ Some(lookup) => {
+ // TODO: sort lookup records by priority/weight
+ for srv in lookup.iter() {
+ match connect_to_host(&resolver, &srv.target().to_ascii(), srv.port()).await {
+ Ok(stream) => return Ok(stream),
+ Err(_) => {}
}
}
- _ => panic!(""),
+ Err(Error::Disconnected)
+ }
+ None => {
+ // SRV lookup error, retry with hostname
+ connect_to_host(&resolver, domain, fallback_port).await
}
}
}
@@ -6,10 +6,9 @@ mod starttls;
mod stream_start;
pub mod xmpp_codec;
pub use crate::xmpp_codec::Packet;
-pub mod xmpp_stream;
-pub use crate::starttls::StartTlsClient;
mod event;
mod happy_eyeballs;
+pub mod xmpp_stream;
pub use crate::event::Event;
mod client;
pub use crate::client::Client;
@@ -1,114 +1,39 @@
-use futures::sink;
-use futures::stream::Stream;
-use futures::{Async, Future, Poll, Sink};
+use futures::{sink::SinkExt, stream::StreamExt};
use native_tls::TlsConnector as NativeTlsConnector;
-use std::mem::replace;
-use tokio_io::{AsyncRead, AsyncWrite};
-use tokio_tls::{Connect, TlsConnector, TlsStream};
-use xmpp_parsers::{Element, Jid};
+use tokio::io::{AsyncRead, AsyncWrite};
+use tokio_tls::{TlsConnector, TlsStream};
+use xmpp_parsers::Element;
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::XMPPStream;
-use crate::Error;
+use crate::{Error, ProtocolError};
/// XMPP TLS XML namespace
pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
-/// XMPP stream that switches to TLS if available in received features
-pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
- state: StartTlsClientState<S>,
- jid: Jid,
-}
-
-enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
- Invalid,
- SendStartTls(sink::Send<XMPPStream<S>>),
- AwaitProceed(XMPPStream<S>),
- StartingTls(Connect<S>),
-}
-
-impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
- /// Waits for <stream:features>
- pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
- let jid = xmpp_stream.jid.clone();
-
- let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build();
- let packet = Packet::Stanza(nonza);
- let send = xmpp_stream.send(packet);
-
- StartTlsClient {
- state: StartTlsClientState::SendStartTls(send),
- jid,
+pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
+ mut xmpp_stream: XMPPStream<S>,
+) -> Result<TlsStream<S>, Error> {
+ let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build();
+ let packet = Packet::Stanza(nonza);
+ xmpp_stream.send(packet).await?;
+
+ loop {
+ match xmpp_stream.next().await {
+ Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break,
+ Some(Ok(Packet::Text(_))) => {}
+ Some(Err(e)) => return Err(e.into()),
+ _ => {
+ return Err(ProtocolError::NoTls.into());
+ }
}
}
-}
-
-impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
- type Item = TlsStream<S>;
- type Error = Error;
-
- fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
- let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
- let mut retry = false;
- let (new_state, result) = match old_state {
- StartTlsClientState::SendStartTls(mut send) => match send.poll() {
- Ok(Async::Ready(xmpp_stream)) => {
- let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
- retry = true;
- (new_state, Ok(Async::NotReady))
- }
- Ok(Async::NotReady) => {
- (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady))
- }
- Err(e) => (StartTlsClientState::SendStartTls(send), Err(e.into())),
- },
- StartTlsClientState::AwaitProceed(mut xmpp_stream) => match xmpp_stream.poll() {
- Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
- if stanza.name() == "proceed" =>
- {
- let stream = xmpp_stream.stream.into_inner();
- let connect =
- TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
- .connect(&self.jid.clone().domain(), stream);
- let new_state = StartTlsClientState::StartingTls(connect);
- retry = true;
- (new_state, Ok(Async::NotReady))
- }
- Ok(Async::Ready(_value)) => {
- // println!("StartTlsClient ignore {:?}", _value);
- (
- StartTlsClientState::AwaitProceed(xmpp_stream),
- Ok(Async::NotReady),
- )
- }
- Ok(_) => (
- StartTlsClientState::AwaitProceed(xmpp_stream),
- Ok(Async::NotReady),
- ),
- Err(e) => (
- StartTlsClientState::AwaitProceed(xmpp_stream),
- Err(Error::Protocol(e.into())),
- ),
- },
- StartTlsClientState::StartingTls(mut connect) => match connect.poll() {
- Ok(Async::Ready(tls_stream)) => {
- (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream)))
- }
- Ok(Async::NotReady) => (
- StartTlsClientState::StartingTls(connect),
- Ok(Async::NotReady),
- ),
- Err(e) => (StartTlsClientState::Invalid, Err(e.into())),
- },
- StartTlsClientState::Invalid => unreachable!(),
- };
+ let domain = xmpp_stream.jid.clone().domain();
+ let stream = xmpp_stream.into_inner();
+ let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
+ .connect(&domain, stream)
+ .await?;
- self.state = new_state;
- if retry {
- self.poll()
- } else {
- result
- }
- }
+ Ok(tls_stream)
}
@@ -1,7 +1,7 @@
-use futures::{sink, Async, Future, Poll, Sink, Stream};
-use std::mem::replace;
-use tokio_codec::Framed;
-use tokio_io::{AsyncRead, AsyncWrite};
+use futures::{sink::SinkExt, stream::StreamExt};
+use std::marker::Unpin;
+use tokio::io::{AsyncRead, AsyncWrite};
+use tokio_util::codec::Framed;
use xmpp_parsers::{Element, Jid};
use crate::xmpp_codec::{Packet, XMPPCodec};
@@ -10,116 +10,66 @@ use crate::{Error, ProtocolError};
const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
-pub struct StreamStart<S: AsyncWrite> {
- state: StreamStartState<S>,
+pub async fn start<S: AsyncRead + AsyncWrite + Unpin>(
+ mut stream: Framed<S, XMPPCodec>,
jid: Jid,
ns: String,
-}
-
-enum StreamStartState<S: AsyncWrite> {
- SendStart(sink::Send<Framed<S, XMPPCodec>>),
- RecvStart(Framed<S, XMPPCodec>),
- RecvFeatures(Framed<S, XMPPCodec>, String),
- Invalid,
-}
-
-impl<S: AsyncWrite> StreamStart<S> {
- pub fn from_stream(stream: Framed<S, XMPPCodec>, jid: Jid, ns: String) -> Self {
- let attrs = [
- ("to".to_owned(), jid.clone().domain()),
- ("version".to_owned(), "1.0".to_owned()),
- ("xmlns".to_owned(), ns.clone()),
- ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
- ]
- .iter()
- .cloned()
- .collect();
- let send = stream.send(Packet::StreamStart(attrs));
+) -> Result<XMPPStream<S>, Error> {
+ let attrs = [
+ ("to".to_owned(), jid.clone().domain()),
+ ("version".to_owned(), "1.0".to_owned()),
+ ("xmlns".to_owned(), ns.clone()),
+ ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
+ ]
+ .iter()
+ .cloned()
+ .collect();
+ stream.send(Packet::StreamStart(attrs)).await?;
- StreamStart {
- state: StreamStartState::SendStart(send),
- jid,
- ns,
+ let stream_attrs;
+ loop {
+ match stream.next().await {
+ Some(Ok(Packet::StreamStart(attrs))) => {
+ stream_attrs = attrs;
+ break;
+ }
+ Some(Ok(_)) => {}
+ Some(Err(e)) => return Err(e.into()),
+ None => return Err(Error::Disconnected),
}
}
-}
-
-impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
- type Item = XMPPStream<S>;
- type Error = Error;
- fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
- let old_state = replace(&mut self.state, StreamStartState::Invalid);
- let mut retry = false;
-
- let (new_state, result) = match old_state {
- StreamStartState::SendStart(mut send) => match send.poll() {
- Ok(Async::Ready(stream)) => {
- retry = true;
- (StreamStartState::RecvStart(stream), Ok(Async::NotReady))
- }
- Ok(Async::NotReady) => (StreamStartState::SendStart(send), Ok(Async::NotReady)),
- Err(e) => (StreamStartState::Invalid, Err(e.into())),
- },
- StreamStartState::RecvStart(mut stream) => match stream.poll() {
- Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => {
- let stream_ns = stream_attrs
- .get("xmlns")
- .ok_or(ProtocolError::NoStreamNamespace)?
- .clone();
- if self.ns == "jabber:client" {
- retry = true;
- // TODO: skip RecvFeatures for version < 1.0
- (
- StreamStartState::RecvFeatures(stream, stream_ns),
- Ok(Async::NotReady),
- )
- } else {
- let id = stream_attrs
- .get("id")
- .ok_or(ProtocolError::NoStreamId)?
- .clone();
- // FIXME: huge hack, shouldnβt be an element!
- let stream = XMPPStream::new(
- self.jid.clone(),
- stream,
- self.ns.clone(),
- Element::builder(id).build(),
- );
- (StreamStartState::Invalid, Ok(Async::Ready(stream)))
- }
+ let stream_ns = stream_attrs
+ .get("xmlns")
+ .ok_or(ProtocolError::NoStreamNamespace)?
+ .clone();
+ let stream_id = stream_attrs
+ .get("id")
+ .ok_or(ProtocolError::NoStreamId)?
+ .clone();
+ let stream = if stream_ns == "jabber:client" && stream_attrs.get("version").is_some() {
+ let stream_features;
+ loop {
+ match stream.next().await {
+ Some(Ok(Packet::Stanza(stanza))) if stanza.is("features", NS_XMPP_STREAM) => {
+ stream_features = stanza;
+ break;
}
- Ok(Async::Ready(_)) => return Err(ProtocolError::InvalidToken.into()),
- Ok(Async::NotReady) => (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
- Err(e) => return Err(ProtocolError::from(e).into()),
- },
- StreamStartState::RecvFeatures(mut stream, stream_ns) => match stream.poll() {
- Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
- if stanza.is("features", NS_XMPP_STREAM) {
- let stream =
- XMPPStream::new(self.jid.clone(), stream, self.ns.clone(), stanza);
- (StreamStartState::Invalid, Ok(Async::Ready(stream)))
- } else {
- (
- StreamStartState::RecvFeatures(stream, stream_ns),
- Ok(Async::NotReady),
- )
- }
- }
- Ok(Async::Ready(_)) | Ok(Async::NotReady) => (
- StreamStartState::RecvFeatures(stream, stream_ns),
- Ok(Async::NotReady),
- ),
- Err(e) => return Err(ProtocolError::from(e).into()),
- },
- StreamStartState::Invalid => unreachable!(),
- };
-
- self.state = new_state;
- if retry {
- self.poll()
- } else {
- result
+ Some(Ok(_)) => {}
+ Some(Err(e)) => return Err(e.into()),
+ None => return Err(Error::Disconnected),
+ }
}
- }
+ XMPPStream::new(jid, stream, ns, stream_id, stream_features)
+ } else {
+ // FIXME: huge hack, shouldnβt be an element!
+ XMPPStream::new(
+ jid,
+ stream,
+ ns,
+ stream_id.clone(),
+ Element::builder(stream_id).build(),
+ )
+ };
+ Ok(stream)
}
@@ -5,16 +5,16 @@ use bytes::{BufMut, BytesMut};
use log::debug;
use std;
use std::borrow::Cow;
-use std::cell::RefCell;
use std::collections::vec_deque::VecDeque;
use std::collections::HashMap;
use std::default::Default;
use std::fmt::Write;
use std::io;
use std::iter::FromIterator;
-use std::rc::Rc;
use std::str::from_utf8;
-use tokio_codec::{Decoder, Encoder};
+use std::sync::Arc;
+use std::sync::Mutex;
+use tokio_util::codec::{Decoder, Encoder};
use xml5ever::buffer_queue::BufferQueue;
use xml5ever::interface::Attribute;
use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
@@ -38,14 +38,14 @@ type QueueItem = Result<Packet, ParserError>;
/// Parser state
struct ParserSink {
// Ready stanzas, shared with XMPPCodec
- queue: Rc<RefCell<VecDeque<QueueItem>>>,
+ queue: Arc<Mutex<VecDeque<QueueItem>>>,
// Parsing stack
stack: Vec<Element>,
ns_stack: Vec<HashMap<Option<String>, String>>,
}
impl ParserSink {
- pub fn new(queue: Rc<RefCell<VecDeque<QueueItem>>>) -> Self {
+ pub fn new(queue: Arc<Mutex<VecDeque<QueueItem>>>) -> Self {
ParserSink {
queue,
stack: vec![],
@@ -54,11 +54,11 @@ impl ParserSink {
}
fn push_queue(&self, pkt: Packet) {
- self.queue.borrow_mut().push_back(Ok(pkt));
+ self.queue.lock().unwrap().push_back(Ok(pkt));
}
fn push_queue_error(&self, e: ParserError) {
- self.queue.borrow_mut().push_back(Err(e));
+ self.queue.lock().unwrap().push_back(Err(e));
}
/// Lookup XML namespace declaration for given prefix (or no prefix)
@@ -169,7 +169,6 @@ impl TokenSink for ParserSink {
},
Token::EOFToken => self.push_queue(Packet::StreamEnd),
Token::ParseError(s) => {
- // println!("ParseError: {:?}", s);
self.push_queue_error(ParserError::Parse(ParseError(s)));
}
_ => (),
@@ -190,13 +189,13 @@ pub struct XMPPCodec {
// TODO: optimize using tendrils?
buf: Vec<u8>,
/// Shared with ParserSink
- queue: Rc<RefCell<VecDeque<QueueItem>>>,
+ queue: Arc<Mutex<VecDeque<QueueItem>>>,
}
impl XMPPCodec {
/// Constructor
pub fn new() -> Self {
- let queue = Rc::new(RefCell::new(VecDeque::new()));
+ let queue = Arc::new(Mutex::new(VecDeque::new()));
let sink = ParserSink::new(queue.clone());
// TODO: configure parser?
let parser = XmlTokenizer::new(sink, Default::default());
@@ -222,10 +221,10 @@ impl Decoder for XMPPCodec {
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let buf1: Box<dyn AsRef<[u8]>> = if !self.buf.is_empty() && !buf.is_empty() {
let mut prefix = std::mem::replace(&mut self.buf, vec![]);
- prefix.extend_from_slice(buf.take().as_ref());
+ prefix.extend_from_slice(&buf.split_to(buf.len()));
Box::new(prefix)
} else {
- Box::new(buf.take())
+ Box::new(buf.split_to(buf.len()))
};
let buf1 = buf1.as_ref().as_ref();
match from_utf8(buf1) {
@@ -258,7 +257,7 @@ impl Decoder for XMPPCodec {
}
}
- match self.queue.borrow_mut().pop_front() {
+ match self.queue.lock().unwrap().pop_front() {
None => Ok(None),
Some(result) => result.map(|pkt| Some(pkt)),
}
@@ -372,7 +371,7 @@ mod tests {
fn test_stream_start() {
let mut c = XMPPCodec::new();
let mut b = BytesMut::with_capacity(1024);
- b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+ b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
let r = c.decode(&mut b);
assert!(match r {
Ok(Some(Packet::StreamStart(_))) => true,
@@ -384,14 +383,14 @@ mod tests {
fn test_stream_end() {
let mut c = XMPPCodec::new();
let mut b = BytesMut::with_capacity(1024);
- b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+ b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
let r = c.decode(&mut b);
assert!(match r {
Ok(Some(Packet::StreamStart(_))) => true,
_ => false,
});
b.clear();
- b.put(r"</stream:stream>");
+ b.put_slice(b"</stream:stream>");
let r = c.decode(&mut b);
assert!(match r {
Ok(Some(Packet::StreamEnd)) => true,
@@ -403,7 +402,7 @@ mod tests {
fn test_truncated_stanza() {
let mut c = XMPPCodec::new();
let mut b = BytesMut::with_capacity(1024);
- b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+ b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
let r = c.decode(&mut b);
assert!(match r {
Ok(Some(Packet::StreamStart(_))) => true,
@@ -411,7 +410,7 @@ mod tests {
});
b.clear();
- b.put(r"<test>Γ</test");
+ b.put_slice("<test>Γ</test".as_bytes());
let r = c.decode(&mut b);
assert!(match r {
Ok(None) => true,
@@ -419,7 +418,7 @@ mod tests {
});
b.clear();
- b.put(r">");
+ b.put_slice(b">");
let r = c.decode(&mut b);
assert!(match r {
Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "Γ" => true,
@@ -431,7 +430,7 @@ mod tests {
fn test_truncated_utf8() {
let mut c = XMPPCodec::new();
let mut b = BytesMut::with_capacity(1024);
- b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+ b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
let r = c.decode(&mut b);
assert!(match r {
Ok(Some(Packet::StreamStart(_))) => true,
@@ -460,7 +459,7 @@ mod tests {
fn test_atrribute_prefix() {
let mut c = XMPPCodec::new();
let mut b = BytesMut::with_capacity(1024);
- b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+ b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
let r = c.decode(&mut b);
assert!(match r {
Ok(Some(Packet::StreamStart(_))) => true,
@@ -468,7 +467,7 @@ mod tests {
});
b.clear();
- b.put(r"<status xml:lang='en'>Test status</status>");
+ b.put_slice(b"<status xml:lang='en'>Test status</status>");
let r = c.decode(&mut b);
assert!(match r {
Ok(Some(Packet::Stanza(ref el)))
@@ -483,10 +482,10 @@ mod tests {
/// By default, encode() only get's a BytesMut that has 8kb space reserved.
#[test]
fn test_large_stanza() {
- use futures::{Future, Sink};
+ use futures::{executor::block_on, sink::SinkExt};
use std::io::Cursor;
- use tokio_codec::FramedWrite;
- let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
+ use tokio_util::codec::FramedWrite;
+ let mut framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
let mut text = "".to_owned();
for _ in 0..2usize.pow(15) {
text = text + "A";
@@ -494,7 +493,7 @@ mod tests {
let stanza = Element::builder("message")
.append(Element::builder("body").append(text.as_ref()).build())
.build();
- let framed = framed.send(Packet::Stanza(stanza)).wait().expect("send");
+ block_on(framed.send(Packet::Stanza(stanza))).expect("send");
assert_eq!(
framed.get_ref().get_ref(),
&("<message><body>".to_owned() + &text + "</body></message>").as_bytes()
@@ -505,7 +504,7 @@ mod tests {
fn test_cut_out_stanza() {
let mut c = XMPPCodec::new();
let mut b = BytesMut::with_capacity(1024);
- b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+ b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
let r = c.decode(&mut b);
assert!(match r {
Ok(Some(Packet::StreamStart(_))) => true,
@@ -513,8 +512,8 @@ mod tests {
});
b.clear();
- b.put(r"<message ");
- b.put(r"type='chat'><body>Foo</body></message>");
+ b.put_slice(b"<message ");
+ b.put_slice(b"type='chat'><body>Foo</body></message>");
let r = c.decode(&mut b);
assert!(match r {
Ok(Some(Packet::Stanza(_))) => true,
@@ -1,23 +1,28 @@
//! `XMPPStream` is the common container for all XMPP network connections
use futures::sink::Send;
-use futures::{Poll, Sink, StartSend, Stream};
-use tokio_codec::Framed;
-use tokio_io::{AsyncRead, AsyncWrite};
+use futures::{sink::SinkExt, task::Poll, Sink, Stream};
+use std::ops::DerefMut;
+use std::pin::Pin;
+use std::sync::Mutex;
+use std::task::Context;
+use tokio::io::{AsyncRead, AsyncWrite};
+use tokio_util::codec::Framed;
use xmpp_parsers::{Element, Jid};
-use crate::stream_start::StreamStart;
+use crate::stream_start;
use crate::xmpp_codec::{Packet, XMPPCodec};
+use crate::Error;
/// <stream:stream> namespace
pub const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
/// Wraps a `stream`
-pub struct XMPPStream<S> {
+pub struct XMPPStream<S: AsyncRead + AsyncWrite + Unpin> {
/// The local Jabber-Id
pub jid: Jid,
/// Codec instance
- pub stream: Framed<S, XMPPCodec>,
+ pub stream: Mutex<Framed<S, XMPPCodec>>,
/// `<stream:features/>` for XMPP version 1.0
pub stream_features: Element,
/// Root namespace
@@ -25,68 +30,94 @@ pub struct XMPPStream<S> {
/// This is different for either c2s, s2s, or component
/// connections.
pub ns: String,
+ /// Stream `id` attribute
+ pub id: String,
}
-impl<S: AsyncRead + AsyncWrite> XMPPStream<S> {
+// // TODO: fix this hack
+// unsafe impl<S: AsyncRead + AsyncWrite + Unpin> core::marker::Send for XMPPStream<S> {}
+// unsafe impl<S: AsyncRead + AsyncWrite + Unpin> Sync for XMPPStream<S> {}
+
+impl<S: AsyncRead + AsyncWrite + Unpin> XMPPStream<S> {
/// Constructor
pub fn new(
jid: Jid,
stream: Framed<S, XMPPCodec>,
ns: String,
+ id: String,
stream_features: Element,
) -> Self {
XMPPStream {
jid,
- stream,
+ stream: Mutex::new(stream),
stream_features,
ns,
+ id,
}
}
/// Send a `<stream:stream>` start tag
- pub fn start(stream: S, jid: Jid, ns: String) -> StreamStart<S> {
+ pub async fn start<'a>(stream: S, jid: Jid, ns: String) -> Result<Self, Error> {
let xmpp_stream = Framed::new(stream, XMPPCodec::new());
- StreamStart::from_stream(xmpp_stream, jid, ns)
+ stream_start::start(xmpp_stream, jid, ns).await
}
/// Unwraps the inner stream
+ // TODO: use this everywhere
pub fn into_inner(self) -> S {
- self.stream.into_inner()
+ self.stream.into_inner().unwrap().into_inner()
}
/// Re-run `start()`
- pub fn restart(self) -> StreamStart<S> {
- Self::start(self.stream.into_inner(), self.jid, self.ns)
+ pub async fn restart<'a>(self) -> Result<Self, Error> {
+ let stream = self.stream.into_inner().unwrap().into_inner();
+ Self::start(stream, self.jid, self.ns).await
}
}
-impl<S: AsyncWrite> XMPPStream<S> {
+impl<S: AsyncRead + AsyncWrite + Unpin> XMPPStream<S> {
/// Convenience method
- pub fn send_stanza<E: Into<Element>>(self, e: E) -> Send<Self> {
+ pub fn send_stanza<E: Into<Element>>(&mut self, e: E) -> Send<Self, Packet> {
self.send(Packet::Stanza(e.into()))
}
}
/// Proxy to self.stream
-impl<S: AsyncWrite> Sink for XMPPStream<S> {
- type SinkItem = <Framed<S, XMPPCodec> as Sink>::SinkItem;
- type SinkError = <Framed<S, XMPPCodec> as Sink>::SinkError;
+impl<S: AsyncRead + AsyncWrite + Unpin> Sink<Packet> for XMPPStream<S> {
+ type Error = crate::Error;
+
+ fn poll_ready(self: Pin<&mut Self>, _ctx: &mut Context) -> Poll<Result<(), Self::Error>> {
+ // Pin::new(&mut self.stream).poll_ready(ctx)
+ // .map_err(|e| e.into())
+ Poll::Ready(Ok(()))
+ }
+
+ fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
+ Pin::new(&mut self.stream.lock().unwrap().deref_mut())
+ .start_send(item)
+ .map_err(|e| e.into())
+ }
- fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
- self.stream.start_send(item)
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+ Pin::new(&mut self.stream.lock().unwrap().deref_mut())
+ .poll_flush(cx)
+ .map_err(|e| e.into())
}
- fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
- self.stream.poll_complete()
+ fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+ Pin::new(&mut self.stream.lock().unwrap().deref_mut())
+ .poll_close(cx)
+ .map_err(|e| e.into())
}
}
/// Proxy to self.stream
-impl<S: AsyncRead> Stream for XMPPStream<S> {
- type Item = <Framed<S, XMPPCodec> as Stream>::Item;
- type Error = <Framed<S, XMPPCodec> as Stream>::Error;
+impl<S: AsyncRead + AsyncWrite + Unpin> Stream for XMPPStream<S> {
+ type Item = Result<Packet, crate::Error>;
- fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
- self.stream.poll()
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
+ Pin::new(&mut self.stream.lock().unwrap().deref_mut())
+ .poll_next(cx)
+ .map(|result| result.map(|result| result.map_err(|e| e.into())))
}
}