tokio-xmpp: rewrite for futures-0.3

Astro created

Change summary

tokio-xmpp/Cargo.toml                   |  17 
tokio-xmpp/examples/contact_addr.rs     |  80 +++----
tokio-xmpp/examples/download_avatars.rs | 246 +++++++++++------------
tokio-xmpp/examples/echo_bot.rs         | 103 +++------
tokio-xmpp/examples/echo_component.rs   |  71 ++----
tokio-xmpp/src/client/auth.rs           | 154 +++++---------
tokio-xmpp/src/client/bind.rs           | 115 +++-------
tokio-xmpp/src/client/mod.rs            | 277 ++++++++++++++++----------
tokio-xmpp/src/component/auth.rs        | 102 ++-------
tokio-xmpp/src/component/mod.rs         | 176 ++++++----------
tokio-xmpp/src/error.rs                 |  11 
tokio-xmpp/src/event.rs                 |   3 
tokio-xmpp/src/happy_eyeballs.rs        | 230 ++++-----------------
tokio-xmpp/src/lib.rs                   |   3 
tokio-xmpp/src/starttls.rs              | 127 ++---------
tokio-xmpp/src/stream_start.rs          | 166 +++++----------
tokio-xmpp/src/xmpp_codec.rs            |  57 ++--
tokio-xmpp/src/xmpp_stream.rs           |  85 +++++--
18 files changed, 801 insertions(+), 1,222 deletions(-)

Detailed changes

tokio-xmpp/Cargo.toml πŸ”—

@@ -1,6 +1,6 @@
 [package]
 name = "tokio-xmpp"
-version = "1.0.1"
+version = "2.0.0"
 authors = ["Astro <astro@spaceboyz.net>", "Emmanuel Gil Peyrot <linkmauve@linkmauve.fr>", "pep <pep+code@bouah.net>", "O01eg <o01eg@yandex.ru>"]
 description = "Asynchronous XMPP for Rust with tokio"
 license = "MPL-2.0"
@@ -12,17 +12,16 @@ keywords = ["xmpp", "tokio"]
 edition = "2018"
 
 [dependencies]
-bytes = "0.4"
-futures = "0.1"
+bytes = "0.5"
+futures = "0.3"
 idna = "0.2"
 log = "0.4"
 native-tls = "0.2"
 sasl = "0.4"
-tokio = "0.1"
-tokio-codec = "0.1"
-trust-dns-resolver = "0.12"
-trust-dns-proto = "0.8"
-tokio-io = "0.1"
-tokio-tls = "0.2"
+tokio = { version = "0.2", features = ["net", "stream", "rt-util", "rt-threaded", "macros"] }
+tokio-util = { version = "0.2", features = ["codec"] }
+tokio-tls = "0.3"
+trust-dns-resolver = "0.19"
+trust-dns-proto = "0.19"
 xml5ever = "0.16"
 xmpp-parsers = "0.17"

tokio-xmpp/examples/contact_addr.rs πŸ”—

@@ -1,9 +1,8 @@
-use futures::{future, Sink, Stream};
+use futures::stream::StreamExt;
 use std::convert::TryFrom;
 use std::env::args;
 use std::process::exit;
-use tokio::runtime::current_thread::Runtime;
-use tokio_xmpp::{xmpp_codec::Packet, Client};
+use tokio_xmpp::Client;
 use xmpp_parsers::{
     disco::{DiscoInfoQuery, DiscoInfoResult},
     iq::{Iq, IqType},
@@ -12,70 +11,55 @@ use xmpp_parsers::{
     Element, Jid,
 };
 
-fn main() {
+#[tokio::main]
+async fn main() {
     let args: Vec<String> = args().collect();
     if args.len() != 4 {
         println!("Usage: {} <jid> <password> <target>", args[0]);
         exit(1);
     }
     let jid = &args[1];
-    let password = &args[2];
+    let password = args[2].clone();
     let target = &args[3];
 
-    // tokio_core context
-    let mut rt = Runtime::new().unwrap();
     // Client instance
-    let client = Client::new(jid, password).unwrap();
+    let mut client = Client::new(jid, password).unwrap();
 
-    // Make the two interfaces for sending and receiving independent
-    // of each other so we can move one into a closure.
-    let (mut sink, stream) = client.split();
-    // Wrap sink in Option so that we can take() it for the send(self)
-    // to consume and return it back when ready.
-    let mut send = move |packet| {
-        sink.start_send(packet).expect("start_send");
-    };
     // Main loop, processes events
     let mut wait_for_stream_end = false;
-    let done = stream.for_each(|event| {
-        if wait_for_stream_end {
-            /* Do Nothing. */
-        } else if event.is_online() {
-            println!("Online!");
+    let mut stream_ended = false;
+    while !stream_ended {
+        if let Some(event) = client.next().await {
+            if wait_for_stream_end {
+                /* Do Nothing. */
+            } else if event.is_online() {
+                println!("Online!");
 
-            let target_jid: Jid = target.clone().parse().unwrap();
-            let iq = make_disco_iq(target_jid);
-            println!("Sending disco#info request to {}", target.clone());
-            println!(">> {}", String::from(&iq));
-            send(Packet::Stanza(iq));
-        } else if let Some(stanza) = event.into_stanza() {
-            if stanza.is("iq", "jabber:client") {
-                let iq = Iq::try_from(stanza).unwrap();
-                if let IqType::Result(Some(payload)) = iq.payload {
-                    if payload.is("query", ns::DISCO_INFO) {
-                        if let Ok(disco_info) = DiscoInfoResult::try_from(payload) {
-                            for ext in disco_info.extensions {
-                                if let Ok(server_info) = ServerInfo::try_from(ext) {
-                                    print_server_info(server_info);
-                                    wait_for_stream_end = true;
-                                    send(Packet::StreamEnd);
+                let target_jid: Jid = target.clone().parse().unwrap();
+                let iq = make_disco_iq(target_jid);
+                println!("Sending disco#info request to {}", target.clone());
+                println!(">> {}", String::from(&iq));
+                client.send_stanza(iq).await.unwrap();
+            } else if let Some(stanza) = event.into_stanza() {
+                if stanza.is("iq", "jabber:client") {
+                    let iq = Iq::try_from(stanza).unwrap();
+                    if let IqType::Result(Some(payload)) = iq.payload {
+                        if payload.is("query", ns::DISCO_INFO) {
+                            if let Ok(disco_info) = DiscoInfoResult::try_from(payload) {
+                                for ext in disco_info.extensions {
+                                    if let Ok(server_info) = ServerInfo::try_from(ext) {
+                                        print_server_info(server_info);
+                                    }
                                 }
                             }
                         }
+                        wait_for_stream_end = true;
+                        client.send_end().await.unwrap();
                     }
                 }
             }
-        }
-
-        Box::new(future::ok(()))
-    });
-
-    // Start polling `done`
-    match rt.block_on(done) {
-        Ok(_) => (),
-        Err(e) => {
-            println!("Fatal: {}", e);
-            ()
+        } else {
+            stream_ended = true;
         }
     }
 }

tokio-xmpp/examples/download_avatars.rs πŸ”—

@@ -1,12 +1,12 @@
-use futures::{future, Future, Sink, Stream};
+use futures::stream::StreamExt;
 use std::convert::TryFrom;
 use std::env::args;
 use std::fs::{create_dir_all, File};
 use std::io::{self, Write};
 use std::process::exit;
 use std::str::FromStr;
-use tokio::runtime::current_thread::Runtime;
-use tokio_xmpp::{Client, Packet};
+use tokio;
+use tokio_xmpp::Client;
 use xmpp_parsers::{
     avatar::{Data as AvatarData, Metadata as AvatarMetadata},
     caps::{compute_disco, hash_caps, Caps},
@@ -22,162 +22,153 @@ use xmpp_parsers::{
         NodeName,
     },
     stanza_error::{DefinedCondition, ErrorType, StanzaError},
-    Jid,
+    Element, Jid,
 };
 
-fn main() {
+#[tokio::main]
+async fn main() {
     let args: Vec<String> = args().collect();
     if args.len() != 3 {
         println!("Usage: {} <jid> <password>", args[0]);
         exit(1);
     }
     let jid = &args[1];
-    let password = &args[2];
+    let password = args[2].clone();
 
-    // tokio_core context
-    let mut rt = Runtime::new().unwrap();
     // Client instance
-    let client = Client::new(jid, password).unwrap();
-
-    // Make the two interfaces for sending and receiving independent
-    // of each other so we can move one into a closure.
-    let (sink, stream) = client.split();
-
-    // Create outgoing pipe
-    let (mut tx, rx) = futures::unsync::mpsc::unbounded();
-    rt.spawn(
-        rx.forward(sink.sink_map_err(|_| panic!("Pipe")))
-            .map(|(rx, mut sink)| {
-                drop(rx);
-                let _ = sink.close();
-            })
-            .map_err(|e| {
-                panic!("Send error: {:?}", e);
-            }),
-    );
+    let mut client = Client::new(jid, password).unwrap();
 
     let disco_info = make_disco();
 
     // Main loop, processes events
     let mut wait_for_stream_end = false;
-    let done = stream.for_each(move |event| {
-        // Helper function to send an iq error.
-        let mut send_error = |to, id, type_, condition, text: &str| {
-            let error = StanzaError::new(type_, condition, "en", text);
-            let iq = Iq::from_error(id, error).with_to(to);
-            tx.start_send(Packet::Stanza(iq.into())).unwrap();
-        };
-
-        if wait_for_stream_end {
-            /* Do nothing */
-        } else if event.is_online() {
-            println!("Online!");
-
-            let caps = get_disco_caps(&disco_info, "https://gitlab.com/xmpp-rs/tokio-xmpp");
-            let presence = make_presence(caps);
-            tx.start_send(Packet::Stanza(presence.into())).unwrap();
-        } else if let Some(stanza) = event.into_stanza() {
-            if stanza.is("iq", "jabber:client") {
-                let iq = Iq::try_from(stanza).unwrap();
-                if let IqType::Get(payload) = iq.payload {
-                    if payload.is("query", ns::DISCO_INFO) {
-                        let query = DiscoInfoQuery::try_from(payload);
-                        match query {
-                            Ok(query) => {
-                                let mut disco = disco_info.clone();
-                                disco.node = query.node;
-                                let iq =
-                                    Iq::from_result(iq.id, Some(disco)).with_to(iq.from.unwrap());
-                                tx.start_send(Packet::Stanza(iq.into())).unwrap();
+    let mut stream_ended = false;
+    while !stream_ended {
+        if let Some(event) = client.next().await {
+            if wait_for_stream_end {
+                /* Do nothing */
+            } else if event.is_online() {
+                println!("Online!");
+
+                let caps = get_disco_caps(&disco_info, "https://gitlab.com/xmpp-rs/tokio-xmpp");
+                let presence = make_presence(caps);
+                client.send_stanza(presence.into()).await.unwrap();
+            } else if let Some(stanza) = event.into_stanza() {
+                if stanza.is("iq", "jabber:client") {
+                    let iq = Iq::try_from(stanza).unwrap();
+                    if let IqType::Get(payload) = iq.payload {
+                        if payload.is("query", ns::DISCO_INFO) {
+                            let query = DiscoInfoQuery::try_from(payload);
+                            match query {
+                                Ok(query) => {
+                                    let mut disco = disco_info.clone();
+                                    disco.node = query.node;
+                                    let iq = Iq::from_result(iq.id, Some(disco))
+                                        .with_to(iq.from.unwrap());
+                                    client.send_stanza(iq.into()).await.unwrap();
+                                }
+                                Err(err) => client
+                                    .send_stanza(make_error(
+                                        iq.from.unwrap(),
+                                        iq.id,
+                                        ErrorType::Modify,
+                                        DefinedCondition::BadRequest,
+                                        &format!("{}", err),
+                                    ))
+                                    .await
+                                    .unwrap(),
                             }
-                            Err(err) => {
-                                send_error(
+                        } else {
+                            // We MUST answer unhandled get iqs with a service-unavailable error.
+                            client
+                                .send_stanza(make_error(
                                     iq.from.unwrap(),
                                     iq.id,
-                                    ErrorType::Modify,
-                                    DefinedCondition::BadRequest,
-                                    &format!("{}", err),
-                                );
-                            }
+                                    ErrorType::Cancel,
+                                    DefinedCondition::ServiceUnavailable,
+                                    "No handler defined for this kind of iq.",
+                                ))
+                                .await
+                                .unwrap();
                         }
-                    } else {
-                        // We MUST answer unhandled get iqs with a service-unavailable error.
-                        send_error(
-                            iq.from.unwrap(),
-                            iq.id,
-                            ErrorType::Cancel,
-                            DefinedCondition::ServiceUnavailable,
-                            "No handler defined for this kind of iq.",
-                        );
-                    }
-                } else if let IqType::Result(Some(payload)) = iq.payload {
-                    if payload.is("pubsub", ns::PUBSUB) {
-                        let pubsub = PubSub::try_from(payload).unwrap();
-                        let from = iq.from.clone().unwrap_or(Jid::from_str(jid).unwrap());
-                        handle_iq_result(pubsub, &from);
+                    } else if let IqType::Result(Some(payload)) = iq.payload {
+                        if payload.is("pubsub", ns::PUBSUB) {
+                            let pubsub = PubSub::try_from(payload).unwrap();
+                            let from = iq.from.clone().unwrap_or(Jid::from_str(jid).unwrap());
+                            handle_iq_result(pubsub, &from);
+                        }
+                    } else if let IqType::Set(_) = iq.payload {
+                        // We MUST answer unhandled set iqs with a service-unavailable error.
+                        client
+                            .send_stanza(make_error(
+                                iq.from.unwrap(),
+                                iq.id,
+                                ErrorType::Cancel,
+                                DefinedCondition::ServiceUnavailable,
+                                "No handler defined for this kind of iq.",
+                            ))
+                            .await
+                            .unwrap();
                     }
-                } else if let IqType::Set(_) = iq.payload {
-                    // We MUST answer unhandled set iqs with a service-unavailable error.
-                    send_error(
-                        iq.from.unwrap(),
-                        iq.id,
-                        ErrorType::Cancel,
-                        DefinedCondition::ServiceUnavailable,
-                        "No handler defined for this kind of iq.",
-                    );
-                }
-            } else if stanza.is("message", "jabber:client") {
-                let message = Message::try_from(stanza).unwrap();
-                let from = message.from.clone().unwrap();
-                if let Some(body) = message.get_best_body(vec!["en"]) {
-                    if body.1 .0 == "die" {
-                        println!("Secret die command triggered by {}", from);
-                        wait_for_stream_end = true;
-                        tx.start_send(Packet::StreamEnd).unwrap();
+                } else if stanza.is("message", "jabber:client") {
+                    let message = Message::try_from(stanza).unwrap();
+                    let from = message.from.clone().unwrap();
+                    if let Some(body) = message.get_best_body(vec!["en"]) {
+                        if body.0 == "die" {
+                            println!("Secret die command triggered by {}", from);
+                            wait_for_stream_end = true;
+                            client.send_end().await.unwrap();
+                        }
                     }
-                }
-                for child in message.payloads {
-                    if child.is("event", ns::PUBSUB_EVENT) {
-                        let event = PubSubEvent::try_from(child).unwrap();
-                        if let PubSubEvent::PublishedItems { node, items } = event {
-                            if node.0 == ns::AVATAR_METADATA {
-                                for item in items.into_iter() {
-                                    let payload = item.payload.clone().unwrap();
-                                    if payload.is("metadata", ns::AVATAR_METADATA) {
-                                        // TODO: do something with these metadata.
-                                        let _metadata = AvatarMetadata::try_from(payload).unwrap();
-                                        println!(
-                                            "{} has published an avatar, downloading...",
-                                            from.clone()
-                                        );
-                                        let iq = download_avatar(from.clone());
-                                        tx.start_send(Packet::Stanza(iq.into())).unwrap();
+                    for child in message.payloads {
+                        if child.is("event", ns::PUBSUB_EVENT) {
+                            let event = PubSubEvent::try_from(child).unwrap();
+                            if let PubSubEvent::PublishedItems { node, items } = event {
+                                if node.0 == ns::AVATAR_METADATA {
+                                    for item in items.into_iter() {
+                                        let payload = item.payload.clone().unwrap();
+                                        if payload.is("metadata", ns::AVATAR_METADATA) {
+                                            // TODO: do something with these metadata.
+                                            let _metadata =
+                                                AvatarMetadata::try_from(payload).unwrap();
+                                            println!(
+                                                "{} has published an avatar, downloading...",
+                                                from.clone()
+                                            );
+                                            let iq = download_avatar(from.clone());
+                                            client.send_stanza(iq.into()).await.unwrap();
+                                        }
                                     }
                                 }
                             }
                         }
                     }
+                } else if stanza.is("presence", "jabber:client") {
+                    // Nothing to do here.
+                    ()
+                } else {
+                    panic!("Unknown stanza: {}", String::from(&stanza));
                 }
-            } else if stanza.is("presence", "jabber:client") {
-                // Nothing to do here.
-            } else {
-                panic!("Unknown stanza: {}", String::from(&stanza));
             }
-        }
-
-        future::ok(())
-    });
-
-    // Start polling `done`
-    match rt.block_on(done) {
-        Ok(_) => (),
-        Err(e) => {
-            println!("Fatal: {}", e);
-            ()
+        } else {
+            println!("stream_ended");
+            stream_ended = true;
         }
     }
 }
 
+fn make_error(
+    to: Jid,
+    id: String,
+    type_: ErrorType,
+    condition: DefinedCondition,
+    text: &str,
+) -> Element {
+    let error = StanzaError::new(type_, condition, "en", text);
+    let iq = Iq::from_error(id, error).with_to(to);
+    iq.into()
+}
+
 fn make_disco() -> DiscoInfoResult {
     let identities = vec![Identity::new("client", "bot", "en", "tokio-xmpp")];
     let features = vec![
@@ -235,6 +226,7 @@ fn handle_iq_result(pubsub: PubSub, from: &Jid) {
     }
 }
 
+// TODO: may use tokio?
 fn save_avatar(from: &Jid, id: String, data: &[u8]) -> io::Result<()> {
     let directory = format!("data/{}", from);
     let filename = format!("data/{}/{}", from, id);

tokio-xmpp/examples/echo_bot.rs πŸ”—

@@ -1,14 +1,15 @@
-use futures::{future, Future, Sink, Stream};
+use futures::stream::StreamExt;
 use std::convert::TryFrom;
 use std::env::args;
 use std::process::exit;
-use tokio::runtime::current_thread::Runtime;
-use tokio_xmpp::{Client, Packet};
+use tokio;
+use tokio_xmpp::Client;
 use xmpp_parsers::message::{Body, Message, MessageType};
 use xmpp_parsers::presence::{Presence, Show as PresenceShow, Type as PresenceType};
 use xmpp_parsers::{Element, Jid};
 
-fn main() {
+#[tokio::main]
+async fn main() {
     let args: Vec<String> = args().collect();
     if args.len() != 3 {
         println!("Usage: {} <jid> <password>", args[0]);
@@ -17,72 +18,50 @@ fn main() {
     let jid = &args[1];
     let password = &args[2];
 
-    // tokio_core context
-    let mut rt = Runtime::new().unwrap();
     // Client instance
-    let client = Client::new(jid, password).unwrap();
-
-    // Make the two interfaces for sending and receiving independent
-    // of each other so we can move one into a closure.
-    let (sink, stream) = client.split();
-
-    // Create outgoing pipe
-    let (mut tx, rx) = futures::unsync::mpsc::unbounded();
-    rt.spawn(
-        rx.forward(sink.sink_map_err(|_| panic!("Pipe")))
-            .map(|(rx, mut sink)| {
-                drop(rx);
-                let _ = sink.close();
-            })
-            .map_err(|e| {
-                panic!("Send error: {:?}", e);
-            }),
-    );
+    let mut client = Client::new(jid, password.to_owned()).unwrap();
+    client.set_reconnect(true);
 
     // Main loop, processes events
     let mut wait_for_stream_end = false;
-    let done = stream.for_each(move |event| {
-        if wait_for_stream_end {
-            /* Do nothing */
-        } else if event.is_online() {
-            let jid = event
-                .get_jid()
-                .map(|jid| format!("{}", jid))
-                .unwrap_or("unknown".to_owned());
-            println!("Online at {}", jid);
+    let mut stream_ended = false;
+    while !stream_ended {
+        if let Some(event) = client.next().await {
+            println!("event: {:?}", event);
+            if wait_for_stream_end {
+                /* Do nothing */
+            } else if event.is_online() {
+                let jid = event
+                    .get_jid()
+                    .map(|jid| format!("{}", jid))
+                    .unwrap_or("unknown".to_owned());
+                println!("Online at {}", jid);
 
-            let presence = make_presence();
-            tx.start_send(Packet::Stanza(presence)).unwrap();
-        } else if let Some(message) = event
-            .into_stanza()
-            .and_then(|stanza| Message::try_from(stanza).ok())
-        {
-            match (message.from, message.bodies.get("")) {
-                (Some(ref from), Some(ref body)) if body.0 == "die" => {
-                    println!("Secret die command triggered by {}", from);
-                    wait_for_stream_end = true;
-                    tx.start_send(Packet::StreamEnd).unwrap();
-                }
-                (Some(ref from), Some(ref body)) => {
-                    if message.type_ != MessageType::Error {
-                        // This is a message we'll echo
-                        let reply = make_reply(from.clone(), &body.0);
-                        tx.start_send(Packet::Stanza(reply)).unwrap();
+                let presence = make_presence();
+                client.send_stanza(presence).await.unwrap();
+            } else if let Some(message) = event
+                .into_stanza()
+                .and_then(|stanza| Message::try_from(stanza).ok())
+            {
+                match (message.from, message.bodies.get("")) {
+                    (Some(ref from), Some(ref body)) if body.0 == "die" => {
+                        println!("Secret die command triggered by {}", from);
+                        wait_for_stream_end = true;
+                        client.send_end().await.unwrap();
+                    }
+                    (Some(ref from), Some(ref body)) => {
+                        if message.type_ != MessageType::Error {
+                            // This is a message we'll echo
+                            let reply = make_reply(from.clone(), &body.0);
+                            client.send_stanza(reply).await.unwrap();
+                        }
                     }
+                    _ => {}
                 }
-                _ => {}
             }
-        }
-
-        future::ok(())
-    });
-
-    // Start polling `done`
-    match rt.block_on(done) {
-        Ok(_) => (),
-        Err(e) => {
-            println!("Fatal: {}", e);
-            ()
+        } else {
+            println!("stream_ended");
+            stream_ended = true;
         }
     }
 }

tokio-xmpp/examples/echo_component.rs πŸ”—

@@ -1,15 +1,15 @@
-use futures::{future, Sink, Stream};
+use futures::stream::StreamExt;
 use std::convert::TryFrom;
 use std::env::args;
 use std::process::exit;
 use std::str::FromStr;
-use tokio::runtime::current_thread::Runtime;
 use tokio_xmpp::Component;
 use xmpp_parsers::message::{Body, Message, MessageType};
 use xmpp_parsers::presence::{Presence, Show as PresenceShow, Type as PresenceType};
 use xmpp_parsers::{Element, Jid};
 
-fn main() {
+#[tokio::main]
+async fn main() {
     let args: Vec<String> = args().collect();
     if args.len() < 3 || args.len() > 5 {
         println!("Usage: {} <jid> <password> [server] [port]", args[0]);
@@ -24,57 +24,38 @@ fn main() {
         .unwrap_or("127.0.0.1".to_owned());
     let port: u16 = args.get(4).unwrap().parse().unwrap_or(5347u16);
 
-    // tokio_core context
-    let mut rt = Runtime::new().unwrap();
     // Component instance
     println!("{} {} {} {}", jid, password, server, port);
-    let component = Component::new(jid, password, server, port).unwrap();
+    let mut component = Component::new(jid, password, server, port).await.unwrap();
 
     // Make the two interfaces for sending and receiving independent
     // of each other so we can move one into a closure.
-    println!("Got it: {}", component.jid.clone());
-    let (mut sink, stream) = component.split();
-    // Wrap sink in Option so that we can take() it for the send(self)
-    // to consume and return it back when ready.
-    let mut send = move |stanza| {
-        sink.start_send(stanza).expect("start_send");
-    };
-    // Main loop, processes events
-    let done = stream.for_each(|event| {
-        if event.is_online() {
-            println!("Online!");
+    println!("Online: {}", component.jid);
+
+    // TODO: replace these hardcoded JIDs
+    let presence = make_presence(
+        Jid::from_str("test@component.linkmauve.fr/coucou").unwrap(),
+        Jid::from_str("linkmauve@linkmauve.fr").unwrap(),
+    );
+    component.send_stanza(presence).await.unwrap();
 
-            // TODO: replace these hardcoded JIDs
-            let presence = make_presence(
-                Jid::from_str("test@component.linkmauve.fr/coucou").unwrap(),
-                Jid::from_str("linkmauve@linkmauve.fr").unwrap(),
-            );
-            send(presence);
-        } else if let Some(message) = event
-            .into_stanza()
-            .and_then(|stanza| Message::try_from(stanza).ok())
-        {
-            // This is a message we'll echo
-            match (message.from, message.bodies.get("")) {
-                (Some(from), Some(body)) => {
-                    if message.type_ != MessageType::Error {
-                        let reply = make_reply(from, &body.0);
-                        send(reply);
+    // Main loop, processes events
+    loop {
+        if let Some(stanza) = component.next().await {
+            if let Some(message) = Message::try_from(stanza).ok() {
+                // This is a message we'll echo
+                match (message.from, message.bodies.get("")) {
+                    (Some(from), Some(body)) => {
+                        if message.type_ != MessageType::Error {
+                            let reply = make_reply(from, &body.0);
+                            component.send_stanza(reply).await.unwrap();
+                        }
                     }
+                    _ => (),
                 }
-                _ => (),
             }
-        }
-
-        Box::new(future::ok(()))
-    });
-
-    // Start polling `done`
-    match rt.block_on(done) {
-        Ok(_) => (),
-        Err(e) => {
-            println!("Fatal: {}", e);
-            ()
+        } else {
+            break;
         }
     }
 }

tokio-xmpp/src/client/auth.rs πŸ”—

@@ -1,7 +1,4 @@
-use futures::{
-    future::{err, ok, IntoFuture},
-    Future, Poll, Stream,
-};
+use futures::stream::StreamExt;
 use sasl::client::mechanisms::{Anonymous, Plain, Scram};
 use sasl::client::Mechanism;
 use sasl::common::scram::{Sha1, Sha256};
@@ -9,7 +6,7 @@ use sasl::common::Credentials;
 use std::collections::HashSet;
 use std::convert::TryFrom;
 use std::str::FromStr;
-use tokio_io::{AsyncRead, AsyncWrite};
+use tokio::io::{AsyncRead, AsyncWrite};
 use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
 
 use crate::xmpp_codec::Packet;
@@ -18,109 +15,70 @@ use crate::{AuthError, Error, ProtocolError};
 
 const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
 
-pub struct ClientAuth<S: AsyncRead + AsyncWrite> {
-    future: Box<dyn Future<Item = XMPPStream<S>, Error = Error>>,
-}
-
-impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> {
-    pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
-        let local_mechs: Vec<Box<dyn Fn() -> Box<dyn Mechanism>>> = vec![
-            Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
-            Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
-            Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
-            Box::new(|| Box::new(Anonymous::new())),
-        ];
+pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
+    mut stream: XMPPStream<S>,
+    creds: Credentials,
+) -> Result<S, Error> {
+    let local_mechs: Vec<Box<dyn Fn() -> Box<dyn Mechanism + Send + Sync> + Send>> = vec![
+        Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
+        Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
+        Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
+        Box::new(|| Box::new(Anonymous::new())),
+    ];
 
-        let remote_mechs: HashSet<String> = stream
-            .stream_features
-            .get_child("mechanisms", NS_XMPP_SASL)
-            .ok_or(AuthError::NoMechanism)?
-            .children()
-            .filter(|child| child.is("mechanism", NS_XMPP_SASL))
-            .map(|mech_el| mech_el.text())
-            .collect();
+    let remote_mechs: HashSet<String> = stream
+        .stream_features
+        .get_child("mechanisms", NS_XMPP_SASL)
+        .ok_or(AuthError::NoMechanism)?
+        .children()
+        .filter(|child| child.is("mechanism", NS_XMPP_SASL))
+        .map(|mech_el| mech_el.text())
+        .collect();
 
-        for local_mech in local_mechs {
-            let mut mechanism = local_mech();
-            if remote_mechs.contains(mechanism.name()) {
-                let initial = mechanism.initial().map_err(AuthError::Sasl)?;
-                let mechanism_name =
-                    XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
+    for local_mech in local_mechs {
+        let mut mechanism = local_mech();
+        if remote_mechs.contains(mechanism.name()) {
+            let initial = mechanism.initial().map_err(AuthError::Sasl)?;
+            let mechanism_name =
+                XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
 
-                let send_initial = Box::new(stream.send_stanza(Auth {
+            stream
+                .send_stanza(Auth {
                     mechanism: mechanism_name,
                     data: initial,
-                }))
-                .map_err(Error::Io);
-                let future = Box::new(
-                    send_initial
-                        .and_then(|stream| Self::handle_challenge(stream, mechanism))
-                        .and_then(|stream| stream.restart()),
-                );
-                return Ok(ClientAuth { future });
-            }
-        }
+                })
+                .await?;
 
-        Err(AuthError::NoMechanism)?
-    }
+            loop {
+                match stream.next().await {
+                    Some(Ok(Packet::Stanza(stanza))) => {
+                        if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
+                            let response = mechanism
+                                .response(&challenge.data)
+                                .map_err(|e| AuthError::Sasl(e))?;
 
-    fn handle_challenge(
-        stream: XMPPStream<S>,
-        mut mechanism: Box<dyn Mechanism>,
-    ) -> Box<dyn Future<Item = XMPPStream<S>, Error = Error>> {
-        Box::new(
-            stream
-                .into_future()
-                .map_err(|(e, _stream)| e.into())
-                .and_then(|(stanza, stream)| {
-                    match stanza {
-                        Some(Packet::Stanza(stanza)) => {
-                            if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
-                                let response = mechanism.response(&challenge.data);
-                                Box::new(
-                                    response
-                                        .map_err(|e| AuthError::Sasl(e).into())
-                                        .into_future()
-                                        .and_then(|response| {
-                                            // Send response and loop
-                                            stream
-                                                .send_stanza(Response { data: response })
-                                                .map_err(Error::Io)
-                                                .and_then(|stream| {
-                                                    Self::handle_challenge(stream, mechanism)
-                                                })
-                                        }),
-                                )
-                            } else if let Ok(_) = Success::try_from(stanza.clone()) {
-                                Box::new(ok(stream))
-                            } else if let Ok(failure) = Failure::try_from(stanza.clone()) {
-                                Box::new(err(Error::Auth(AuthError::Fail(
-                                    failure.defined_condition,
-                                ))))
-                            } else if stanza.name() == "failure" {
-                                // Workaround for https://gitlab.com/xmpp-rs/xmpp-parsers/merge_requests/1
-                                Box::new(err(Error::Auth(AuthError::Sasl("failure".to_string()))))
-                            } else {
-                                // ignore and loop
-                                Self::handle_challenge(stream, mechanism)
-                            }
-                        }
-                        Some(_) => {
+                            // Send response and loop
+                            stream.send_stanza(Response { data: response }).await?;
+                        } else if let Ok(_) = Success::try_from(stanza.clone()) {
+                            return Ok(stream.into_inner());
+                        } else if let Ok(failure) = Failure::try_from(stanza.clone()) {
+                            return Err(Error::Auth(AuthError::Fail(failure.defined_condition)));
+                        } else if stanza.name() == "failure" {
+                            // Workaround for https://gitlab.com/xmpp-rs/xmpp-parsers/merge_requests/1
+                            return Err(Error::Auth(AuthError::Sasl("failure".to_string())));
+                        } else {
                             // ignore and loop
-                            Self::handle_challenge(stream, mechanism)
                         }
-                        None => Box::new(err(Error::Disconnected)),
                     }
-                }),
-        )
+                    Some(Ok(_)) => {
+                        // ignore and loop
+                    }
+                    Some(Err(e)) => return Err(e),
+                    None => return Err(Error::Disconnected),
+                }
+            }
+        }
     }
-}
-
-impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
-    type Item = XMPPStream<S>;
-    type Error = Error;
 
-    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
-        self.future.poll()
-    }
+    Err(AuthError::NoMechanism.into())
 }

tokio-xmpp/src/client/bind.rs πŸ”—

@@ -1,7 +1,7 @@
-use futures::{sink, Async, Future, Poll, Stream};
+use futures::stream::StreamExt;
 use std::convert::TryFrom;
-use std::mem::replace;
-use tokio_io::{AsyncRead, AsyncWrite};
+use std::marker::Unpin;
+use tokio::io::{AsyncRead, AsyncWrite};
 use xmpp_parsers::bind::{BindQuery, BindResponse};
 use xmpp_parsers::iq::{Iq, IqType};
 use xmpp_parsers::Jid;
@@ -13,90 +13,43 @@ use crate::{Error, ProtocolError};
 const NS_XMPP_BIND: &str = "urn:ietf:params:xml:ns:xmpp-bind";
 const BIND_REQ_ID: &str = "resource-bind";
 
-pub enum ClientBind<S: AsyncWrite> {
-    Unsupported(XMPPStream<S>),
-    WaitSend(sink::Send<XMPPStream<S>>),
-    WaitRecv(XMPPStream<S>),
-    Invalid,
-}
-
-impl<S: AsyncWrite> ClientBind<S> {
-    /// Consumes and returns the stream to express that you cannot use
-    /// the stream for anything else until the resource binding
-    /// req/resp are done.
-    pub fn new(stream: XMPPStream<S>) -> Self {
-        match stream.stream_features.get_child("bind", NS_XMPP_BIND) {
-            None =>
+pub async fn bind<S: AsyncRead + AsyncWrite + Unpin>(
+    mut stream: XMPPStream<S>,
+) -> Result<XMPPStream<S>, Error> {
+    match stream.stream_features.get_child("bind", NS_XMPP_BIND) {
+        None => {
             // No resource binding available,
             // return the (probably // usable) stream immediately
-            {
-                ClientBind::Unsupported(stream)
-            }
-            Some(_) => {
-                let resource;
-                if let Jid::Full(jid) = stream.jid.clone() {
-                    resource = Some(jid.resource);
-                } else {
-                    resource = None;
-                }
-                let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource));
-                let send = stream.send_stanza(iq);
-                ClientBind::WaitSend(send)
-            }
+            return Ok(stream);
         }
-    }
-}
+        Some(_) => {
+            let resource = if let Jid::Full(jid) = stream.jid.clone() {
+                Some(jid.resource)
+            } else {
+                None
+            };
+            let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource));
+            stream.send_stanza(iq).await?;
 
-impl<S: AsyncRead + AsyncWrite> Future for ClientBind<S> {
-    type Item = XMPPStream<S>;
-    type Error = Error;
-
-    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
-        let state = replace(self, ClientBind::Invalid);
-
-        match state {
-            ClientBind::Unsupported(stream) => Ok(Async::Ready(stream)),
-            ClientBind::WaitSend(mut send) => match send.poll() {
-                Ok(Async::Ready(stream)) => {
-                    replace(self, ClientBind::WaitRecv(stream));
-                    self.poll()
-                }
-                Ok(Async::NotReady) => {
-                    replace(self, ClientBind::WaitSend(send));
-                    Ok(Async::NotReady)
-                }
-                Err(e) => Err(e)?,
-            },
-            ClientBind::WaitRecv(mut stream) => match stream.poll() {
-                Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => match Iq::try_from(stanza) {
-                    Ok(iq) => {
-                        if iq.id == BIND_REQ_ID {
-                            match iq.payload {
-                                IqType::Result(payload) => {
-                                    payload
-                                        .and_then(|payload| BindResponse::try_from(payload).ok())
-                                        .map(|bind| stream.jid = bind.into());
-                                    Ok(Async::Ready(stream))
-                                }
-                                _ => Err(ProtocolError::InvalidBindResponse)?,
+            loop {
+                match stream.next().await {
+                    Some(Ok(Packet::Stanza(stanza))) => match Iq::try_from(stanza) {
+                        Ok(iq) if iq.id == BIND_REQ_ID => match iq.payload {
+                            IqType::Result(payload) => {
+                                payload
+                                    .and_then(|payload| BindResponse::try_from(payload).ok())
+                                    .map(|bind| stream.jid = bind.into());
+                                return Ok(stream);
                             }
-                        } else {
-                            Ok(Async::NotReady)
-                        }
-                    }
-                    _ => Ok(Async::NotReady),
-                },
-                Ok(Async::Ready(_)) => {
-                    replace(self, ClientBind::WaitRecv(stream));
-                    self.poll()
-                }
-                Ok(Async::NotReady) => {
-                    replace(self, ClientBind::WaitRecv(stream));
-                    Ok(Async::NotReady)
+                            _ => return Err(ProtocolError::InvalidBindResponse.into()),
+                        },
+                        _ => {}
+                    },
+                    Some(Ok(_)) => {}
+                    Some(Err(e)) => return Err(e),
+                    None => return Err(Error::Disconnected),
                 }
-                Err(e) => Err(e)?,
-            },
-            ClientBind::Invalid => unreachable!(),
+            }
         }
     }
 }

tokio-xmpp/src/client/mod.rs πŸ”—

@@ -1,28 +1,33 @@
-use futures::{done, Async, AsyncSink, Future, Poll, Sink, StartSend, Stream};
+use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream};
 use idna;
 use sasl::common::{ChannelBinding, Credentials};
 use std::mem::replace;
+use std::pin::Pin;
 use std::str::FromStr;
+use std::task::Context;
+use tokio::io::{AsyncRead, AsyncWrite};
 use tokio::net::TcpStream;
-use tokio_io::{AsyncRead, AsyncWrite};
+use tokio::task::JoinHandle;
+use tokio::task::LocalSet;
 use tokio_tls::TlsStream;
-use xmpp_parsers::{Jid, JidParseError};
+use xmpp_parsers::{Element, Jid, JidParseError};
 
 use super::event::Event;
-use super::happy_eyeballs::Connecter;
-use super::starttls::{StartTlsClient, NS_XMPP_TLS};
+use super::happy_eyeballs::connect;
+use super::starttls::{starttls, NS_XMPP_TLS};
 use super::xmpp_codec::Packet;
 use super::xmpp_stream;
 use super::{Error, ProtocolError};
 
 mod auth;
-use self::auth::ClientAuth;
 mod bind;
-use self::bind::ClientBind;
 
 /// XMPP client connection and state
 pub struct Client {
     state: ClientState,
+    jid: Jid,
+    password: String,
+    reconnect: bool,
 }
 
 type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>;
@@ -31,7 +36,7 @@ const NS_JABBER_CLIENT: &str = "jabber:client";
 enum ClientState {
     Invalid,
     Disconnected,
-    Connecting(Box<dyn Future<Item = XMPPStream, Error = Error>>),
+    Connecting(JoinHandle<Result<XMPPStream, Error>>, LocalSet),
     Connected(XMPPStream),
 }
 
@@ -40,87 +45,87 @@ impl Client {
     ///
     /// Start polling the returned instance so that it will connect
     /// and yield events.
-    pub fn new(jid: &str, password: &str) -> Result<Self, JidParseError> {
+    pub fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, JidParseError> {
         let jid = Jid::from_str(jid)?;
-        let client = Self::new_with_jid(jid, password);
+        let client = Self::new_with_jid(jid, password.into());
         Ok(client)
     }
 
     /// Start a new client given that the JID is already parsed.
-    pub fn new_with_jid(jid: Jid, password: &str) -> Self {
-        let password = password.to_owned();
-        let connect = Self::make_connect(jid, password.clone());
+    pub fn new_with_jid(jid: Jid, password: String) -> Self {
+        let local = LocalSet::new();
+        let connect = local.spawn_local(Self::connect(jid.clone(), password.clone()));
         let client = Client {
-            state: ClientState::Connecting(Box::new(connect)),
+            jid,
+            password,
+            state: ClientState::Connecting(connect, local),
+            reconnect: false,
         };
         client
     }
 
-    fn make_connect(jid: Jid, password: String) -> impl Future<Item = XMPPStream, Error = Error> {
+    /// Set whether to reconnect (`true`) or end the stream (`false`)
+    /// when a connection to the server has ended.
+    pub fn set_reconnect(&mut self, reconnect: bool) -> &mut Self {
+        self.reconnect = reconnect;
+        self
+    }
+
+    async fn connect(jid: Jid, password: String) -> Result<XMPPStream, Error> {
         let username = jid.clone().node().unwrap();
-        let jid1 = jid.clone();
-        let jid2 = jid.clone();
         let password = password;
-        done(idna::domain_to_ascii(&jid.domain()))
-            .map_err(|_| Error::Idna)
-            .and_then(|domain| {
-                done(Connecter::from_lookup(
-                    &domain,
-                    Some("_xmpp-client._tcp"),
-                    5222,
-                ))
-            })
-            .flatten()
-            .and_then(move |tcp_stream| {
-                xmpp_stream::XMPPStream::start(tcp_stream, jid1, NS_JABBER_CLIENT.to_owned())
-            })
-            .and_then(|xmpp_stream| {
-                if Self::can_starttls(&xmpp_stream) {
-                    Ok(Self::starttls(xmpp_stream))
-                } else {
-                    Err(Error::Protocol(ProtocolError::NoTls))
-                }
-            })
-            .flatten()
-            .and_then(|tls_stream| XMPPStream::start(tls_stream, jid2, NS_JABBER_CLIENT.to_owned()))
-            .and_then(
-                move |xmpp_stream| done(Self::auth(xmpp_stream, username, password)), // TODO: flatten?
-            )
-            .and_then(|auth| auth)
-            .and_then(|xmpp_stream| Self::bind(xmpp_stream))
-            .and_then(|xmpp_stream| {
-                // println!("Bound to {}", xmpp_stream.jid);
-                Ok(xmpp_stream)
-            })
-    }
-
-    fn can_starttls<S>(stream: &xmpp_stream::XMPPStream<S>) -> bool {
-        stream
+        let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?;
+
+        let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), 5222).await?;
+
+        let xmpp_stream =
+            xmpp_stream::XMPPStream::start(tcp_stream, jid, NS_JABBER_CLIENT.to_owned()).await?;
+        let xmpp_stream = if Self::can_starttls(&xmpp_stream) {
+            Self::starttls(xmpp_stream).await?
+        } else {
+            return Err(Error::Protocol(ProtocolError::NoTls));
+        };
+
+        let xmpp_stream = Self::auth(xmpp_stream, username, password).await?;
+        let xmpp_stream = Self::bind(xmpp_stream).await?;
+        Ok(xmpp_stream)
+    }
+
+    fn can_starttls<S: AsyncRead + AsyncWrite + Unpin>(
+        xmpp_stream: &xmpp_stream::XMPPStream<S>,
+    ) -> bool {
+        xmpp_stream
             .stream_features
             .get_child("starttls", NS_XMPP_TLS)
             .is_some()
     }
 
-    fn starttls<S: AsyncRead + AsyncWrite>(
-        stream: xmpp_stream::XMPPStream<S>,
-    ) -> StartTlsClient<S> {
-        StartTlsClient::from_stream(stream)
+    async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
+        xmpp_stream: xmpp_stream::XMPPStream<S>,
+    ) -> Result<xmpp_stream::XMPPStream<TlsStream<S>>, Error> {
+        let jid = xmpp_stream.jid.clone();
+        let tls_stream = starttls(xmpp_stream).await?;
+        xmpp_stream::XMPPStream::start(tls_stream, jid, NS_JABBER_CLIENT.to_owned()).await
     }
 
-    fn auth<S: AsyncRead + AsyncWrite + 'static>(
-        stream: xmpp_stream::XMPPStream<S>,
+    async fn auth<S: AsyncRead + AsyncWrite + Unpin + 'static>(
+        xmpp_stream: xmpp_stream::XMPPStream<S>,
         username: String,
         password: String,
-    ) -> Result<ClientAuth<S>, Error> {
+    ) -> Result<xmpp_stream::XMPPStream<S>, Error> {
+        let jid = xmpp_stream.jid.clone();
         let creds = Credentials::default()
             .with_username(username)
             .with_password(password)
             .with_channel_binding(ChannelBinding::None);
-        ClientAuth::new(stream, creds)
+        let stream = auth::auth(xmpp_stream, creds).await?;
+        xmpp_stream::XMPPStream::start(stream, jid, NS_JABBER_CLIENT.to_owned()).await
     }
 
-    fn bind<S: AsyncWrite>(stream: xmpp_stream::XMPPStream<S>) -> ClientBind<S> {
-        ClientBind::new(stream)
+    async fn bind<S: Unpin + AsyncRead + AsyncWrite>(
+        stream: xmpp_stream::XMPPStream<S>,
+    ) -> Result<xmpp_stream::XMPPStream<S>, Error> {
+        bind::bind(stream).await
     }
 
     /// Get the client's bound JID (the one reported by the XMPP
@@ -131,102 +136,150 @@ impl Client {
             _ => None,
         }
     }
+
+    /// Send stanza
+    pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> {
+        self.send(Packet::Stanza(stanza)).await
+    }
+
+    /// End connection
+    pub async fn send_end(&mut self) -> Result<(), Error> {
+        self.send(Packet::StreamEnd).await
+    }
 }
 
 impl Stream for Client {
     type Item = Event;
-    type Error = Error;
 
-    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
         let state = replace(&mut self.state, ClientState::Invalid);
 
         match state {
-            ClientState::Invalid => Err(Error::InvalidState),
-            ClientState::Disconnected => Ok(Async::Ready(None)),
-            ClientState::Connecting(mut connect) => match connect.poll() {
-                Ok(Async::Ready(stream)) => {
-                    let jid = stream.jid.clone();
-                    self.state = ClientState::Connected(stream);
-                    Ok(Async::Ready(Some(Event::Online(jid))))
-                }
-                Ok(Async::NotReady) => {
-                    self.state = ClientState::Connecting(connect);
-                    Ok(Async::NotReady)
+            ClientState::Invalid => panic!("Invalid client state"),
+            ClientState::Disconnected if self.reconnect => {
+                // TODO: add timeout
+                let mut local = LocalSet::new();
+                let connect =
+                    local.spawn_local(Self::connect(self.jid.clone(), self.password.clone()));
+                let _ = Pin::new(&mut local).poll(cx);
+                self.state = ClientState::Connecting(connect, local);
+                self.poll_next(cx)
+            }
+            ClientState::Disconnected => Poll::Ready(None),
+            ClientState::Connecting(mut connect, mut local) => {
+                match Pin::new(&mut connect).poll(cx) {
+                    Poll::Ready(Ok(Ok(stream))) => {
+                        let bound_jid = stream.jid.clone();
+                        self.state = ClientState::Connected(stream);
+                        Poll::Ready(Some(Event::Online(bound_jid)))
+                    }
+                    Poll::Ready(Ok(Err(e))) => {
+                        self.state = ClientState::Disconnected;
+                        return Poll::Ready(Some(Event::Disconnected(e.into())));
+                    }
+                    Poll::Ready(Err(e)) => {
+                        self.state = ClientState::Disconnected;
+                        panic!("connect task: {}", e);
+                    }
+                    Poll::Pending => {
+                        let _ = Pin::new(&mut local).poll(cx);
+
+                        self.state = ClientState::Connecting(connect, local);
+                        Poll::Pending
+                    }
                 }
-                Err(e) => Err(e),
-            },
+            }
             ClientState::Connected(mut stream) => {
                 // Poll sink
-                match stream.poll_complete() {
-                    Ok(Async::NotReady) => (),
-                    Ok(Async::Ready(())) => (),
-                    Err(e) => return Err(e)?,
+                match Pin::new(&mut stream).poll_ready(cx) {
+                    Poll::Pending => (),
+                    Poll::Ready(Ok(())) => (),
+                    Poll::Ready(Err(e)) => {
+                        self.state = ClientState::Disconnected;
+                        return Poll::Ready(Some(Event::Disconnected(e.into())));
+                    }
                 };
 
                 // Poll stream
-                match stream.poll() {
-                    Ok(Async::Ready(None)) => {
+                match Pin::new(&mut stream).poll_next(cx) {
+                    Poll::Ready(None) => {
                         // EOF
                         self.state = ClientState::Disconnected;
-                        Ok(Async::Ready(Some(Event::Disconnected)))
+                        Poll::Ready(Some(Event::Disconnected(Error::Disconnected)))
                     }
-                    Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
+                    Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => {
                         // Receive stanza
                         self.state = ClientState::Connected(stream);
-                        Ok(Async::Ready(Some(Event::Stanza(stanza))))
+                        Poll::Ready(Some(Event::Stanza(stanza)))
                     }
-                    Ok(Async::Ready(Some(Packet::Text(_)))) => {
+                    Poll::Ready(Some(Ok(Packet::Text(_)))) => {
                         // Ignore text between stanzas
                         self.state = ClientState::Connected(stream);
-                        Ok(Async::NotReady)
+                        Poll::Pending
                     }
-                    Ok(Async::Ready(Some(Packet::StreamStart(_)))) => {
+                    Poll::Ready(Some(Ok(Packet::StreamStart(_)))) => {
                         // <stream:stream>
-                        Err(ProtocolError::InvalidStreamStart.into())
+                        self.state = ClientState::Disconnected;
+                        Poll::Ready(Some(Event::Disconnected(
+                            ProtocolError::InvalidStreamStart.into(),
+                        )))
                     }
-                    Ok(Async::Ready(Some(Packet::StreamEnd))) => {
+                    Poll::Ready(Some(Ok(Packet::StreamEnd))) => {
                         // End of stream: </stream:stream>
-                        Ok(Async::Ready(None))
+                        self.state = ClientState::Disconnected;
+                        Poll::Ready(Some(Event::Disconnected(Error::Disconnected)))
                     }
-                    Ok(Async::NotReady) => {
+                    Poll::Pending => {
                         // Try again later
                         self.state = ClientState::Connected(stream);
-                        Ok(Async::NotReady)
+                        Poll::Pending
+                    }
+                    Poll::Ready(Some(Err(e))) => {
+                        self.state = ClientState::Disconnected;
+                        Poll::Ready(Some(Event::Disconnected(e.into())))
                     }
-                    Err(e) => Err(e)?,
                 }
             }
         }
     }
 }
 
-impl Sink for Client {
-    type SinkItem = Packet;
-    type SinkError = Error;
+impl Sink<Packet> for Client {
+    type Error = Error;
 
-    fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
+    fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
         match self.state {
-            ClientState::Connected(ref mut stream) => Ok(stream.start_send(item)?),
-            _ => Ok(AsyncSink::NotReady(item)),
+            ClientState::Connected(ref mut stream) => {
+                Pin::new(stream).start_send(item).map_err(|e| e.into())
+            }
+            _ => Err(Error::InvalidState),
         }
     }
 
-    fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
+    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
         match self.state {
-            ClientState::Connected(ref mut stream) => stream.poll_complete().map_err(|e| e.into()),
-            _ => Ok(Async::Ready(())),
+            ClientState::Connected(ref mut stream) => {
+                Pin::new(stream).poll_ready(cx).map_err(|e| e.into())
+            }
+            _ => Poll::Pending,
         }
     }
 
-    /// This closes the inner TCP stream.
-    ///
-    /// To synchronize your shutdown with the server side, you should
-    /// first send `Packet::StreamEnd` and wait for the end of the
-    /// incoming stream before closing the connection.
-    fn close(&mut self) -> Poll<(), Self::SinkError> {
+    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
         match self.state {
-            ClientState::Connected(ref mut stream) => stream.close().map_err(|e| e.into()),
-            _ => Ok(Async::Ready(())),
+            ClientState::Connected(ref mut stream) => {
+                Pin::new(stream).poll_flush(cx).map_err(|e| e.into())
+            }
+            _ => Poll::Pending,
+        }
+    }
+
+    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+        match self.state {
+            ClientState::Connected(ref mut stream) => {
+                Pin::new(stream).poll_close(cx).map_err(|e| e.into())
+            }
+            _ => Poll::Pending,
         }
     }
 }

tokio-xmpp/src/component/auth.rs πŸ”—

@@ -1,6 +1,6 @@
-use futures::{sink, Async, Future, Poll, Stream};
-use std::mem::replace;
-use tokio_io::{AsyncRead, AsyncWrite};
+use futures::stream::StreamExt;
+use std::marker::Unpin;
+use tokio::io::{AsyncRead, AsyncWrite};
 use xmpp_parsers::component::Handshake;
 
 use crate::xmpp_codec::Packet;
@@ -9,81 +9,27 @@ use crate::{AuthError, Error};
 
 const NS_JABBER_COMPONENT_ACCEPT: &str = "jabber:component:accept";
 
-pub struct ComponentAuth<S: AsyncWrite> {
-    state: ComponentAuthState<S>,
-}
-
-enum ComponentAuthState<S: AsyncWrite> {
-    WaitSend(sink::Send<XMPPStream<S>>),
-    WaitRecv(XMPPStream<S>),
-    Invalid,
-}
-
-impl<S: AsyncWrite> ComponentAuth<S> {
-    // TODO: doesn't have to be a Result<> actually
-    pub fn new(stream: XMPPStream<S>, password: String) -> Result<Self, Error> {
-        // FIXME: huge hack, shouldn’t be an element!
-        let sid = stream.stream_features.name().to_owned();
-        let mut this = ComponentAuth {
-            state: ComponentAuthState::Invalid,
-        };
-        this.send(
-            stream,
-            Handshake::from_password_and_stream_id(&password, &sid),
-        );
-        Ok(this)
-    }
-
-    fn send(&mut self, stream: XMPPStream<S>, handshake: Handshake) {
-        let nonza = handshake;
-        let send = stream.send_stanza(nonza);
-
-        self.state = ComponentAuthState::WaitSend(send);
-    }
-}
-
-impl<S: AsyncRead + AsyncWrite> Future for ComponentAuth<S> {
-    type Item = XMPPStream<S>;
-    type Error = Error;
-
-    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
-        let state = replace(&mut self.state, ComponentAuthState::Invalid);
-
-        match state {
-            ComponentAuthState::WaitSend(mut send) => match send.poll() {
-                Ok(Async::Ready(stream)) => {
-                    self.state = ComponentAuthState::WaitRecv(stream);
-                    self.poll()
-                }
-                Ok(Async::NotReady) => {
-                    self.state = ComponentAuthState::WaitSend(send);
-                    Ok(Async::NotReady)
-                }
-                Err(e) => Err(e)?,
-            },
-            ComponentAuthState::WaitRecv(mut stream) => match stream.poll() {
-                Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
-                    if stanza.is("handshake", NS_JABBER_COMPONENT_ACCEPT) =>
-                {
-                    self.state = ComponentAuthState::Invalid;
-                    Ok(Async::Ready(stream))
-                }
-                Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
-                    if stanza.is("error", "http://etherx.jabber.org/streams") =>
-                {
-                    Err(AuthError::ComponentFail.into())
-                }
-                Ok(Async::Ready(_event)) => {
-                    // println!("ComponentAuth ignore {:?}", _event);
-                    Ok(Async::NotReady)
-                }
-                Ok(_) => {
-                    self.state = ComponentAuthState::WaitRecv(stream);
-                    Ok(Async::NotReady)
-                }
-                Err(e) => Err(e)?,
-            },
-            ComponentAuthState::Invalid => unreachable!(),
+pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
+    stream: &mut XMPPStream<S>,
+    password: String,
+) -> Result<(), Error> {
+    let nonza = Handshake::from_password_and_stream_id(&password, &stream.id);
+    stream.send_stanza(nonza).await?;
+
+    loop {
+        match stream.next().await {
+            Some(Ok(Packet::Stanza(ref stanza)))
+                if stanza.is("handshake", NS_JABBER_COMPONENT_ACCEPT) =>
+            {
+                return Ok(());
+            }
+            Some(Ok(Packet::Stanza(ref stanza)))
+                if stanza.is("error", "http://etherx.jabber.org/streams") =>
+            {
+                return Err(AuthError::ComponentFail.into());
+            }
+            Some(_) => {}
+            None => return Err(Error::Disconnected),
         }
     }
 }

tokio-xmpp/src/component/mod.rs πŸ”—

@@ -1,163 +1,115 @@
 //! Components in XMPP are services/gateways that are logged into an
 //! XMPP server under a JID consisting of just a domain name. They are
 //! allowed to use any user and resource identifiers in their stanzas.
-use futures::{done, Async, AsyncSink, Future, Poll, Sink, StartSend, Stream};
-use std::mem::replace;
+use futures::{sink::SinkExt, task::Poll, Sink, Stream};
+use std::pin::Pin;
 use std::str::FromStr;
+use std::task::Context;
 use tokio::net::TcpStream;
-use tokio_io::{AsyncRead, AsyncWrite};
-use xmpp_parsers::{Element, Jid, JidParseError};
+use xmpp_parsers::{Element, Jid};
 
-use super::event::Event;
-use super::happy_eyeballs::Connecter;
+use super::happy_eyeballs::connect;
 use super::xmpp_codec::Packet;
 use super::xmpp_stream;
 use super::Error;
 
 mod auth;
-use self::auth::ComponentAuth;
 
 /// Component connection to an XMPP server
+///
+/// This simplifies the `XMPPStream` to a `Stream`/`Sink` of `Element`
+/// (stanzas). Connection handling however is up to the user.
 pub struct Component {
     /// The component's Jabber-Id
     pub jid: Jid,
-    state: ComponentState,
+    stream: XMPPStream,
 }
 
 type XMPPStream = xmpp_stream::XMPPStream<TcpStream>;
 const NS_JABBER_COMPONENT_ACCEPT: &str = "jabber:component:accept";
 
-enum ComponentState {
-    Invalid,
-    Disconnected,
-    Connecting(Box<dyn Future<Item = XMPPStream, Error = Error>>),
-    Connected(XMPPStream),
-}
-
 impl Component {
     /// Start a new XMPP component
-    ///
-    /// Start polling the returned instance so that it will connect
-    /// and yield events.
-    pub fn new(jid: &str, password: &str, server: &str, port: u16) -> Result<Self, JidParseError> {
+    pub async fn new(jid: &str, password: &str, server: &str, port: u16) -> Result<Self, Error> {
         let jid = Jid::from_str(jid)?;
         let password = password.to_owned();
-        let connect = Self::make_connect(jid.clone(), password, server, port);
-        Ok(Component {
-            jid,
-            state: ComponentState::Connecting(Box::new(connect)),
-        })
+        let stream = Self::connect(jid.clone(), password, server, port).await?;
+        Ok(Component { jid, stream })
     }
 
-    fn make_connect(
+    async fn connect(
         jid: Jid,
         password: String,
         server: &str,
         port: u16,
-    ) -> impl Future<Item = XMPPStream, Error = Error> {
-        let jid1 = jid.clone();
+    ) -> Result<XMPPStream, Error> {
         let password = password;
-        done(Connecter::from_lookup(server, None, port))
-            .flatten()
-            .and_then(move |tcp_stream| {
-                xmpp_stream::XMPPStream::start(
-                    tcp_stream,
-                    jid1,
-                    NS_JABBER_COMPONENT_ACCEPT.to_owned(),
-                )
-            })
-            .and_then(move |xmpp_stream| Self::auth(xmpp_stream, password).expect("auth"))
+        let tcp_stream = connect(server, None, port).await?;
+        let mut xmpp_stream =
+            xmpp_stream::XMPPStream::start(tcp_stream, jid, NS_JABBER_COMPONENT_ACCEPT.to_owned())
+                .await?;
+        auth::auth(&mut xmpp_stream, password).await?;
+        Ok(xmpp_stream)
     }
 
-    fn auth<S: AsyncRead + AsyncWrite>(
-        stream: xmpp_stream::XMPPStream<S>,
-        password: String,
-    ) -> Result<ComponentAuth<S>, Error> {
-        ComponentAuth::new(stream, password)
+    /// Send stanza
+    pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> {
+        self.send(stanza).await
+    }
+
+    /// End connection
+    pub async fn send_end(&mut self) -> Result<(), Error> {
+        self.close().await
     }
 }
 
 impl Stream for Component {
-    type Item = Event;
-    type Error = Error;
-
-    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
-        let state = replace(&mut self.state, ComponentState::Invalid);
+    type Item = Element;
 
-        match state {
-            ComponentState::Invalid => Err(Error::InvalidState),
-            ComponentState::Disconnected => Ok(Async::Ready(None)),
-            ComponentState::Connecting(mut connect) => match connect.poll() {
-                Ok(Async::Ready(stream)) => {
-                    self.state = ComponentState::Connected(stream);
-                    Ok(Async::Ready(Some(Event::Online(self.jid.clone()))))
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
+        loop {
+            match Pin::new(&mut self.stream).poll_next(cx) {
+                Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => return Poll::Ready(Some(stanza)),
+                Poll::Ready(Some(Ok(Packet::Text(_)))) => {
+                    // retry
                 }
-                Ok(Async::NotReady) => {
-                    self.state = ComponentState::Connecting(connect);
-                    Ok(Async::NotReady)
-                }
-                Err(e) => Err(e),
-            },
-            ComponentState::Connected(mut stream) => {
-                // Poll sink
-                match stream.poll_complete() {
-                    Ok(Async::NotReady) => (),
-                    Ok(Async::Ready(())) => (),
-                    Err(e) => return Err(e)?,
-                };
-
-                // Poll stream
-                match stream.poll() {
-                    Ok(Async::NotReady) => {
-                        self.state = ComponentState::Connected(stream);
-                        Ok(Async::NotReady)
-                    }
-                    Ok(Async::Ready(None)) => {
-                        // EOF
-                        self.state = ComponentState::Disconnected;
-                        Ok(Async::Ready(Some(Event::Disconnected)))
-                    }
-                    Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
-                        self.state = ComponentState::Connected(stream);
-                        Ok(Async::Ready(Some(Event::Stanza(stanza))))
-                    }
-                    Ok(Async::Ready(_)) => {
-                        self.state = ComponentState::Connected(stream);
-                        Ok(Async::NotReady)
-                    }
-                    Err(e) => Err(e)?,
+                Poll::Ready(Some(Ok(_))) =>
+                // unexpected
+                {
+                    return Poll::Ready(None)
                 }
+                Poll::Ready(Some(Err(_))) => return Poll::Ready(None),
+                Poll::Ready(None) => return Poll::Ready(None),
+                Poll::Pending => return Poll::Pending,
             }
         }
     }
 }
 
-impl Sink for Component {
-    type SinkItem = Element;
-    type SinkError = Error;
+impl Sink<Element> for Component {
+    type Error = Error;
 
-    fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
-        match self.state {
-            ComponentState::Connected(ref mut stream) => match stream
-                .start_send(Packet::Stanza(item))
-            {
-                Ok(AsyncSink::NotReady(Packet::Stanza(stanza))) => Ok(AsyncSink::NotReady(stanza)),
-                Ok(AsyncSink::NotReady(_)) => {
-                    panic!("Component.start_send with stanza but got something else back")
-                }
-                Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
-                Err(e) => Err(e)?,
-            },
-            _ => Ok(AsyncSink::NotReady(item)),
-        }
+    fn start_send(mut self: Pin<&mut Self>, item: Element) -> Result<(), Self::Error> {
+        Pin::new(&mut self.stream)
+            .start_send(Packet::Stanza(item))
+            .map_err(|e| e.into())
     }
 
-    fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
-        match &mut self.state {
-            &mut ComponentState::Connected(ref mut stream) => {
-                stream.poll_complete().map_err(|e| e.into())
-            }
-            _ => Ok(Async::Ready(())),
-        }
+    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+        Pin::new(&mut self.stream)
+            .poll_ready(cx)
+            .map_err(|e| e.into())
+    }
+
+    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+        Pin::new(&mut self.stream)
+            .poll_flush(cx)
+            .map_err(|e| e.into())
+    }
+
+    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+        Pin::new(&mut self.stream)
+            .poll_close(cx)
+            .map_err(|e| e.into())
     }
 }

tokio-xmpp/src/error.rs πŸ”—

@@ -8,7 +8,7 @@ use trust_dns_proto::error::ProtoError;
 use trust_dns_resolver::error::ResolveError;
 
 use xmpp_parsers::sasl::DefinedCondition as SaslDefinedCondition;
-use xmpp_parsers::Error as ParsersError;
+use xmpp_parsers::{Error as ParsersError, JidParseError};
 
 /// Top-level error type
 #[derive(Debug)]
@@ -20,6 +20,8 @@ pub enum Error {
     /// DNS label conversion error, no details available from module
     /// `idna`
     Idna,
+    /// Error parsing Jabber-Id
+    JidParse(JidParseError),
     /// Protocol-level error
     Protocol(ProtocolError),
     /// Authentication error
@@ -38,6 +40,7 @@ impl fmt::Display for Error {
             Error::Io(e) => write!(fmt, "IO error: {}", e),
             Error::Connection(e) => write!(fmt, "connection error: {}", e),
             Error::Idna => write!(fmt, "IDNA error"),
+            Error::JidParse(e) => write!(fmt, "jid parse error: {}", e),
             Error::Protocol(e) => write!(fmt, "protocol error: {}", e),
             Error::Auth(e) => write!(fmt, "authentication error: {}", e),
             Error::Tls(e) => write!(fmt, "TLS error: {}", e),
@@ -59,6 +62,12 @@ impl From<ConnecterError> for Error {
     }
 }
 
+impl From<JidParseError> for Error {
+    fn from(e: JidParseError) -> Self {
+        Error::JidParse(e)
+    }
+}
+
 impl From<ProtocolError> for Error {
     fn from(e: ProtocolError) -> Self {
         Error::Protocol(e)

tokio-xmpp/src/event.rs πŸ”—

@@ -1,3 +1,4 @@
+use super::Error;
 use xmpp_parsers::{Element, Jid};
 
 /// High-level event on the Stream implemented by Client and Component
@@ -6,7 +7,7 @@ pub enum Event {
     /// Stream is connected and initialized
     Online(Jid),
     /// Stream end
-    Disconnected,
+    Disconnected(Error),
     /// Received stanza/nonza
     Stanza(Element),
 }

tokio-xmpp/src/happy_eyeballs.rs πŸ”—

@@ -1,195 +1,63 @@
 use crate::{ConnecterError, Error};
-use futures::{Async, Future, Poll};
-use std::cell::RefCell;
-use std::collections::BTreeMap;
-use std::collections::VecDeque;
-use std::io::Error as IoError;
-use std::mem;
 use std::net::SocketAddr;
-use tokio::net::tcp::ConnectFuture;
 use tokio::net::TcpStream;
-use trust_dns_resolver::config::LookupIpStrategy;
-use trust_dns_resolver::lookup::SrvLookupFuture;
-use trust_dns_resolver::lookup_ip::LookupIpFuture;
-use trust_dns_resolver::{AsyncResolver, Background, BackgroundLookup, IntoName, Name};
-
-enum State {
-    ResolveSrv(AsyncResolver, BackgroundLookup<SrvLookupFuture>),
-    ResolveTarget(AsyncResolver, Background<LookupIpFuture>, u16),
-    Connecting(Option<AsyncResolver>, Vec<RefCell<ConnectFuture>>),
-    Invalid,
+use trust_dns_resolver::{IntoName, TokioAsyncResolver};
+
+async fn connect_to_host(
+    resolver: &TokioAsyncResolver,
+    host: &str,
+    port: u16,
+) -> Result<TcpStream, Error> {
+    let ips = resolver
+        .lookup_ip(host)
+        .await
+        .map_err(ConnecterError::Resolve)?;
+    for ip in ips.iter() {
+        match TcpStream::connect(&SocketAddr::new(ip, port)).await {
+            Ok(stream) => return Ok(stream),
+            Err(_) => {}
+        }
+    }
+    Err(Error::Disconnected)
 }
 
-pub struct Connecter {
+pub async fn connect(
+    domain: &str,
+    srv: Option<&str>,
     fallback_port: u16,
-    srv_domain: Option<Name>,
-    domain: Name,
-    state: State,
-    targets: VecDeque<(Name, u16)>,
-    error: Option<Error>,
-}
-
-fn resolver() -> Result<AsyncResolver, IoError> {
-    let (config, mut opts) = trust_dns_resolver::system_conf::read_system_conf()?;
-    opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
-    let (resolver, resolver_background) = AsyncResolver::new(config, opts);
-    tokio::runtime::current_thread::spawn(resolver_background);
-    Ok(resolver)
-}
-
-impl Connecter {
-    pub fn from_lookup(
-        domain: &str,
-        srv: Option<&str>,
-        fallback_port: u16,
-    ) -> Result<Connecter, Error> {
-        if let Ok(ip) = domain.parse() {
-            // use specified IP address, not domain name, skip the whole dns part
-            let connect = RefCell::new(TcpStream::connect(&SocketAddr::new(ip, fallback_port)));
-            return Ok(Connecter {
-                fallback_port,
-                srv_domain: None,
-                domain: "nohost".into_name().map_err(ConnecterError::Dns)?,
-                state: State::Connecting(None, vec![connect]),
-                targets: VecDeque::new(),
-                error: None,
-            });
-        }
-
-        let srv_domain = match srv {
-            Some(srv) => Some(
-                format!("{}.{}.", srv, domain)
-                    .into_name()
-                    .map_err(ConnecterError::Dns)?,
-            ),
-            None => None,
-        };
-
-        let mut self_ = Connecter {
-            fallback_port,
-            srv_domain,
-            domain: domain.into_name().map_err(ConnecterError::Dns)?,
-            state: State::Invalid,
-            targets: VecDeque::new(),
-            error: None,
-        };
-
-        let resolver = resolver()?;
-        // Initialize state
-        match &self_.srv_domain {
-            &Some(ref srv_domain) => {
-                let srv_lookup = resolver.lookup_srv(srv_domain.clone());
-                self_.state = State::ResolveSrv(resolver, srv_lookup);
-            }
-            None => {
-                self_.targets = [(self_.domain.clone(), self_.fallback_port)]
-                    .iter()
-                    .cloned()
-                    .collect();
-                self_.state = State::Connecting(Some(resolver), vec![]);
-            }
-        }
-
-        Ok(self_)
+) -> Result<TcpStream, Error> {
+    if let Ok(ip) = domain.parse() {
+        return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
     }
-}
 
-impl Future for Connecter {
-    type Item = TcpStream;
-    type Error = Error;
+    let resolver = TokioAsyncResolver::tokio_from_system_conf()
+        .await
+        .map_err(ConnecterError::Resolve)?;
 
-    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
-        let state = mem::replace(&mut self.state, State::Invalid);
-        match state {
-            State::ResolveSrv(resolver, mut srv_lookup) => {
-                match srv_lookup.poll() {
-                    Ok(Async::NotReady) => {
-                        self.state = State::ResolveSrv(resolver, srv_lookup);
-                        Ok(Async::NotReady)
-                    }
-                    Ok(Async::Ready(srv_result)) => {
-                        let srv_map: BTreeMap<_, _> = srv_result
-                            .iter()
-                            .map(|srv| (srv.priority(), (srv.target().clone(), srv.port())))
-                            .collect();
-                        let targets = srv_map.into_iter().map(|(_, tp)| tp).collect();
-                        self.targets = targets;
-                        self.state = State::Connecting(Some(resolver), vec![]);
-                        self.poll()
-                    }
-                    Err(_) => {
-                        // ignore, fallback
-                        self.targets = [(self.domain.clone(), self.fallback_port)]
-                            .iter()
-                            .cloned()
-                            .collect();
-                        self.state = State::Connecting(Some(resolver), vec![]);
-                        self.poll()
-                    }
-                }
-            }
-            State::Connecting(resolver, mut connects) => {
-                if resolver.is_some() && connects.len() == 0 && self.targets.len() > 0 {
-                    let resolver = resolver.unwrap();
-                    let (host, port) = self.targets.pop_front().unwrap();
-                    let ip_lookup = resolver.lookup_ip(host);
-                    self.state = State::ResolveTarget(resolver, ip_lookup, port);
-                    self.poll()
-                } else if connects.len() > 0 {
-                    let mut success = None;
-                    connects.retain(|connect| match connect.borrow_mut().poll() {
-                        Ok(Async::NotReady) => true,
-                        Ok(Async::Ready(connection)) => {
-                            success = Some(connection);
-                            false
-                        }
-                        Err(e) => {
-                            if self.error.is_none() {
-                                self.error = Some(e.into());
-                            }
-                            false
-                        }
-                    });
-                    match success {
-                        Some(connection) => Ok(Async::Ready(connection)),
-                        None => {
-                            self.state = State::Connecting(resolver, connects);
-                            Ok(Async::NotReady)
-                        }
-                    }
-                } else {
-                    // All targets tried
-                    match self.error.take() {
-                        None => Err(ConnecterError::AllFailed.into()),
-                        Some(e) => Err(e),
-                    }
-                }
-            }
-            State::ResolveTarget(resolver, mut ip_lookup, port) => {
-                match ip_lookup.poll() {
-                    Ok(Async::NotReady) => {
-                        self.state = State::ResolveTarget(resolver, ip_lookup, port);
-                        Ok(Async::NotReady)
-                    }
-                    Ok(Async::Ready(ip_result)) => {
-                        let connects = ip_result
-                            .iter()
-                            .map(|ip| RefCell::new(TcpStream::connect(&SocketAddr::new(ip, port))))
-                            .collect();
-                        self.state = State::Connecting(Some(resolver), connects);
-                        self.poll()
-                    }
-                    Err(e) => {
-                        if self.error.is_none() {
-                            self.error = Some(ConnecterError::Resolve(e).into());
-                        }
-                        // ignore, next…
-                        self.state = State::Connecting(Some(resolver), vec![]);
-                        self.poll()
-                    }
+    let srv_records = match srv {
+        Some(srv) => {
+            let srv_domain = format!("{}.{}.", srv, domain)
+                .into_name()
+                .map_err(ConnecterError::Dns)?;
+            resolver.srv_lookup(srv_domain).await.ok()
+        }
+        None => None,
+    };
+
+    match srv_records {
+        Some(lookup) => {
+            // TODO: sort lookup records by priority/weight
+            for srv in lookup.iter() {
+                match connect_to_host(&resolver, &srv.target().to_ascii(), srv.port()).await {
+                    Ok(stream) => return Ok(stream),
+                    Err(_) => {}
                 }
             }
-            _ => panic!(""),
+            Err(Error::Disconnected)
+        }
+        None => {
+            // SRV lookup error, retry with hostname
+            connect_to_host(&resolver, domain, fallback_port).await
         }
     }
 }

tokio-xmpp/src/lib.rs πŸ”—

@@ -6,10 +6,9 @@ mod starttls;
 mod stream_start;
 pub mod xmpp_codec;
 pub use crate::xmpp_codec::Packet;
-pub mod xmpp_stream;
-pub use crate::starttls::StartTlsClient;
 mod event;
 mod happy_eyeballs;
+pub mod xmpp_stream;
 pub use crate::event::Event;
 mod client;
 pub use crate::client::Client;

tokio-xmpp/src/starttls.rs πŸ”—

@@ -1,114 +1,39 @@
-use futures::sink;
-use futures::stream::Stream;
-use futures::{Async, Future, Poll, Sink};
+use futures::{sink::SinkExt, stream::StreamExt};
 use native_tls::TlsConnector as NativeTlsConnector;
-use std::mem::replace;
-use tokio_io::{AsyncRead, AsyncWrite};
-use tokio_tls::{Connect, TlsConnector, TlsStream};
-use xmpp_parsers::{Element, Jid};
+use tokio::io::{AsyncRead, AsyncWrite};
+use tokio_tls::{TlsConnector, TlsStream};
+use xmpp_parsers::Element;
 
 use crate::xmpp_codec::Packet;
 use crate::xmpp_stream::XMPPStream;
-use crate::Error;
+use crate::{Error, ProtocolError};
 
 /// XMPP TLS XML namespace
 pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
 
-/// XMPP stream that switches to TLS if available in received features
-pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
-    state: StartTlsClientState<S>,
-    jid: Jid,
-}
-
-enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
-    Invalid,
-    SendStartTls(sink::Send<XMPPStream<S>>),
-    AwaitProceed(XMPPStream<S>),
-    StartingTls(Connect<S>),
-}
-
-impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
-    /// Waits for <stream:features>
-    pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
-        let jid = xmpp_stream.jid.clone();
-
-        let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build();
-        let packet = Packet::Stanza(nonza);
-        let send = xmpp_stream.send(packet);
-
-        StartTlsClient {
-            state: StartTlsClientState::SendStartTls(send),
-            jid,
+pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
+    mut xmpp_stream: XMPPStream<S>,
+) -> Result<TlsStream<S>, Error> {
+    let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build();
+    let packet = Packet::Stanza(nonza);
+    xmpp_stream.send(packet).await?;
+
+    loop {
+        match xmpp_stream.next().await {
+            Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break,
+            Some(Ok(Packet::Text(_))) => {}
+            Some(Err(e)) => return Err(e.into()),
+            _ => {
+                return Err(ProtocolError::NoTls.into());
+            }
         }
     }
-}
-
-impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
-    type Item = TlsStream<S>;
-    type Error = Error;
-
-    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
-        let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
-        let mut retry = false;
 
-        let (new_state, result) = match old_state {
-            StartTlsClientState::SendStartTls(mut send) => match send.poll() {
-                Ok(Async::Ready(xmpp_stream)) => {
-                    let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
-                    retry = true;
-                    (new_state, Ok(Async::NotReady))
-                }
-                Ok(Async::NotReady) => {
-                    (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady))
-                }
-                Err(e) => (StartTlsClientState::SendStartTls(send), Err(e.into())),
-            },
-            StartTlsClientState::AwaitProceed(mut xmpp_stream) => match xmpp_stream.poll() {
-                Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
-                    if stanza.name() == "proceed" =>
-                {
-                    let stream = xmpp_stream.stream.into_inner();
-                    let connect =
-                        TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
-                            .connect(&self.jid.clone().domain(), stream);
-                    let new_state = StartTlsClientState::StartingTls(connect);
-                    retry = true;
-                    (new_state, Ok(Async::NotReady))
-                }
-                Ok(Async::Ready(_value)) => {
-                    // println!("StartTlsClient ignore {:?}", _value);
-                    (
-                        StartTlsClientState::AwaitProceed(xmpp_stream),
-                        Ok(Async::NotReady),
-                    )
-                }
-                Ok(_) => (
-                    StartTlsClientState::AwaitProceed(xmpp_stream),
-                    Ok(Async::NotReady),
-                ),
-                Err(e) => (
-                    StartTlsClientState::AwaitProceed(xmpp_stream),
-                    Err(Error::Protocol(e.into())),
-                ),
-            },
-            StartTlsClientState::StartingTls(mut connect) => match connect.poll() {
-                Ok(Async::Ready(tls_stream)) => {
-                    (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream)))
-                }
-                Ok(Async::NotReady) => (
-                    StartTlsClientState::StartingTls(connect),
-                    Ok(Async::NotReady),
-                ),
-                Err(e) => (StartTlsClientState::Invalid, Err(e.into())),
-            },
-            StartTlsClientState::Invalid => unreachable!(),
-        };
+    let domain = xmpp_stream.jid.clone().domain();
+    let stream = xmpp_stream.into_inner();
+    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
+        .connect(&domain, stream)
+        .await?;
 
-        self.state = new_state;
-        if retry {
-            self.poll()
-        } else {
-            result
-        }
-    }
+    Ok(tls_stream)
 }

tokio-xmpp/src/stream_start.rs πŸ”—

@@ -1,7 +1,7 @@
-use futures::{sink, Async, Future, Poll, Sink, Stream};
-use std::mem::replace;
-use tokio_codec::Framed;
-use tokio_io::{AsyncRead, AsyncWrite};
+use futures::{sink::SinkExt, stream::StreamExt};
+use std::marker::Unpin;
+use tokio::io::{AsyncRead, AsyncWrite};
+use tokio_util::codec::Framed;
 use xmpp_parsers::{Element, Jid};
 
 use crate::xmpp_codec::{Packet, XMPPCodec};
@@ -10,116 +10,66 @@ use crate::{Error, ProtocolError};
 
 const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
 
-pub struct StreamStart<S: AsyncWrite> {
-    state: StreamStartState<S>,
+pub async fn start<S: AsyncRead + AsyncWrite + Unpin>(
+    mut stream: Framed<S, XMPPCodec>,
     jid: Jid,
     ns: String,
-}
-
-enum StreamStartState<S: AsyncWrite> {
-    SendStart(sink::Send<Framed<S, XMPPCodec>>),
-    RecvStart(Framed<S, XMPPCodec>),
-    RecvFeatures(Framed<S, XMPPCodec>, String),
-    Invalid,
-}
-
-impl<S: AsyncWrite> StreamStart<S> {
-    pub fn from_stream(stream: Framed<S, XMPPCodec>, jid: Jid, ns: String) -> Self {
-        let attrs = [
-            ("to".to_owned(), jid.clone().domain()),
-            ("version".to_owned(), "1.0".to_owned()),
-            ("xmlns".to_owned(), ns.clone()),
-            ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
-        ]
-        .iter()
-        .cloned()
-        .collect();
-        let send = stream.send(Packet::StreamStart(attrs));
+) -> Result<XMPPStream<S>, Error> {
+    let attrs = [
+        ("to".to_owned(), jid.clone().domain()),
+        ("version".to_owned(), "1.0".to_owned()),
+        ("xmlns".to_owned(), ns.clone()),
+        ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
+    ]
+    .iter()
+    .cloned()
+    .collect();
+    stream.send(Packet::StreamStart(attrs)).await?;
 
-        StreamStart {
-            state: StreamStartState::SendStart(send),
-            jid,
-            ns,
+    let stream_attrs;
+    loop {
+        match stream.next().await {
+            Some(Ok(Packet::StreamStart(attrs))) => {
+                stream_attrs = attrs;
+                break;
+            }
+            Some(Ok(_)) => {}
+            Some(Err(e)) => return Err(e.into()),
+            None => return Err(Error::Disconnected),
         }
     }
-}
-
-impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
-    type Item = XMPPStream<S>;
-    type Error = Error;
 
-    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
-        let old_state = replace(&mut self.state, StreamStartState::Invalid);
-        let mut retry = false;
-
-        let (new_state, result) = match old_state {
-            StreamStartState::SendStart(mut send) => match send.poll() {
-                Ok(Async::Ready(stream)) => {
-                    retry = true;
-                    (StreamStartState::RecvStart(stream), Ok(Async::NotReady))
-                }
-                Ok(Async::NotReady) => (StreamStartState::SendStart(send), Ok(Async::NotReady)),
-                Err(e) => (StreamStartState::Invalid, Err(e.into())),
-            },
-            StreamStartState::RecvStart(mut stream) => match stream.poll() {
-                Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => {
-                    let stream_ns = stream_attrs
-                        .get("xmlns")
-                        .ok_or(ProtocolError::NoStreamNamespace)?
-                        .clone();
-                    if self.ns == "jabber:client" {
-                        retry = true;
-                        // TODO: skip RecvFeatures for version < 1.0
-                        (
-                            StreamStartState::RecvFeatures(stream, stream_ns),
-                            Ok(Async::NotReady),
-                        )
-                    } else {
-                        let id = stream_attrs
-                            .get("id")
-                            .ok_or(ProtocolError::NoStreamId)?
-                            .clone();
-                        // FIXME: huge hack, shouldn’t be an element!
-                        let stream = XMPPStream::new(
-                            self.jid.clone(),
-                            stream,
-                            self.ns.clone(),
-                            Element::builder(id).build(),
-                        );
-                        (StreamStartState::Invalid, Ok(Async::Ready(stream)))
-                    }
+    let stream_ns = stream_attrs
+        .get("xmlns")
+        .ok_or(ProtocolError::NoStreamNamespace)?
+        .clone();
+    let stream_id = stream_attrs
+        .get("id")
+        .ok_or(ProtocolError::NoStreamId)?
+        .clone();
+    let stream = if stream_ns == "jabber:client" && stream_attrs.get("version").is_some() {
+        let stream_features;
+        loop {
+            match stream.next().await {
+                Some(Ok(Packet::Stanza(stanza))) if stanza.is("features", NS_XMPP_STREAM) => {
+                    stream_features = stanza;
+                    break;
                 }
-                Ok(Async::Ready(_)) => return Err(ProtocolError::InvalidToken.into()),
-                Ok(Async::NotReady) => (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
-                Err(e) => return Err(ProtocolError::from(e).into()),
-            },
-            StreamStartState::RecvFeatures(mut stream, stream_ns) => match stream.poll() {
-                Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
-                    if stanza.is("features", NS_XMPP_STREAM) {
-                        let stream =
-                            XMPPStream::new(self.jid.clone(), stream, self.ns.clone(), stanza);
-                        (StreamStartState::Invalid, Ok(Async::Ready(stream)))
-                    } else {
-                        (
-                            StreamStartState::RecvFeatures(stream, stream_ns),
-                            Ok(Async::NotReady),
-                        )
-                    }
-                }
-                Ok(Async::Ready(_)) | Ok(Async::NotReady) => (
-                    StreamStartState::RecvFeatures(stream, stream_ns),
-                    Ok(Async::NotReady),
-                ),
-                Err(e) => return Err(ProtocolError::from(e).into()),
-            },
-            StreamStartState::Invalid => unreachable!(),
-        };
-
-        self.state = new_state;
-        if retry {
-            self.poll()
-        } else {
-            result
+                Some(Ok(_)) => {}
+                Some(Err(e)) => return Err(e.into()),
+                None => return Err(Error::Disconnected),
+            }
         }
-    }
+        XMPPStream::new(jid, stream, ns, stream_id, stream_features)
+    } else {
+        // FIXME: huge hack, shouldn’t be an element!
+        XMPPStream::new(
+            jid,
+            stream,
+            ns,
+            stream_id.clone(),
+            Element::builder(stream_id).build(),
+        )
+    };
+    Ok(stream)
 }

tokio-xmpp/src/xmpp_codec.rs πŸ”—

@@ -5,16 +5,16 @@ use bytes::{BufMut, BytesMut};
 use log::debug;
 use std;
 use std::borrow::Cow;
-use std::cell::RefCell;
 use std::collections::vec_deque::VecDeque;
 use std::collections::HashMap;
 use std::default::Default;
 use std::fmt::Write;
 use std::io;
 use std::iter::FromIterator;
-use std::rc::Rc;
 use std::str::from_utf8;
-use tokio_codec::{Decoder, Encoder};
+use std::sync::Arc;
+use std::sync::Mutex;
+use tokio_util::codec::{Decoder, Encoder};
 use xml5ever::buffer_queue::BufferQueue;
 use xml5ever::interface::Attribute;
 use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
@@ -38,14 +38,14 @@ type QueueItem = Result<Packet, ParserError>;
 /// Parser state
 struct ParserSink {
     // Ready stanzas, shared with XMPPCodec
-    queue: Rc<RefCell<VecDeque<QueueItem>>>,
+    queue: Arc<Mutex<VecDeque<QueueItem>>>,
     // Parsing stack
     stack: Vec<Element>,
     ns_stack: Vec<HashMap<Option<String>, String>>,
 }
 
 impl ParserSink {
-    pub fn new(queue: Rc<RefCell<VecDeque<QueueItem>>>) -> Self {
+    pub fn new(queue: Arc<Mutex<VecDeque<QueueItem>>>) -> Self {
         ParserSink {
             queue,
             stack: vec![],
@@ -54,11 +54,11 @@ impl ParserSink {
     }
 
     fn push_queue(&self, pkt: Packet) {
-        self.queue.borrow_mut().push_back(Ok(pkt));
+        self.queue.lock().unwrap().push_back(Ok(pkt));
     }
 
     fn push_queue_error(&self, e: ParserError) {
-        self.queue.borrow_mut().push_back(Err(e));
+        self.queue.lock().unwrap().push_back(Err(e));
     }
 
     /// Lookup XML namespace declaration for given prefix (or no prefix)
@@ -169,7 +169,6 @@ impl TokenSink for ParserSink {
             },
             Token::EOFToken => self.push_queue(Packet::StreamEnd),
             Token::ParseError(s) => {
-                // println!("ParseError: {:?}", s);
                 self.push_queue_error(ParserError::Parse(ParseError(s)));
             }
             _ => (),
@@ -190,13 +189,13 @@ pub struct XMPPCodec {
     // TODO: optimize using  tendrils?
     buf: Vec<u8>,
     /// Shared with ParserSink
-    queue: Rc<RefCell<VecDeque<QueueItem>>>,
+    queue: Arc<Mutex<VecDeque<QueueItem>>>,
 }
 
 impl XMPPCodec {
     /// Constructor
     pub fn new() -> Self {
-        let queue = Rc::new(RefCell::new(VecDeque::new()));
+        let queue = Arc::new(Mutex::new(VecDeque::new()));
         let sink = ParserSink::new(queue.clone());
         // TODO: configure parser?
         let parser = XmlTokenizer::new(sink, Default::default());
@@ -222,10 +221,10 @@ impl Decoder for XMPPCodec {
     fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
         let buf1: Box<dyn AsRef<[u8]>> = if !self.buf.is_empty() && !buf.is_empty() {
             let mut prefix = std::mem::replace(&mut self.buf, vec![]);
-            prefix.extend_from_slice(buf.take().as_ref());
+            prefix.extend_from_slice(&buf.split_to(buf.len()));
             Box::new(prefix)
         } else {
-            Box::new(buf.take())
+            Box::new(buf.split_to(buf.len()))
         };
         let buf1 = buf1.as_ref().as_ref();
         match from_utf8(buf1) {
@@ -258,7 +257,7 @@ impl Decoder for XMPPCodec {
             }
         }
 
-        match self.queue.borrow_mut().pop_front() {
+        match self.queue.lock().unwrap().pop_front() {
             None => Ok(None),
             Some(result) => result.map(|pkt| Some(pkt)),
         }
@@ -372,7 +371,7 @@ mod tests {
     fn test_stream_start() {
         let mut c = XMPPCodec::new();
         let mut b = BytesMut::with_capacity(1024);
-        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
         let r = c.decode(&mut b);
         assert!(match r {
             Ok(Some(Packet::StreamStart(_))) => true,
@@ -384,14 +383,14 @@ mod tests {
     fn test_stream_end() {
         let mut c = XMPPCodec::new();
         let mut b = BytesMut::with_capacity(1024);
-        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
         let r = c.decode(&mut b);
         assert!(match r {
             Ok(Some(Packet::StreamStart(_))) => true,
             _ => false,
         });
         b.clear();
-        b.put(r"</stream:stream>");
+        b.put_slice(b"</stream:stream>");
         let r = c.decode(&mut b);
         assert!(match r {
             Ok(Some(Packet::StreamEnd)) => true,
@@ -403,7 +402,7 @@ mod tests {
     fn test_truncated_stanza() {
         let mut c = XMPPCodec::new();
         let mut b = BytesMut::with_capacity(1024);
-        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
         let r = c.decode(&mut b);
         assert!(match r {
             Ok(Some(Packet::StreamStart(_))) => true,
@@ -411,7 +410,7 @@ mod tests {
         });
 
         b.clear();
-        b.put(r"<test>ß</test");
+        b.put_slice("<test>ß</test".as_bytes());
         let r = c.decode(&mut b);
         assert!(match r {
             Ok(None) => true,
@@ -419,7 +418,7 @@ mod tests {
         });
 
         b.clear();
-        b.put(r">");
+        b.put_slice(b">");
         let r = c.decode(&mut b);
         assert!(match r {
             Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
@@ -431,7 +430,7 @@ mod tests {
     fn test_truncated_utf8() {
         let mut c = XMPPCodec::new();
         let mut b = BytesMut::with_capacity(1024);
-        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
         let r = c.decode(&mut b);
         assert!(match r {
             Ok(Some(Packet::StreamStart(_))) => true,
@@ -460,7 +459,7 @@ mod tests {
     fn test_atrribute_prefix() {
         let mut c = XMPPCodec::new();
         let mut b = BytesMut::with_capacity(1024);
-        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
         let r = c.decode(&mut b);
         assert!(match r {
             Ok(Some(Packet::StreamStart(_))) => true,
@@ -468,7 +467,7 @@ mod tests {
         });
 
         b.clear();
-        b.put(r"<status xml:lang='en'>Test status</status>");
+        b.put_slice(b"<status xml:lang='en'>Test status</status>");
         let r = c.decode(&mut b);
         assert!(match r {
             Ok(Some(Packet::Stanza(ref el)))
@@ -483,10 +482,10 @@ mod tests {
     /// By default, encode() only get's a BytesMut that has 8kb space reserved.
     #[test]
     fn test_large_stanza() {
-        use futures::{Future, Sink};
+        use futures::{executor::block_on, sink::SinkExt};
         use std::io::Cursor;
-        use tokio_codec::FramedWrite;
-        let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
+        use tokio_util::codec::FramedWrite;
+        let mut framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
         let mut text = "".to_owned();
         for _ in 0..2usize.pow(15) {
             text = text + "A";
@@ -494,7 +493,7 @@ mod tests {
         let stanza = Element::builder("message")
             .append(Element::builder("body").append(text.as_ref()).build())
             .build();
-        let framed = framed.send(Packet::Stanza(stanza)).wait().expect("send");
+        block_on(framed.send(Packet::Stanza(stanza))).expect("send");
         assert_eq!(
             framed.get_ref().get_ref(),
             &("<message><body>".to_owned() + &text + "</body></message>").as_bytes()
@@ -505,7 +504,7 @@ mod tests {
     fn test_cut_out_stanza() {
         let mut c = XMPPCodec::new();
         let mut b = BytesMut::with_capacity(1024);
-        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
         let r = c.decode(&mut b);
         assert!(match r {
             Ok(Some(Packet::StreamStart(_))) => true,
@@ -513,8 +512,8 @@ mod tests {
         });
 
         b.clear();
-        b.put(r"<message ");
-        b.put(r"type='chat'><body>Foo</body></message>");
+        b.put_slice(b"<message ");
+        b.put_slice(b"type='chat'><body>Foo</body></message>");
         let r = c.decode(&mut b);
         assert!(match r {
             Ok(Some(Packet::Stanza(_))) => true,

tokio-xmpp/src/xmpp_stream.rs πŸ”—

@@ -1,23 +1,28 @@
 //! `XMPPStream` is the common container for all XMPP network connections
 
 use futures::sink::Send;
-use futures::{Poll, Sink, StartSend, Stream};
-use tokio_codec::Framed;
-use tokio_io::{AsyncRead, AsyncWrite};
+use futures::{sink::SinkExt, task::Poll, Sink, Stream};
+use std::ops::DerefMut;
+use std::pin::Pin;
+use std::sync::Mutex;
+use std::task::Context;
+use tokio::io::{AsyncRead, AsyncWrite};
+use tokio_util::codec::Framed;
 use xmpp_parsers::{Element, Jid};
 
-use crate::stream_start::StreamStart;
+use crate::stream_start;
 use crate::xmpp_codec::{Packet, XMPPCodec};
+use crate::Error;
 
 /// <stream:stream> namespace
 pub const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
 
 /// Wraps a `stream`
-pub struct XMPPStream<S> {
+pub struct XMPPStream<S: AsyncRead + AsyncWrite + Unpin> {
     /// The local Jabber-Id
     pub jid: Jid,
     /// Codec instance
-    pub stream: Framed<S, XMPPCodec>,
+    pub stream: Mutex<Framed<S, XMPPCodec>>,
     /// `<stream:features/>` for XMPP version 1.0
     pub stream_features: Element,
     /// Root namespace
@@ -25,68 +30,94 @@ pub struct XMPPStream<S> {
     /// This is different for either c2s, s2s, or component
     /// connections.
     pub ns: String,
+    /// Stream `id` attribute
+    pub id: String,
 }
 
-impl<S: AsyncRead + AsyncWrite> XMPPStream<S> {
+// // TODO: fix this hack
+// unsafe impl<S: AsyncRead + AsyncWrite + Unpin> core::marker::Send for XMPPStream<S> {}
+// unsafe impl<S: AsyncRead + AsyncWrite + Unpin> Sync for XMPPStream<S> {}
+
+impl<S: AsyncRead + AsyncWrite + Unpin> XMPPStream<S> {
     /// Constructor
     pub fn new(
         jid: Jid,
         stream: Framed<S, XMPPCodec>,
         ns: String,
+        id: String,
         stream_features: Element,
     ) -> Self {
         XMPPStream {
             jid,
-            stream,
+            stream: Mutex::new(stream),
             stream_features,
             ns,
+            id,
         }
     }
 
     /// Send a `<stream:stream>` start tag
-    pub fn start(stream: S, jid: Jid, ns: String) -> StreamStart<S> {
+    pub async fn start<'a>(stream: S, jid: Jid, ns: String) -> Result<Self, Error> {
         let xmpp_stream = Framed::new(stream, XMPPCodec::new());
-        StreamStart::from_stream(xmpp_stream, jid, ns)
+        stream_start::start(xmpp_stream, jid, ns).await
     }
 
     /// Unwraps the inner stream
+    // TODO: use this everywhere
     pub fn into_inner(self) -> S {
-        self.stream.into_inner()
+        self.stream.into_inner().unwrap().into_inner()
     }
 
     /// Re-run `start()`
-    pub fn restart(self) -> StreamStart<S> {
-        Self::start(self.stream.into_inner(), self.jid, self.ns)
+    pub async fn restart<'a>(self) -> Result<Self, Error> {
+        let stream = self.stream.into_inner().unwrap().into_inner();
+        Self::start(stream, self.jid, self.ns).await
     }
 }
 
-impl<S: AsyncWrite> XMPPStream<S> {
+impl<S: AsyncRead + AsyncWrite + Unpin> XMPPStream<S> {
     /// Convenience method
-    pub fn send_stanza<E: Into<Element>>(self, e: E) -> Send<Self> {
+    pub fn send_stanza<E: Into<Element>>(&mut self, e: E) -> Send<Self, Packet> {
         self.send(Packet::Stanza(e.into()))
     }
 }
 
 /// Proxy to self.stream
-impl<S: AsyncWrite> Sink for XMPPStream<S> {
-    type SinkItem = <Framed<S, XMPPCodec> as Sink>::SinkItem;
-    type SinkError = <Framed<S, XMPPCodec> as Sink>::SinkError;
+impl<S: AsyncRead + AsyncWrite + Unpin> Sink<Packet> for XMPPStream<S> {
+    type Error = crate::Error;
+
+    fn poll_ready(self: Pin<&mut Self>, _ctx: &mut Context) -> Poll<Result<(), Self::Error>> {
+        // Pin::new(&mut self.stream).poll_ready(ctx)
+        //     .map_err(|e| e.into())
+        Poll::Ready(Ok(()))
+    }
+
+    fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
+        Pin::new(&mut self.stream.lock().unwrap().deref_mut())
+            .start_send(item)
+            .map_err(|e| e.into())
+    }
 
-    fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
-        self.stream.start_send(item)
+    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+        Pin::new(&mut self.stream.lock().unwrap().deref_mut())
+            .poll_flush(cx)
+            .map_err(|e| e.into())
     }
 
-    fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
-        self.stream.poll_complete()
+    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
+        Pin::new(&mut self.stream.lock().unwrap().deref_mut())
+            .poll_close(cx)
+            .map_err(|e| e.into())
     }
 }
 
 /// Proxy to self.stream
-impl<S: AsyncRead> Stream for XMPPStream<S> {
-    type Item = <Framed<S, XMPPCodec> as Stream>::Item;
-    type Error = <Framed<S, XMPPCodec> as Stream>::Error;
+impl<S: AsyncRead + AsyncWrite + Unpin> Stream for XMPPStream<S> {
+    type Item = Result<Packet, crate::Error>;
 
-    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
-        self.stream.poll()
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
+        Pin::new(&mut self.stream.lock().unwrap().deref_mut())
+            .poll_next(cx)
+            .map(|result| result.map(|result| result.map_err(|e| e.into())))
     }
 }