From 23cb34e026cd82fb9fbb6bcbbf30f008d2a9bbbf Mon Sep 17 00:00:00 2001 From: Astro Date: Thu, 5 Mar 2020 01:25:24 +0100 Subject: [PATCH] tokio-xmpp: rewrite for futures-0.3 --- tokio-xmpp/Cargo.toml | 17 +- tokio-xmpp/examples/contact_addr.rs | 80 +++---- tokio-xmpp/examples/download_avatars.rs | 246 ++++++++++----------- tokio-xmpp/examples/echo_bot.rs | 103 ++++----- tokio-xmpp/examples/echo_component.rs | 71 +++--- tokio-xmpp/src/client/auth.rs | 154 +++++-------- tokio-xmpp/src/client/bind.rs | 115 +++------- tokio-xmpp/src/client/mod.rs | 277 ++++++++++++++---------- tokio-xmpp/src/component/auth.rs | 102 ++------- tokio-xmpp/src/component/mod.rs | 176 ++++++--------- tokio-xmpp/src/error.rs | 11 +- tokio-xmpp/src/event.rs | 3 +- tokio-xmpp/src/happy_eyeballs.rs | 230 +++++--------------- tokio-xmpp/src/lib.rs | 3 +- tokio-xmpp/src/starttls.rs | 127 +++-------- tokio-xmpp/src/stream_start.rs | 166 +++++--------- tokio-xmpp/src/xmpp_codec.rs | 57 +++-- tokio-xmpp/src/xmpp_stream.rs | 85 +++++--- 18 files changed, 801 insertions(+), 1222 deletions(-) diff --git a/tokio-xmpp/Cargo.toml b/tokio-xmpp/Cargo.toml index 798f9d6a1dd2f710c750b54d03718d6c6abc505e..113614e7a345f6436ff980e01a2b109e0c291cff 100644 --- a/tokio-xmpp/Cargo.toml +++ b/tokio-xmpp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-xmpp" -version = "1.0.1" +version = "2.0.0" authors = ["Astro ", "Emmanuel Gil Peyrot ", "pep ", "O01eg "] description = "Asynchronous XMPP for Rust with tokio" license = "MPL-2.0" @@ -12,17 +12,16 @@ keywords = ["xmpp", "tokio"] edition = "2018" [dependencies] -bytes = "0.4" -futures = "0.1" +bytes = "0.5" +futures = "0.3" idna = "0.2" log = "0.4" native-tls = "0.2" sasl = "0.4" -tokio = "0.1" -tokio-codec = "0.1" -trust-dns-resolver = "0.12" -trust-dns-proto = "0.8" -tokio-io = "0.1" -tokio-tls = "0.2" +tokio = { version = "0.2", features = ["net", "stream", "rt-util", "rt-threaded", "macros"] } +tokio-util = { version = "0.2", features = ["codec"] } +tokio-tls = "0.3" +trust-dns-resolver = "0.19" +trust-dns-proto = "0.19" xml5ever = "0.16" xmpp-parsers = "0.17" diff --git a/tokio-xmpp/examples/contact_addr.rs b/tokio-xmpp/examples/contact_addr.rs index e3148437b2708d2ac8ce0cd983f2ec1729fb21de..c81f663ad1d102b77d41c5d7ff2681912c1c86a7 100644 --- a/tokio-xmpp/examples/contact_addr.rs +++ b/tokio-xmpp/examples/contact_addr.rs @@ -1,9 +1,8 @@ -use futures::{future, Sink, Stream}; +use futures::stream::StreamExt; use std::convert::TryFrom; use std::env::args; use std::process::exit; -use tokio::runtime::current_thread::Runtime; -use tokio_xmpp::{xmpp_codec::Packet, Client}; +use tokio_xmpp::Client; use xmpp_parsers::{ disco::{DiscoInfoQuery, DiscoInfoResult}, iq::{Iq, IqType}, @@ -12,70 +11,55 @@ use xmpp_parsers::{ Element, Jid, }; -fn main() { +#[tokio::main] +async fn main() { let args: Vec = args().collect(); if args.len() != 4 { println!("Usage: {} ", args[0]); exit(1); } let jid = &args[1]; - let password = &args[2]; + let password = args[2].clone(); let target = &args[3]; - // tokio_core context - let mut rt = Runtime::new().unwrap(); // Client instance - let client = Client::new(jid, password).unwrap(); + let mut client = Client::new(jid, password).unwrap(); - // Make the two interfaces for sending and receiving independent - // of each other so we can move one into a closure. - let (mut sink, stream) = client.split(); - // Wrap sink in Option so that we can take() it for the send(self) - // to consume and return it back when ready. - let mut send = move |packet| { - sink.start_send(packet).expect("start_send"); - }; // Main loop, processes events let mut wait_for_stream_end = false; - let done = stream.for_each(|event| { - if wait_for_stream_end { - /* Do Nothing. */ - } else if event.is_online() { - println!("Online!"); + let mut stream_ended = false; + while !stream_ended { + if let Some(event) = client.next().await { + if wait_for_stream_end { + /* Do Nothing. */ + } else if event.is_online() { + println!("Online!"); - let target_jid: Jid = target.clone().parse().unwrap(); - let iq = make_disco_iq(target_jid); - println!("Sending disco#info request to {}", target.clone()); - println!(">> {}", String::from(&iq)); - send(Packet::Stanza(iq)); - } else if let Some(stanza) = event.into_stanza() { - if stanza.is("iq", "jabber:client") { - let iq = Iq::try_from(stanza).unwrap(); - if let IqType::Result(Some(payload)) = iq.payload { - if payload.is("query", ns::DISCO_INFO) { - if let Ok(disco_info) = DiscoInfoResult::try_from(payload) { - for ext in disco_info.extensions { - if let Ok(server_info) = ServerInfo::try_from(ext) { - print_server_info(server_info); - wait_for_stream_end = true; - send(Packet::StreamEnd); + let target_jid: Jid = target.clone().parse().unwrap(); + let iq = make_disco_iq(target_jid); + println!("Sending disco#info request to {}", target.clone()); + println!(">> {}", String::from(&iq)); + client.send_stanza(iq).await.unwrap(); + } else if let Some(stanza) = event.into_stanza() { + if stanza.is("iq", "jabber:client") { + let iq = Iq::try_from(stanza).unwrap(); + if let IqType::Result(Some(payload)) = iq.payload { + if payload.is("query", ns::DISCO_INFO) { + if let Ok(disco_info) = DiscoInfoResult::try_from(payload) { + for ext in disco_info.extensions { + if let Ok(server_info) = ServerInfo::try_from(ext) { + print_server_info(server_info); + } } } } + wait_for_stream_end = true; + client.send_end().await.unwrap(); } } } - } - - Box::new(future::ok(())) - }); - - // Start polling `done` - match rt.block_on(done) { - Ok(_) => (), - Err(e) => { - println!("Fatal: {}", e); - () + } else { + stream_ended = true; } } } diff --git a/tokio-xmpp/examples/download_avatars.rs b/tokio-xmpp/examples/download_avatars.rs index ec4222d2d7e5f23877701cc99e8980003c2af2e0..7562d07050e970c8a6bcc3aa7882f7f98cd41ee5 100644 --- a/tokio-xmpp/examples/download_avatars.rs +++ b/tokio-xmpp/examples/download_avatars.rs @@ -1,12 +1,12 @@ -use futures::{future, Future, Sink, Stream}; +use futures::stream::StreamExt; use std::convert::TryFrom; use std::env::args; use std::fs::{create_dir_all, File}; use std::io::{self, Write}; use std::process::exit; use std::str::FromStr; -use tokio::runtime::current_thread::Runtime; -use tokio_xmpp::{Client, Packet}; +use tokio; +use tokio_xmpp::Client; use xmpp_parsers::{ avatar::{Data as AvatarData, Metadata as AvatarMetadata}, caps::{compute_disco, hash_caps, Caps}, @@ -22,162 +22,153 @@ use xmpp_parsers::{ NodeName, }, stanza_error::{DefinedCondition, ErrorType, StanzaError}, - Jid, + Element, Jid, }; -fn main() { +#[tokio::main] +async fn main() { let args: Vec = args().collect(); if args.len() != 3 { println!("Usage: {} ", args[0]); exit(1); } let jid = &args[1]; - let password = &args[2]; + let password = args[2].clone(); - // tokio_core context - let mut rt = Runtime::new().unwrap(); // Client instance - let client = Client::new(jid, password).unwrap(); - - // Make the two interfaces for sending and receiving independent - // of each other so we can move one into a closure. - let (sink, stream) = client.split(); - - // Create outgoing pipe - let (mut tx, rx) = futures::unsync::mpsc::unbounded(); - rt.spawn( - rx.forward(sink.sink_map_err(|_| panic!("Pipe"))) - .map(|(rx, mut sink)| { - drop(rx); - let _ = sink.close(); - }) - .map_err(|e| { - panic!("Send error: {:?}", e); - }), - ); + let mut client = Client::new(jid, password).unwrap(); let disco_info = make_disco(); // Main loop, processes events let mut wait_for_stream_end = false; - let done = stream.for_each(move |event| { - // Helper function to send an iq error. - let mut send_error = |to, id, type_, condition, text: &str| { - let error = StanzaError::new(type_, condition, "en", text); - let iq = Iq::from_error(id, error).with_to(to); - tx.start_send(Packet::Stanza(iq.into())).unwrap(); - }; - - if wait_for_stream_end { - /* Do nothing */ - } else if event.is_online() { - println!("Online!"); - - let caps = get_disco_caps(&disco_info, "https://gitlab.com/xmpp-rs/tokio-xmpp"); - let presence = make_presence(caps); - tx.start_send(Packet::Stanza(presence.into())).unwrap(); - } else if let Some(stanza) = event.into_stanza() { - if stanza.is("iq", "jabber:client") { - let iq = Iq::try_from(stanza).unwrap(); - if let IqType::Get(payload) = iq.payload { - if payload.is("query", ns::DISCO_INFO) { - let query = DiscoInfoQuery::try_from(payload); - match query { - Ok(query) => { - let mut disco = disco_info.clone(); - disco.node = query.node; - let iq = - Iq::from_result(iq.id, Some(disco)).with_to(iq.from.unwrap()); - tx.start_send(Packet::Stanza(iq.into())).unwrap(); + let mut stream_ended = false; + while !stream_ended { + if let Some(event) = client.next().await { + if wait_for_stream_end { + /* Do nothing */ + } else if event.is_online() { + println!("Online!"); + + let caps = get_disco_caps(&disco_info, "https://gitlab.com/xmpp-rs/tokio-xmpp"); + let presence = make_presence(caps); + client.send_stanza(presence.into()).await.unwrap(); + } else if let Some(stanza) = event.into_stanza() { + if stanza.is("iq", "jabber:client") { + let iq = Iq::try_from(stanza).unwrap(); + if let IqType::Get(payload) = iq.payload { + if payload.is("query", ns::DISCO_INFO) { + let query = DiscoInfoQuery::try_from(payload); + match query { + Ok(query) => { + let mut disco = disco_info.clone(); + disco.node = query.node; + let iq = Iq::from_result(iq.id, Some(disco)) + .with_to(iq.from.unwrap()); + client.send_stanza(iq.into()).await.unwrap(); + } + Err(err) => client + .send_stanza(make_error( + iq.from.unwrap(), + iq.id, + ErrorType::Modify, + DefinedCondition::BadRequest, + &format!("{}", err), + )) + .await + .unwrap(), } - Err(err) => { - send_error( + } else { + // We MUST answer unhandled get iqs with a service-unavailable error. + client + .send_stanza(make_error( iq.from.unwrap(), iq.id, - ErrorType::Modify, - DefinedCondition::BadRequest, - &format!("{}", err), - ); - } + ErrorType::Cancel, + DefinedCondition::ServiceUnavailable, + "No handler defined for this kind of iq.", + )) + .await + .unwrap(); } - } else { - // We MUST answer unhandled get iqs with a service-unavailable error. - send_error( - iq.from.unwrap(), - iq.id, - ErrorType::Cancel, - DefinedCondition::ServiceUnavailable, - "No handler defined for this kind of iq.", - ); - } - } else if let IqType::Result(Some(payload)) = iq.payload { - if payload.is("pubsub", ns::PUBSUB) { - let pubsub = PubSub::try_from(payload).unwrap(); - let from = iq.from.clone().unwrap_or(Jid::from_str(jid).unwrap()); - handle_iq_result(pubsub, &from); + } else if let IqType::Result(Some(payload)) = iq.payload { + if payload.is("pubsub", ns::PUBSUB) { + let pubsub = PubSub::try_from(payload).unwrap(); + let from = iq.from.clone().unwrap_or(Jid::from_str(jid).unwrap()); + handle_iq_result(pubsub, &from); + } + } else if let IqType::Set(_) = iq.payload { + // We MUST answer unhandled set iqs with a service-unavailable error. + client + .send_stanza(make_error( + iq.from.unwrap(), + iq.id, + ErrorType::Cancel, + DefinedCondition::ServiceUnavailable, + "No handler defined for this kind of iq.", + )) + .await + .unwrap(); } - } else if let IqType::Set(_) = iq.payload { - // We MUST answer unhandled set iqs with a service-unavailable error. - send_error( - iq.from.unwrap(), - iq.id, - ErrorType::Cancel, - DefinedCondition::ServiceUnavailable, - "No handler defined for this kind of iq.", - ); - } - } else if stanza.is("message", "jabber:client") { - let message = Message::try_from(stanza).unwrap(); - let from = message.from.clone().unwrap(); - if let Some(body) = message.get_best_body(vec!["en"]) { - if body.1 .0 == "die" { - println!("Secret die command triggered by {}", from); - wait_for_stream_end = true; - tx.start_send(Packet::StreamEnd).unwrap(); + } else if stanza.is("message", "jabber:client") { + let message = Message::try_from(stanza).unwrap(); + let from = message.from.clone().unwrap(); + if let Some(body) = message.get_best_body(vec!["en"]) { + if body.0 == "die" { + println!("Secret die command triggered by {}", from); + wait_for_stream_end = true; + client.send_end().await.unwrap(); + } } - } - for child in message.payloads { - if child.is("event", ns::PUBSUB_EVENT) { - let event = PubSubEvent::try_from(child).unwrap(); - if let PubSubEvent::PublishedItems { node, items } = event { - if node.0 == ns::AVATAR_METADATA { - for item in items.into_iter() { - let payload = item.payload.clone().unwrap(); - if payload.is("metadata", ns::AVATAR_METADATA) { - // TODO: do something with these metadata. - let _metadata = AvatarMetadata::try_from(payload).unwrap(); - println!( - "{} has published an avatar, downloading...", - from.clone() - ); - let iq = download_avatar(from.clone()); - tx.start_send(Packet::Stanza(iq.into())).unwrap(); + for child in message.payloads { + if child.is("event", ns::PUBSUB_EVENT) { + let event = PubSubEvent::try_from(child).unwrap(); + if let PubSubEvent::PublishedItems { node, items } = event { + if node.0 == ns::AVATAR_METADATA { + for item in items.into_iter() { + let payload = item.payload.clone().unwrap(); + if payload.is("metadata", ns::AVATAR_METADATA) { + // TODO: do something with these metadata. + let _metadata = + AvatarMetadata::try_from(payload).unwrap(); + println!( + "{} has published an avatar, downloading...", + from.clone() + ); + let iq = download_avatar(from.clone()); + client.send_stanza(iq.into()).await.unwrap(); + } } } } } } + } else if stanza.is("presence", "jabber:client") { + // Nothing to do here. + () + } else { + panic!("Unknown stanza: {}", String::from(&stanza)); } - } else if stanza.is("presence", "jabber:client") { - // Nothing to do here. - } else { - panic!("Unknown stanza: {}", String::from(&stanza)); } - } - - future::ok(()) - }); - - // Start polling `done` - match rt.block_on(done) { - Ok(_) => (), - Err(e) => { - println!("Fatal: {}", e); - () + } else { + println!("stream_ended"); + stream_ended = true; } } } +fn make_error( + to: Jid, + id: String, + type_: ErrorType, + condition: DefinedCondition, + text: &str, +) -> Element { + let error = StanzaError::new(type_, condition, "en", text); + let iq = Iq::from_error(id, error).with_to(to); + iq.into() +} + fn make_disco() -> DiscoInfoResult { let identities = vec![Identity::new("client", "bot", "en", "tokio-xmpp")]; let features = vec![ @@ -235,6 +226,7 @@ fn handle_iq_result(pubsub: PubSub, from: &Jid) { } } +// TODO: may use tokio? fn save_avatar(from: &Jid, id: String, data: &[u8]) -> io::Result<()> { let directory = format!("data/{}", from); let filename = format!("data/{}/{}", from, id); diff --git a/tokio-xmpp/examples/echo_bot.rs b/tokio-xmpp/examples/echo_bot.rs index 3fcd9aa7c7a115c53a4094137f1b320b117e5f16..4ba837cf4c829a4c55339eb8573f04a6420e6328 100644 --- a/tokio-xmpp/examples/echo_bot.rs +++ b/tokio-xmpp/examples/echo_bot.rs @@ -1,14 +1,15 @@ -use futures::{future, Future, Sink, Stream}; +use futures::stream::StreamExt; use std::convert::TryFrom; use std::env::args; use std::process::exit; -use tokio::runtime::current_thread::Runtime; -use tokio_xmpp::{Client, Packet}; +use tokio; +use tokio_xmpp::Client; use xmpp_parsers::message::{Body, Message, MessageType}; use xmpp_parsers::presence::{Presence, Show as PresenceShow, Type as PresenceType}; use xmpp_parsers::{Element, Jid}; -fn main() { +#[tokio::main] +async fn main() { let args: Vec = args().collect(); if args.len() != 3 { println!("Usage: {} ", args[0]); @@ -17,72 +18,50 @@ fn main() { let jid = &args[1]; let password = &args[2]; - // tokio_core context - let mut rt = Runtime::new().unwrap(); // Client instance - let client = Client::new(jid, password).unwrap(); - - // Make the two interfaces for sending and receiving independent - // of each other so we can move one into a closure. - let (sink, stream) = client.split(); - - // Create outgoing pipe - let (mut tx, rx) = futures::unsync::mpsc::unbounded(); - rt.spawn( - rx.forward(sink.sink_map_err(|_| panic!("Pipe"))) - .map(|(rx, mut sink)| { - drop(rx); - let _ = sink.close(); - }) - .map_err(|e| { - panic!("Send error: {:?}", e); - }), - ); + let mut client = Client::new(jid, password.to_owned()).unwrap(); + client.set_reconnect(true); // Main loop, processes events let mut wait_for_stream_end = false; - let done = stream.for_each(move |event| { - if wait_for_stream_end { - /* Do nothing */ - } else if event.is_online() { - let jid = event - .get_jid() - .map(|jid| format!("{}", jid)) - .unwrap_or("unknown".to_owned()); - println!("Online at {}", jid); + let mut stream_ended = false; + while !stream_ended { + if let Some(event) = client.next().await { + println!("event: {:?}", event); + if wait_for_stream_end { + /* Do nothing */ + } else if event.is_online() { + let jid = event + .get_jid() + .map(|jid| format!("{}", jid)) + .unwrap_or("unknown".to_owned()); + println!("Online at {}", jid); - let presence = make_presence(); - tx.start_send(Packet::Stanza(presence)).unwrap(); - } else if let Some(message) = event - .into_stanza() - .and_then(|stanza| Message::try_from(stanza).ok()) - { - match (message.from, message.bodies.get("")) { - (Some(ref from), Some(ref body)) if body.0 == "die" => { - println!("Secret die command triggered by {}", from); - wait_for_stream_end = true; - tx.start_send(Packet::StreamEnd).unwrap(); - } - (Some(ref from), Some(ref body)) => { - if message.type_ != MessageType::Error { - // This is a message we'll echo - let reply = make_reply(from.clone(), &body.0); - tx.start_send(Packet::Stanza(reply)).unwrap(); + let presence = make_presence(); + client.send_stanza(presence).await.unwrap(); + } else if let Some(message) = event + .into_stanza() + .and_then(|stanza| Message::try_from(stanza).ok()) + { + match (message.from, message.bodies.get("")) { + (Some(ref from), Some(ref body)) if body.0 == "die" => { + println!("Secret die command triggered by {}", from); + wait_for_stream_end = true; + client.send_end().await.unwrap(); + } + (Some(ref from), Some(ref body)) => { + if message.type_ != MessageType::Error { + // This is a message we'll echo + let reply = make_reply(from.clone(), &body.0); + client.send_stanza(reply).await.unwrap(); + } } + _ => {} } - _ => {} } - } - - future::ok(()) - }); - - // Start polling `done` - match rt.block_on(done) { - Ok(_) => (), - Err(e) => { - println!("Fatal: {}", e); - () + } else { + println!("stream_ended"); + stream_ended = true; } } } diff --git a/tokio-xmpp/examples/echo_component.rs b/tokio-xmpp/examples/echo_component.rs index 29218ffe14cde3afd986f894103ba7d2f3add4c4..ef8844303a3a1f0518268db6bd51a8ca5abe0d23 100644 --- a/tokio-xmpp/examples/echo_component.rs +++ b/tokio-xmpp/examples/echo_component.rs @@ -1,15 +1,15 @@ -use futures::{future, Sink, Stream}; +use futures::stream::StreamExt; use std::convert::TryFrom; use std::env::args; use std::process::exit; use std::str::FromStr; -use tokio::runtime::current_thread::Runtime; use tokio_xmpp::Component; use xmpp_parsers::message::{Body, Message, MessageType}; use xmpp_parsers::presence::{Presence, Show as PresenceShow, Type as PresenceType}; use xmpp_parsers::{Element, Jid}; -fn main() { +#[tokio::main] +async fn main() { let args: Vec = args().collect(); if args.len() < 3 || args.len() > 5 { println!("Usage: {} [server] [port]", args[0]); @@ -24,57 +24,38 @@ fn main() { .unwrap_or("127.0.0.1".to_owned()); let port: u16 = args.get(4).unwrap().parse().unwrap_or(5347u16); - // tokio_core context - let mut rt = Runtime::new().unwrap(); // Component instance println!("{} {} {} {}", jid, password, server, port); - let component = Component::new(jid, password, server, port).unwrap(); + let mut component = Component::new(jid, password, server, port).await.unwrap(); // Make the two interfaces for sending and receiving independent // of each other so we can move one into a closure. - println!("Got it: {}", component.jid.clone()); - let (mut sink, stream) = component.split(); - // Wrap sink in Option so that we can take() it for the send(self) - // to consume and return it back when ready. - let mut send = move |stanza| { - sink.start_send(stanza).expect("start_send"); - }; - // Main loop, processes events - let done = stream.for_each(|event| { - if event.is_online() { - println!("Online!"); + println!("Online: {}", component.jid); + + // TODO: replace these hardcoded JIDs + let presence = make_presence( + Jid::from_str("test@component.linkmauve.fr/coucou").unwrap(), + Jid::from_str("linkmauve@linkmauve.fr").unwrap(), + ); + component.send_stanza(presence).await.unwrap(); - // TODO: replace these hardcoded JIDs - let presence = make_presence( - Jid::from_str("test@component.linkmauve.fr/coucou").unwrap(), - Jid::from_str("linkmauve@linkmauve.fr").unwrap(), - ); - send(presence); - } else if let Some(message) = event - .into_stanza() - .and_then(|stanza| Message::try_from(stanza).ok()) - { - // This is a message we'll echo - match (message.from, message.bodies.get("")) { - (Some(from), Some(body)) => { - if message.type_ != MessageType::Error { - let reply = make_reply(from, &body.0); - send(reply); + // Main loop, processes events + loop { + if let Some(stanza) = component.next().await { + if let Some(message) = Message::try_from(stanza).ok() { + // This is a message we'll echo + match (message.from, message.bodies.get("")) { + (Some(from), Some(body)) => { + if message.type_ != MessageType::Error { + let reply = make_reply(from, &body.0); + component.send_stanza(reply).await.unwrap(); + } } + _ => (), } - _ => (), } - } - - Box::new(future::ok(())) - }); - - // Start polling `done` - match rt.block_on(done) { - Ok(_) => (), - Err(e) => { - println!("Fatal: {}", e); - () + } else { + break; } } } diff --git a/tokio-xmpp/src/client/auth.rs b/tokio-xmpp/src/client/auth.rs index ca62845508af9bc13ca257172d0b707f39cd8630..28a60dc85d0ab767210df65aa47bf6e77ede09e2 100644 --- a/tokio-xmpp/src/client/auth.rs +++ b/tokio-xmpp/src/client/auth.rs @@ -1,7 +1,4 @@ -use futures::{ - future::{err, ok, IntoFuture}, - Future, Poll, Stream, -}; +use futures::stream::StreamExt; use sasl::client::mechanisms::{Anonymous, Plain, Scram}; use sasl::client::Mechanism; use sasl::common::scram::{Sha1, Sha256}; @@ -9,7 +6,7 @@ use sasl::common::Credentials; use std::collections::HashSet; use std::convert::TryFrom; use std::str::FromStr; -use tokio_io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite}; use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success}; use crate::xmpp_codec::Packet; @@ -18,109 +15,70 @@ use crate::{AuthError, Error, ProtocolError}; const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; -pub struct ClientAuth { - future: Box, Error = Error>>, -} - -impl ClientAuth { - pub fn new(stream: XMPPStream, creds: Credentials) -> Result { - let local_mechs: Vec Box>> = vec![ - Box::new(|| Box::new(Scram::::from_credentials(creds.clone()).unwrap())), - Box::new(|| Box::new(Scram::::from_credentials(creds.clone()).unwrap())), - Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())), - Box::new(|| Box::new(Anonymous::new())), - ]; +pub async fn auth( + mut stream: XMPPStream, + creds: Credentials, +) -> Result { + let local_mechs: Vec Box + Send>> = vec![ + Box::new(|| Box::new(Scram::::from_credentials(creds.clone()).unwrap())), + Box::new(|| Box::new(Scram::::from_credentials(creds.clone()).unwrap())), + Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())), + Box::new(|| Box::new(Anonymous::new())), + ]; - let remote_mechs: HashSet = stream - .stream_features - .get_child("mechanisms", NS_XMPP_SASL) - .ok_or(AuthError::NoMechanism)? - .children() - .filter(|child| child.is("mechanism", NS_XMPP_SASL)) - .map(|mech_el| mech_el.text()) - .collect(); + let remote_mechs: HashSet = stream + .stream_features + .get_child("mechanisms", NS_XMPP_SASL) + .ok_or(AuthError::NoMechanism)? + .children() + .filter(|child| child.is("mechanism", NS_XMPP_SASL)) + .map(|mech_el| mech_el.text()) + .collect(); - for local_mech in local_mechs { - let mut mechanism = local_mech(); - if remote_mechs.contains(mechanism.name()) { - let initial = mechanism.initial().map_err(AuthError::Sasl)?; - let mechanism_name = - XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?; + for local_mech in local_mechs { + let mut mechanism = local_mech(); + if remote_mechs.contains(mechanism.name()) { + let initial = mechanism.initial().map_err(AuthError::Sasl)?; + let mechanism_name = + XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?; - let send_initial = Box::new(stream.send_stanza(Auth { + stream + .send_stanza(Auth { mechanism: mechanism_name, data: initial, - })) - .map_err(Error::Io); - let future = Box::new( - send_initial - .and_then(|stream| Self::handle_challenge(stream, mechanism)) - .and_then(|stream| stream.restart()), - ); - return Ok(ClientAuth { future }); - } - } + }) + .await?; - Err(AuthError::NoMechanism)? - } + loop { + match stream.next().await { + Some(Ok(Packet::Stanza(stanza))) => { + if let Ok(challenge) = Challenge::try_from(stanza.clone()) { + let response = mechanism + .response(&challenge.data) + .map_err(|e| AuthError::Sasl(e))?; - fn handle_challenge( - stream: XMPPStream, - mut mechanism: Box, - ) -> Box, Error = Error>> { - Box::new( - stream - .into_future() - .map_err(|(e, _stream)| e.into()) - .and_then(|(stanza, stream)| { - match stanza { - Some(Packet::Stanza(stanza)) => { - if let Ok(challenge) = Challenge::try_from(stanza.clone()) { - let response = mechanism.response(&challenge.data); - Box::new( - response - .map_err(|e| AuthError::Sasl(e).into()) - .into_future() - .and_then(|response| { - // Send response and loop - stream - .send_stanza(Response { data: response }) - .map_err(Error::Io) - .and_then(|stream| { - Self::handle_challenge(stream, mechanism) - }) - }), - ) - } else if let Ok(_) = Success::try_from(stanza.clone()) { - Box::new(ok(stream)) - } else if let Ok(failure) = Failure::try_from(stanza.clone()) { - Box::new(err(Error::Auth(AuthError::Fail( - failure.defined_condition, - )))) - } else if stanza.name() == "failure" { - // Workaround for https://gitlab.com/xmpp-rs/xmpp-parsers/merge_requests/1 - Box::new(err(Error::Auth(AuthError::Sasl("failure".to_string())))) - } else { - // ignore and loop - Self::handle_challenge(stream, mechanism) - } - } - Some(_) => { + // Send response and loop + stream.send_stanza(Response { data: response }).await?; + } else if let Ok(_) = Success::try_from(stanza.clone()) { + return Ok(stream.into_inner()); + } else if let Ok(failure) = Failure::try_from(stanza.clone()) { + return Err(Error::Auth(AuthError::Fail(failure.defined_condition))); + } else if stanza.name() == "failure" { + // Workaround for https://gitlab.com/xmpp-rs/xmpp-parsers/merge_requests/1 + return Err(Error::Auth(AuthError::Sasl("failure".to_string()))); + } else { // ignore and loop - Self::handle_challenge(stream, mechanism) } - None => Box::new(err(Error::Disconnected)), } - }), - ) + Some(Ok(_)) => { + // ignore and loop + } + Some(Err(e)) => return Err(e), + None => return Err(Error::Disconnected), + } + } + } } -} - -impl Future for ClientAuth { - type Item = XMPPStream; - type Error = Error; - fn poll(&mut self) -> Poll { - self.future.poll() - } + Err(AuthError::NoMechanism.into()) } diff --git a/tokio-xmpp/src/client/bind.rs b/tokio-xmpp/src/client/bind.rs index f164846f4cb47e6b73e2d8deb7374d51a80a0210..6331c94cd0998322950936b6faed80559baa0bca 100644 --- a/tokio-xmpp/src/client/bind.rs +++ b/tokio-xmpp/src/client/bind.rs @@ -1,7 +1,7 @@ -use futures::{sink, Async, Future, Poll, Stream}; +use futures::stream::StreamExt; use std::convert::TryFrom; -use std::mem::replace; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::marker::Unpin; +use tokio::io::{AsyncRead, AsyncWrite}; use xmpp_parsers::bind::{BindQuery, BindResponse}; use xmpp_parsers::iq::{Iq, IqType}; use xmpp_parsers::Jid; @@ -13,90 +13,43 @@ use crate::{Error, ProtocolError}; const NS_XMPP_BIND: &str = "urn:ietf:params:xml:ns:xmpp-bind"; const BIND_REQ_ID: &str = "resource-bind"; -pub enum ClientBind { - Unsupported(XMPPStream), - WaitSend(sink::Send>), - WaitRecv(XMPPStream), - Invalid, -} - -impl ClientBind { - /// Consumes and returns the stream to express that you cannot use - /// the stream for anything else until the resource binding - /// req/resp are done. - pub fn new(stream: XMPPStream) -> Self { - match stream.stream_features.get_child("bind", NS_XMPP_BIND) { - None => +pub async fn bind( + mut stream: XMPPStream, +) -> Result, Error> { + match stream.stream_features.get_child("bind", NS_XMPP_BIND) { + None => { // No resource binding available, // return the (probably // usable) stream immediately - { - ClientBind::Unsupported(stream) - } - Some(_) => { - let resource; - if let Jid::Full(jid) = stream.jid.clone() { - resource = Some(jid.resource); - } else { - resource = None; - } - let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource)); - let send = stream.send_stanza(iq); - ClientBind::WaitSend(send) - } + return Ok(stream); } - } -} + Some(_) => { + let resource = if let Jid::Full(jid) = stream.jid.clone() { + Some(jid.resource) + } else { + None + }; + let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource)); + stream.send_stanza(iq).await?; -impl Future for ClientBind { - type Item = XMPPStream; - type Error = Error; - - fn poll(&mut self) -> Poll { - let state = replace(self, ClientBind::Invalid); - - match state { - ClientBind::Unsupported(stream) => Ok(Async::Ready(stream)), - ClientBind::WaitSend(mut send) => match send.poll() { - Ok(Async::Ready(stream)) => { - replace(self, ClientBind::WaitRecv(stream)); - self.poll() - } - Ok(Async::NotReady) => { - replace(self, ClientBind::WaitSend(send)); - Ok(Async::NotReady) - } - Err(e) => Err(e)?, - }, - ClientBind::WaitRecv(mut stream) => match stream.poll() { - Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => match Iq::try_from(stanza) { - Ok(iq) => { - if iq.id == BIND_REQ_ID { - match iq.payload { - IqType::Result(payload) => { - payload - .and_then(|payload| BindResponse::try_from(payload).ok()) - .map(|bind| stream.jid = bind.into()); - Ok(Async::Ready(stream)) - } - _ => Err(ProtocolError::InvalidBindResponse)?, + loop { + match stream.next().await { + Some(Ok(Packet::Stanza(stanza))) => match Iq::try_from(stanza) { + Ok(iq) if iq.id == BIND_REQ_ID => match iq.payload { + IqType::Result(payload) => { + payload + .and_then(|payload| BindResponse::try_from(payload).ok()) + .map(|bind| stream.jid = bind.into()); + return Ok(stream); } - } else { - Ok(Async::NotReady) - } - } - _ => Ok(Async::NotReady), - }, - Ok(Async::Ready(_)) => { - replace(self, ClientBind::WaitRecv(stream)); - self.poll() - } - Ok(Async::NotReady) => { - replace(self, ClientBind::WaitRecv(stream)); - Ok(Async::NotReady) + _ => return Err(ProtocolError::InvalidBindResponse.into()), + }, + _ => {} + }, + Some(Ok(_)) => {} + Some(Err(e)) => return Err(e), + None => return Err(Error::Disconnected), } - Err(e) => Err(e)?, - }, - ClientBind::Invalid => unreachable!(), + } } } } diff --git a/tokio-xmpp/src/client/mod.rs b/tokio-xmpp/src/client/mod.rs index af61311f2819d113fc26bbade859c6cc397dc7e4..9a6b2af30b88523db20c95a75425af3d7a5ae4a8 100644 --- a/tokio-xmpp/src/client/mod.rs +++ b/tokio-xmpp/src/client/mod.rs @@ -1,28 +1,33 @@ -use futures::{done, Async, AsyncSink, Future, Poll, Sink, StartSend, Stream}; +use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream}; use idna; use sasl::common::{ChannelBinding, Credentials}; use std::mem::replace; +use std::pin::Pin; use std::str::FromStr; +use std::task::Context; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; -use tokio_io::{AsyncRead, AsyncWrite}; +use tokio::task::JoinHandle; +use tokio::task::LocalSet; use tokio_tls::TlsStream; -use xmpp_parsers::{Jid, JidParseError}; +use xmpp_parsers::{Element, Jid, JidParseError}; use super::event::Event; -use super::happy_eyeballs::Connecter; -use super::starttls::{StartTlsClient, NS_XMPP_TLS}; +use super::happy_eyeballs::connect; +use super::starttls::{starttls, NS_XMPP_TLS}; use super::xmpp_codec::Packet; use super::xmpp_stream; use super::{Error, ProtocolError}; mod auth; -use self::auth::ClientAuth; mod bind; -use self::bind::ClientBind; /// XMPP client connection and state pub struct Client { state: ClientState, + jid: Jid, + password: String, + reconnect: bool, } type XMPPStream = xmpp_stream::XMPPStream>; @@ -31,7 +36,7 @@ const NS_JABBER_CLIENT: &str = "jabber:client"; enum ClientState { Invalid, Disconnected, - Connecting(Box>), + Connecting(JoinHandle>, LocalSet), Connected(XMPPStream), } @@ -40,87 +45,87 @@ impl Client { /// /// Start polling the returned instance so that it will connect /// and yield events. - pub fn new(jid: &str, password: &str) -> Result { + pub fn new>(jid: &str, password: P) -> Result { let jid = Jid::from_str(jid)?; - let client = Self::new_with_jid(jid, password); + let client = Self::new_with_jid(jid, password.into()); Ok(client) } /// Start a new client given that the JID is already parsed. - pub fn new_with_jid(jid: Jid, password: &str) -> Self { - let password = password.to_owned(); - let connect = Self::make_connect(jid, password.clone()); + pub fn new_with_jid(jid: Jid, password: String) -> Self { + let local = LocalSet::new(); + let connect = local.spawn_local(Self::connect(jid.clone(), password.clone())); let client = Client { - state: ClientState::Connecting(Box::new(connect)), + jid, + password, + state: ClientState::Connecting(connect, local), + reconnect: false, }; client } - fn make_connect(jid: Jid, password: String) -> impl Future { + /// Set whether to reconnect (`true`) or end the stream (`false`) + /// when a connection to the server has ended. + pub fn set_reconnect(&mut self, reconnect: bool) -> &mut Self { + self.reconnect = reconnect; + self + } + + async fn connect(jid: Jid, password: String) -> Result { let username = jid.clone().node().unwrap(); - let jid1 = jid.clone(); - let jid2 = jid.clone(); let password = password; - done(idna::domain_to_ascii(&jid.domain())) - .map_err(|_| Error::Idna) - .and_then(|domain| { - done(Connecter::from_lookup( - &domain, - Some("_xmpp-client._tcp"), - 5222, - )) - }) - .flatten() - .and_then(move |tcp_stream| { - xmpp_stream::XMPPStream::start(tcp_stream, jid1, NS_JABBER_CLIENT.to_owned()) - }) - .and_then(|xmpp_stream| { - if Self::can_starttls(&xmpp_stream) { - Ok(Self::starttls(xmpp_stream)) - } else { - Err(Error::Protocol(ProtocolError::NoTls)) - } - }) - .flatten() - .and_then(|tls_stream| XMPPStream::start(tls_stream, jid2, NS_JABBER_CLIENT.to_owned())) - .and_then( - move |xmpp_stream| done(Self::auth(xmpp_stream, username, password)), // TODO: flatten? - ) - .and_then(|auth| auth) - .and_then(|xmpp_stream| Self::bind(xmpp_stream)) - .and_then(|xmpp_stream| { - // println!("Bound to {}", xmpp_stream.jid); - Ok(xmpp_stream) - }) - } - - fn can_starttls(stream: &xmpp_stream::XMPPStream) -> bool { - stream + let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?; + + let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), 5222).await?; + + let xmpp_stream = + xmpp_stream::XMPPStream::start(tcp_stream, jid, NS_JABBER_CLIENT.to_owned()).await?; + let xmpp_stream = if Self::can_starttls(&xmpp_stream) { + Self::starttls(xmpp_stream).await? + } else { + return Err(Error::Protocol(ProtocolError::NoTls)); + }; + + let xmpp_stream = Self::auth(xmpp_stream, username, password).await?; + let xmpp_stream = Self::bind(xmpp_stream).await?; + Ok(xmpp_stream) + } + + fn can_starttls( + xmpp_stream: &xmpp_stream::XMPPStream, + ) -> bool { + xmpp_stream .stream_features .get_child("starttls", NS_XMPP_TLS) .is_some() } - fn starttls( - stream: xmpp_stream::XMPPStream, - ) -> StartTlsClient { - StartTlsClient::from_stream(stream) + async fn starttls( + xmpp_stream: xmpp_stream::XMPPStream, + ) -> Result>, Error> { + let jid = xmpp_stream.jid.clone(); + let tls_stream = starttls(xmpp_stream).await?; + xmpp_stream::XMPPStream::start(tls_stream, jid, NS_JABBER_CLIENT.to_owned()).await } - fn auth( - stream: xmpp_stream::XMPPStream, + async fn auth( + xmpp_stream: xmpp_stream::XMPPStream, username: String, password: String, - ) -> Result, Error> { + ) -> Result, Error> { + let jid = xmpp_stream.jid.clone(); let creds = Credentials::default() .with_username(username) .with_password(password) .with_channel_binding(ChannelBinding::None); - ClientAuth::new(stream, creds) + let stream = auth::auth(xmpp_stream, creds).await?; + xmpp_stream::XMPPStream::start(stream, jid, NS_JABBER_CLIENT.to_owned()).await } - fn bind(stream: xmpp_stream::XMPPStream) -> ClientBind { - ClientBind::new(stream) + async fn bind( + stream: xmpp_stream::XMPPStream, + ) -> Result, Error> { + bind::bind(stream).await } /// Get the client's bound JID (the one reported by the XMPP @@ -131,102 +136,150 @@ impl Client { _ => None, } } + + /// Send stanza + pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> { + self.send(Packet::Stanza(stanza)).await + } + + /// End connection + pub async fn send_end(&mut self) -> Result<(), Error> { + self.send(Packet::StreamEnd).await + } } impl Stream for Client { type Item = Event; - type Error = Error; - fn poll(&mut self) -> Poll, Self::Error> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let state = replace(&mut self.state, ClientState::Invalid); match state { - ClientState::Invalid => Err(Error::InvalidState), - ClientState::Disconnected => Ok(Async::Ready(None)), - ClientState::Connecting(mut connect) => match connect.poll() { - Ok(Async::Ready(stream)) => { - let jid = stream.jid.clone(); - self.state = ClientState::Connected(stream); - Ok(Async::Ready(Some(Event::Online(jid)))) - } - Ok(Async::NotReady) => { - self.state = ClientState::Connecting(connect); - Ok(Async::NotReady) + ClientState::Invalid => panic!("Invalid client state"), + ClientState::Disconnected if self.reconnect => { + // TODO: add timeout + let mut local = LocalSet::new(); + let connect = + local.spawn_local(Self::connect(self.jid.clone(), self.password.clone())); + let _ = Pin::new(&mut local).poll(cx); + self.state = ClientState::Connecting(connect, local); + self.poll_next(cx) + } + ClientState::Disconnected => Poll::Ready(None), + ClientState::Connecting(mut connect, mut local) => { + match Pin::new(&mut connect).poll(cx) { + Poll::Ready(Ok(Ok(stream))) => { + let bound_jid = stream.jid.clone(); + self.state = ClientState::Connected(stream); + Poll::Ready(Some(Event::Online(bound_jid))) + } + Poll::Ready(Ok(Err(e))) => { + self.state = ClientState::Disconnected; + return Poll::Ready(Some(Event::Disconnected(e.into()))); + } + Poll::Ready(Err(e)) => { + self.state = ClientState::Disconnected; + panic!("connect task: {}", e); + } + Poll::Pending => { + let _ = Pin::new(&mut local).poll(cx); + + self.state = ClientState::Connecting(connect, local); + Poll::Pending + } } - Err(e) => Err(e), - }, + } ClientState::Connected(mut stream) => { // Poll sink - match stream.poll_complete() { - Ok(Async::NotReady) => (), - Ok(Async::Ready(())) => (), - Err(e) => return Err(e)?, + match Pin::new(&mut stream).poll_ready(cx) { + Poll::Pending => (), + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(e)) => { + self.state = ClientState::Disconnected; + return Poll::Ready(Some(Event::Disconnected(e.into()))); + } }; // Poll stream - match stream.poll() { - Ok(Async::Ready(None)) => { + match Pin::new(&mut stream).poll_next(cx) { + Poll::Ready(None) => { // EOF self.state = ClientState::Disconnected; - Ok(Async::Ready(Some(Event::Disconnected))) + Poll::Ready(Some(Event::Disconnected(Error::Disconnected))) } - Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => { + Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => { // Receive stanza self.state = ClientState::Connected(stream); - Ok(Async::Ready(Some(Event::Stanza(stanza)))) + Poll::Ready(Some(Event::Stanza(stanza))) } - Ok(Async::Ready(Some(Packet::Text(_)))) => { + Poll::Ready(Some(Ok(Packet::Text(_)))) => { // Ignore text between stanzas self.state = ClientState::Connected(stream); - Ok(Async::NotReady) + Poll::Pending } - Ok(Async::Ready(Some(Packet::StreamStart(_)))) => { + Poll::Ready(Some(Ok(Packet::StreamStart(_)))) => { // - Err(ProtocolError::InvalidStreamStart.into()) + self.state = ClientState::Disconnected; + Poll::Ready(Some(Event::Disconnected( + ProtocolError::InvalidStreamStart.into(), + ))) } - Ok(Async::Ready(Some(Packet::StreamEnd))) => { + Poll::Ready(Some(Ok(Packet::StreamEnd))) => { // End of stream: - Ok(Async::Ready(None)) + self.state = ClientState::Disconnected; + Poll::Ready(Some(Event::Disconnected(Error::Disconnected))) } - Ok(Async::NotReady) => { + Poll::Pending => { // Try again later self.state = ClientState::Connected(stream); - Ok(Async::NotReady) + Poll::Pending + } + Poll::Ready(Some(Err(e))) => { + self.state = ClientState::Disconnected; + Poll::Ready(Some(Event::Disconnected(e.into()))) } - Err(e) => Err(e)?, } } } } } -impl Sink for Client { - type SinkItem = Packet; - type SinkError = Error; +impl Sink for Client { + type Error = Error; - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { + fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> { match self.state { - ClientState::Connected(ref mut stream) => Ok(stream.start_send(item)?), - _ => Ok(AsyncSink::NotReady(item)), + ClientState::Connected(ref mut stream) => { + Pin::new(stream).start_send(item).map_err(|e| e.into()) + } + _ => Err(Error::InvalidState), } } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { match self.state { - ClientState::Connected(ref mut stream) => stream.poll_complete().map_err(|e| e.into()), - _ => Ok(Async::Ready(())), + ClientState::Connected(ref mut stream) => { + Pin::new(stream).poll_ready(cx).map_err(|e| e.into()) + } + _ => Poll::Pending, } } - /// This closes the inner TCP stream. - /// - /// To synchronize your shutdown with the server side, you should - /// first send `Packet::StreamEnd` and wait for the end of the - /// incoming stream before closing the connection. - fn close(&mut self) -> Poll<(), Self::SinkError> { + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { match self.state { - ClientState::Connected(ref mut stream) => stream.close().map_err(|e| e.into()), - _ => Ok(Async::Ready(())), + ClientState::Connected(ref mut stream) => { + Pin::new(stream).poll_flush(cx).map_err(|e| e.into()) + } + _ => Poll::Pending, + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match self.state { + ClientState::Connected(ref mut stream) => { + Pin::new(stream).poll_close(cx).map_err(|e| e.into()) + } + _ => Poll::Pending, } } } diff --git a/tokio-xmpp/src/component/auth.rs b/tokio-xmpp/src/component/auth.rs index 44104f4c19052fe1db703563fab894cf32c2d7ca..739c97f03ac80d18de1f840b46865c7f263d5ed0 100644 --- a/tokio-xmpp/src/component/auth.rs +++ b/tokio-xmpp/src/component/auth.rs @@ -1,6 +1,6 @@ -use futures::{sink, Async, Future, Poll, Stream}; -use std::mem::replace; -use tokio_io::{AsyncRead, AsyncWrite}; +use futures::stream::StreamExt; +use std::marker::Unpin; +use tokio::io::{AsyncRead, AsyncWrite}; use xmpp_parsers::component::Handshake; use crate::xmpp_codec::Packet; @@ -9,81 +9,27 @@ use crate::{AuthError, Error}; const NS_JABBER_COMPONENT_ACCEPT: &str = "jabber:component:accept"; -pub struct ComponentAuth { - state: ComponentAuthState, -} - -enum ComponentAuthState { - WaitSend(sink::Send>), - WaitRecv(XMPPStream), - Invalid, -} - -impl ComponentAuth { - // TODO: doesn't have to be a Result<> actually - pub fn new(stream: XMPPStream, password: String) -> Result { - // FIXME: huge hack, shouldn’t be an element! - let sid = stream.stream_features.name().to_owned(); - let mut this = ComponentAuth { - state: ComponentAuthState::Invalid, - }; - this.send( - stream, - Handshake::from_password_and_stream_id(&password, &sid), - ); - Ok(this) - } - - fn send(&mut self, stream: XMPPStream, handshake: Handshake) { - let nonza = handshake; - let send = stream.send_stanza(nonza); - - self.state = ComponentAuthState::WaitSend(send); - } -} - -impl Future for ComponentAuth { - type Item = XMPPStream; - type Error = Error; - - fn poll(&mut self) -> Poll { - let state = replace(&mut self.state, ComponentAuthState::Invalid); - - match state { - ComponentAuthState::WaitSend(mut send) => match send.poll() { - Ok(Async::Ready(stream)) => { - self.state = ComponentAuthState::WaitRecv(stream); - self.poll() - } - Ok(Async::NotReady) => { - self.state = ComponentAuthState::WaitSend(send); - Ok(Async::NotReady) - } - Err(e) => Err(e)?, - }, - ComponentAuthState::WaitRecv(mut stream) => match stream.poll() { - Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) - if stanza.is("handshake", NS_JABBER_COMPONENT_ACCEPT) => - { - self.state = ComponentAuthState::Invalid; - Ok(Async::Ready(stream)) - } - Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) - if stanza.is("error", "http://etherx.jabber.org/streams") => - { - Err(AuthError::ComponentFail.into()) - } - Ok(Async::Ready(_event)) => { - // println!("ComponentAuth ignore {:?}", _event); - Ok(Async::NotReady) - } - Ok(_) => { - self.state = ComponentAuthState::WaitRecv(stream); - Ok(Async::NotReady) - } - Err(e) => Err(e)?, - }, - ComponentAuthState::Invalid => unreachable!(), +pub async fn auth( + stream: &mut XMPPStream, + password: String, +) -> Result<(), Error> { + let nonza = Handshake::from_password_and_stream_id(&password, &stream.id); + stream.send_stanza(nonza).await?; + + loop { + match stream.next().await { + Some(Ok(Packet::Stanza(ref stanza))) + if stanza.is("handshake", NS_JABBER_COMPONENT_ACCEPT) => + { + return Ok(()); + } + Some(Ok(Packet::Stanza(ref stanza))) + if stanza.is("error", "http://etherx.jabber.org/streams") => + { + return Err(AuthError::ComponentFail.into()); + } + Some(_) => {} + None => return Err(Error::Disconnected), } } } diff --git a/tokio-xmpp/src/component/mod.rs b/tokio-xmpp/src/component/mod.rs index a111aacafaf48d9a30d0ed67e93ca7706073a59c..35a2267fce6ac02be3514f46051da11032e32e77 100644 --- a/tokio-xmpp/src/component/mod.rs +++ b/tokio-xmpp/src/component/mod.rs @@ -1,163 +1,115 @@ //! Components in XMPP are services/gateways that are logged into an //! XMPP server under a JID consisting of just a domain name. They are //! allowed to use any user and resource identifiers in their stanzas. -use futures::{done, Async, AsyncSink, Future, Poll, Sink, StartSend, Stream}; -use std::mem::replace; +use futures::{sink::SinkExt, task::Poll, Sink, Stream}; +use std::pin::Pin; use std::str::FromStr; +use std::task::Context; use tokio::net::TcpStream; -use tokio_io::{AsyncRead, AsyncWrite}; -use xmpp_parsers::{Element, Jid, JidParseError}; +use xmpp_parsers::{Element, Jid}; -use super::event::Event; -use super::happy_eyeballs::Connecter; +use super::happy_eyeballs::connect; use super::xmpp_codec::Packet; use super::xmpp_stream; use super::Error; mod auth; -use self::auth::ComponentAuth; /// Component connection to an XMPP server +/// +/// This simplifies the `XMPPStream` to a `Stream`/`Sink` of `Element` +/// (stanzas). Connection handling however is up to the user. pub struct Component { /// The component's Jabber-Id pub jid: Jid, - state: ComponentState, + stream: XMPPStream, } type XMPPStream = xmpp_stream::XMPPStream; const NS_JABBER_COMPONENT_ACCEPT: &str = "jabber:component:accept"; -enum ComponentState { - Invalid, - Disconnected, - Connecting(Box>), - Connected(XMPPStream), -} - impl Component { /// Start a new XMPP component - /// - /// Start polling the returned instance so that it will connect - /// and yield events. - pub fn new(jid: &str, password: &str, server: &str, port: u16) -> Result { + pub async fn new(jid: &str, password: &str, server: &str, port: u16) -> Result { let jid = Jid::from_str(jid)?; let password = password.to_owned(); - let connect = Self::make_connect(jid.clone(), password, server, port); - Ok(Component { - jid, - state: ComponentState::Connecting(Box::new(connect)), - }) + let stream = Self::connect(jid.clone(), password, server, port).await?; + Ok(Component { jid, stream }) } - fn make_connect( + async fn connect( jid: Jid, password: String, server: &str, port: u16, - ) -> impl Future { - let jid1 = jid.clone(); + ) -> Result { let password = password; - done(Connecter::from_lookup(server, None, port)) - .flatten() - .and_then(move |tcp_stream| { - xmpp_stream::XMPPStream::start( - tcp_stream, - jid1, - NS_JABBER_COMPONENT_ACCEPT.to_owned(), - ) - }) - .and_then(move |xmpp_stream| Self::auth(xmpp_stream, password).expect("auth")) + let tcp_stream = connect(server, None, port).await?; + let mut xmpp_stream = + xmpp_stream::XMPPStream::start(tcp_stream, jid, NS_JABBER_COMPONENT_ACCEPT.to_owned()) + .await?; + auth::auth(&mut xmpp_stream, password).await?; + Ok(xmpp_stream) } - fn auth( - stream: xmpp_stream::XMPPStream, - password: String, - ) -> Result, Error> { - ComponentAuth::new(stream, password) + /// Send stanza + pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> { + self.send(stanza).await + } + + /// End connection + pub async fn send_end(&mut self) -> Result<(), Error> { + self.close().await } } impl Stream for Component { - type Item = Event; - type Error = Error; - - fn poll(&mut self) -> Poll, Self::Error> { - let state = replace(&mut self.state, ComponentState::Invalid); + type Item = Element; - match state { - ComponentState::Invalid => Err(Error::InvalidState), - ComponentState::Disconnected => Ok(Async::Ready(None)), - ComponentState::Connecting(mut connect) => match connect.poll() { - Ok(Async::Ready(stream)) => { - self.state = ComponentState::Connected(stream); - Ok(Async::Ready(Some(Event::Online(self.jid.clone())))) + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + loop { + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => return Poll::Ready(Some(stanza)), + Poll::Ready(Some(Ok(Packet::Text(_)))) => { + // retry } - Ok(Async::NotReady) => { - self.state = ComponentState::Connecting(connect); - Ok(Async::NotReady) - } - Err(e) => Err(e), - }, - ComponentState::Connected(mut stream) => { - // Poll sink - match stream.poll_complete() { - Ok(Async::NotReady) => (), - Ok(Async::Ready(())) => (), - Err(e) => return Err(e)?, - }; - - // Poll stream - match stream.poll() { - Ok(Async::NotReady) => { - self.state = ComponentState::Connected(stream); - Ok(Async::NotReady) - } - Ok(Async::Ready(None)) => { - // EOF - self.state = ComponentState::Disconnected; - Ok(Async::Ready(Some(Event::Disconnected))) - } - Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => { - self.state = ComponentState::Connected(stream); - Ok(Async::Ready(Some(Event::Stanza(stanza)))) - } - Ok(Async::Ready(_)) => { - self.state = ComponentState::Connected(stream); - Ok(Async::NotReady) - } - Err(e) => Err(e)?, + Poll::Ready(Some(Ok(_))) => + // unexpected + { + return Poll::Ready(None) } + Poll::Ready(Some(Err(_))) => return Poll::Ready(None), + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, } } } } -impl Sink for Component { - type SinkItem = Element; - type SinkError = Error; +impl Sink for Component { + type Error = Error; - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - match self.state { - ComponentState::Connected(ref mut stream) => match stream - .start_send(Packet::Stanza(item)) - { - Ok(AsyncSink::NotReady(Packet::Stanza(stanza))) => Ok(AsyncSink::NotReady(stanza)), - Ok(AsyncSink::NotReady(_)) => { - panic!("Component.start_send with stanza but got something else back") - } - Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), - Err(e) => Err(e)?, - }, - _ => Ok(AsyncSink::NotReady(item)), - } + fn start_send(mut self: Pin<&mut Self>, item: Element) -> Result<(), Self::Error> { + Pin::new(&mut self.stream) + .start_send(Packet::Stanza(item)) + .map_err(|e| e.into()) } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - match &mut self.state { - &mut ComponentState::Connected(ref mut stream) => { - stream.poll_complete().map_err(|e| e.into()) - } - _ => Ok(Async::Ready(())), - } + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream) + .poll_ready(cx) + .map_err(|e| e.into()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream) + .poll_flush(cx) + .map_err(|e| e.into()) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream) + .poll_close(cx) + .map_err(|e| e.into()) } } diff --git a/tokio-xmpp/src/error.rs b/tokio-xmpp/src/error.rs index 498244425fe96e66945ac7d4163faf1fb791ac83..efde29ecddfbcdd5bdcb748d7a113dde2096874f 100644 --- a/tokio-xmpp/src/error.rs +++ b/tokio-xmpp/src/error.rs @@ -8,7 +8,7 @@ use trust_dns_proto::error::ProtoError; use trust_dns_resolver::error::ResolveError; use xmpp_parsers::sasl::DefinedCondition as SaslDefinedCondition; -use xmpp_parsers::Error as ParsersError; +use xmpp_parsers::{Error as ParsersError, JidParseError}; /// Top-level error type #[derive(Debug)] @@ -20,6 +20,8 @@ pub enum Error { /// DNS label conversion error, no details available from module /// `idna` Idna, + /// Error parsing Jabber-Id + JidParse(JidParseError), /// Protocol-level error Protocol(ProtocolError), /// Authentication error @@ -38,6 +40,7 @@ impl fmt::Display for Error { Error::Io(e) => write!(fmt, "IO error: {}", e), Error::Connection(e) => write!(fmt, "connection error: {}", e), Error::Idna => write!(fmt, "IDNA error"), + Error::JidParse(e) => write!(fmt, "jid parse error: {}", e), Error::Protocol(e) => write!(fmt, "protocol error: {}", e), Error::Auth(e) => write!(fmt, "authentication error: {}", e), Error::Tls(e) => write!(fmt, "TLS error: {}", e), @@ -59,6 +62,12 @@ impl From for Error { } } +impl From for Error { + fn from(e: JidParseError) -> Self { + Error::JidParse(e) + } +} + impl From for Error { fn from(e: ProtocolError) -> Self { Error::Protocol(e) diff --git a/tokio-xmpp/src/event.rs b/tokio-xmpp/src/event.rs index bd3dc402a2c2ad2f6b9eee29414e034db14406a2..db9f4059c4e48fbda8c1e262f09f93b105ec2ae2 100644 --- a/tokio-xmpp/src/event.rs +++ b/tokio-xmpp/src/event.rs @@ -1,3 +1,4 @@ +use super::Error; use xmpp_parsers::{Element, Jid}; /// High-level event on the Stream implemented by Client and Component @@ -6,7 +7,7 @@ pub enum Event { /// Stream is connected and initialized Online(Jid), /// Stream end - Disconnected, + Disconnected(Error), /// Received stanza/nonza Stanza(Element), } diff --git a/tokio-xmpp/src/happy_eyeballs.rs b/tokio-xmpp/src/happy_eyeballs.rs index 39394a51c55de106f4789016d20533de473dc7b1..eae489c2530c41433247f4563cc1ee70df004acb 100644 --- a/tokio-xmpp/src/happy_eyeballs.rs +++ b/tokio-xmpp/src/happy_eyeballs.rs @@ -1,195 +1,63 @@ use crate::{ConnecterError, Error}; -use futures::{Async, Future, Poll}; -use std::cell::RefCell; -use std::collections::BTreeMap; -use std::collections::VecDeque; -use std::io::Error as IoError; -use std::mem; use std::net::SocketAddr; -use tokio::net::tcp::ConnectFuture; use tokio::net::TcpStream; -use trust_dns_resolver::config::LookupIpStrategy; -use trust_dns_resolver::lookup::SrvLookupFuture; -use trust_dns_resolver::lookup_ip::LookupIpFuture; -use trust_dns_resolver::{AsyncResolver, Background, BackgroundLookup, IntoName, Name}; - -enum State { - ResolveSrv(AsyncResolver, BackgroundLookup), - ResolveTarget(AsyncResolver, Background, u16), - Connecting(Option, Vec>), - Invalid, +use trust_dns_resolver::{IntoName, TokioAsyncResolver}; + +async fn connect_to_host( + resolver: &TokioAsyncResolver, + host: &str, + port: u16, +) -> Result { + let ips = resolver + .lookup_ip(host) + .await + .map_err(ConnecterError::Resolve)?; + for ip in ips.iter() { + match TcpStream::connect(&SocketAddr::new(ip, port)).await { + Ok(stream) => return Ok(stream), + Err(_) => {} + } + } + Err(Error::Disconnected) } -pub struct Connecter { +pub async fn connect( + domain: &str, + srv: Option<&str>, fallback_port: u16, - srv_domain: Option, - domain: Name, - state: State, - targets: VecDeque<(Name, u16)>, - error: Option, -} - -fn resolver() -> Result { - let (config, mut opts) = trust_dns_resolver::system_conf::read_system_conf()?; - opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6; - let (resolver, resolver_background) = AsyncResolver::new(config, opts); - tokio::runtime::current_thread::spawn(resolver_background); - Ok(resolver) -} - -impl Connecter { - pub fn from_lookup( - domain: &str, - srv: Option<&str>, - fallback_port: u16, - ) -> Result { - if let Ok(ip) = domain.parse() { - // use specified IP address, not domain name, skip the whole dns part - let connect = RefCell::new(TcpStream::connect(&SocketAddr::new(ip, fallback_port))); - return Ok(Connecter { - fallback_port, - srv_domain: None, - domain: "nohost".into_name().map_err(ConnecterError::Dns)?, - state: State::Connecting(None, vec![connect]), - targets: VecDeque::new(), - error: None, - }); - } - - let srv_domain = match srv { - Some(srv) => Some( - format!("{}.{}.", srv, domain) - .into_name() - .map_err(ConnecterError::Dns)?, - ), - None => None, - }; - - let mut self_ = Connecter { - fallback_port, - srv_domain, - domain: domain.into_name().map_err(ConnecterError::Dns)?, - state: State::Invalid, - targets: VecDeque::new(), - error: None, - }; - - let resolver = resolver()?; - // Initialize state - match &self_.srv_domain { - &Some(ref srv_domain) => { - let srv_lookup = resolver.lookup_srv(srv_domain.clone()); - self_.state = State::ResolveSrv(resolver, srv_lookup); - } - None => { - self_.targets = [(self_.domain.clone(), self_.fallback_port)] - .iter() - .cloned() - .collect(); - self_.state = State::Connecting(Some(resolver), vec![]); - } - } - - Ok(self_) +) -> Result { + if let Ok(ip) = domain.parse() { + return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?); } -} -impl Future for Connecter { - type Item = TcpStream; - type Error = Error; + let resolver = TokioAsyncResolver::tokio_from_system_conf() + .await + .map_err(ConnecterError::Resolve)?; - fn poll(&mut self) -> Poll { - let state = mem::replace(&mut self.state, State::Invalid); - match state { - State::ResolveSrv(resolver, mut srv_lookup) => { - match srv_lookup.poll() { - Ok(Async::NotReady) => { - self.state = State::ResolveSrv(resolver, srv_lookup); - Ok(Async::NotReady) - } - Ok(Async::Ready(srv_result)) => { - let srv_map: BTreeMap<_, _> = srv_result - .iter() - .map(|srv| (srv.priority(), (srv.target().clone(), srv.port()))) - .collect(); - let targets = srv_map.into_iter().map(|(_, tp)| tp).collect(); - self.targets = targets; - self.state = State::Connecting(Some(resolver), vec![]); - self.poll() - } - Err(_) => { - // ignore, fallback - self.targets = [(self.domain.clone(), self.fallback_port)] - .iter() - .cloned() - .collect(); - self.state = State::Connecting(Some(resolver), vec![]); - self.poll() - } - } - } - State::Connecting(resolver, mut connects) => { - if resolver.is_some() && connects.len() == 0 && self.targets.len() > 0 { - let resolver = resolver.unwrap(); - let (host, port) = self.targets.pop_front().unwrap(); - let ip_lookup = resolver.lookup_ip(host); - self.state = State::ResolveTarget(resolver, ip_lookup, port); - self.poll() - } else if connects.len() > 0 { - let mut success = None; - connects.retain(|connect| match connect.borrow_mut().poll() { - Ok(Async::NotReady) => true, - Ok(Async::Ready(connection)) => { - success = Some(connection); - false - } - Err(e) => { - if self.error.is_none() { - self.error = Some(e.into()); - } - false - } - }); - match success { - Some(connection) => Ok(Async::Ready(connection)), - None => { - self.state = State::Connecting(resolver, connects); - Ok(Async::NotReady) - } - } - } else { - // All targets tried - match self.error.take() { - None => Err(ConnecterError::AllFailed.into()), - Some(e) => Err(e), - } - } - } - State::ResolveTarget(resolver, mut ip_lookup, port) => { - match ip_lookup.poll() { - Ok(Async::NotReady) => { - self.state = State::ResolveTarget(resolver, ip_lookup, port); - Ok(Async::NotReady) - } - Ok(Async::Ready(ip_result)) => { - let connects = ip_result - .iter() - .map(|ip| RefCell::new(TcpStream::connect(&SocketAddr::new(ip, port)))) - .collect(); - self.state = State::Connecting(Some(resolver), connects); - self.poll() - } - Err(e) => { - if self.error.is_none() { - self.error = Some(ConnecterError::Resolve(e).into()); - } - // ignore, next… - self.state = State::Connecting(Some(resolver), vec![]); - self.poll() - } + let srv_records = match srv { + Some(srv) => { + let srv_domain = format!("{}.{}.", srv, domain) + .into_name() + .map_err(ConnecterError::Dns)?; + resolver.srv_lookup(srv_domain).await.ok() + } + None => None, + }; + + match srv_records { + Some(lookup) => { + // TODO: sort lookup records by priority/weight + for srv in lookup.iter() { + match connect_to_host(&resolver, &srv.target().to_ascii(), srv.port()).await { + Ok(stream) => return Ok(stream), + Err(_) => {} } } - _ => panic!(""), + Err(Error::Disconnected) + } + None => { + // SRV lookup error, retry with hostname + connect_to_host(&resolver, domain, fallback_port).await } } } diff --git a/tokio-xmpp/src/lib.rs b/tokio-xmpp/src/lib.rs index 2bd823eef822a644becfa69b81c3d3b45782317f..f14ff6b22e379152f02397e0aa87e5577d53569d 100644 --- a/tokio-xmpp/src/lib.rs +++ b/tokio-xmpp/src/lib.rs @@ -6,10 +6,9 @@ mod starttls; mod stream_start; pub mod xmpp_codec; pub use crate::xmpp_codec::Packet; -pub mod xmpp_stream; -pub use crate::starttls::StartTlsClient; mod event; mod happy_eyeballs; +pub mod xmpp_stream; pub use crate::event::Event; mod client; pub use crate::client::Client; diff --git a/tokio-xmpp/src/starttls.rs b/tokio-xmpp/src/starttls.rs index 6c5ba49294276534b639f090b4e8d9366ea9dbd3..de19eb044dbf47a54a3b805623281efd86af2311 100644 --- a/tokio-xmpp/src/starttls.rs +++ b/tokio-xmpp/src/starttls.rs @@ -1,114 +1,39 @@ -use futures::sink; -use futures::stream::Stream; -use futures::{Async, Future, Poll, Sink}; +use futures::{sink::SinkExt, stream::StreamExt}; use native_tls::TlsConnector as NativeTlsConnector; -use std::mem::replace; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_tls::{Connect, TlsConnector, TlsStream}; -use xmpp_parsers::{Element, Jid}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_tls::{TlsConnector, TlsStream}; +use xmpp_parsers::Element; use crate::xmpp_codec::Packet; use crate::xmpp_stream::XMPPStream; -use crate::Error; +use crate::{Error, ProtocolError}; /// XMPP TLS XML namespace pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls"; -/// XMPP stream that switches to TLS if available in received features -pub struct StartTlsClient { - state: StartTlsClientState, - jid: Jid, -} - -enum StartTlsClientState { - Invalid, - SendStartTls(sink::Send>), - AwaitProceed(XMPPStream), - StartingTls(Connect), -} - -impl StartTlsClient { - /// Waits for - pub fn from_stream(xmpp_stream: XMPPStream) -> Self { - let jid = xmpp_stream.jid.clone(); - - let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build(); - let packet = Packet::Stanza(nonza); - let send = xmpp_stream.send(packet); - - StartTlsClient { - state: StartTlsClientState::SendStartTls(send), - jid, +pub async fn starttls( + mut xmpp_stream: XMPPStream, +) -> Result, Error> { + let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build(); + let packet = Packet::Stanza(nonza); + xmpp_stream.send(packet).await?; + + loop { + match xmpp_stream.next().await { + Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break, + Some(Ok(Packet::Text(_))) => {} + Some(Err(e)) => return Err(e.into()), + _ => { + return Err(ProtocolError::NoTls.into()); + } } } -} - -impl Future for StartTlsClient { - type Item = TlsStream; - type Error = Error; - - fn poll(&mut self) -> Poll { - let old_state = replace(&mut self.state, StartTlsClientState::Invalid); - let mut retry = false; - let (new_state, result) = match old_state { - StartTlsClientState::SendStartTls(mut send) => match send.poll() { - Ok(Async::Ready(xmpp_stream)) => { - let new_state = StartTlsClientState::AwaitProceed(xmpp_stream); - retry = true; - (new_state, Ok(Async::NotReady)) - } - Ok(Async::NotReady) => { - (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)) - } - Err(e) => (StartTlsClientState::SendStartTls(send), Err(e.into())), - }, - StartTlsClientState::AwaitProceed(mut xmpp_stream) => match xmpp_stream.poll() { - Ok(Async::Ready(Some(Packet::Stanza(ref stanza)))) - if stanza.name() == "proceed" => - { - let stream = xmpp_stream.stream.into_inner(); - let connect = - TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) - .connect(&self.jid.clone().domain(), stream); - let new_state = StartTlsClientState::StartingTls(connect); - retry = true; - (new_state, Ok(Async::NotReady)) - } - Ok(Async::Ready(_value)) => { - // println!("StartTlsClient ignore {:?}", _value); - ( - StartTlsClientState::AwaitProceed(xmpp_stream), - Ok(Async::NotReady), - ) - } - Ok(_) => ( - StartTlsClientState::AwaitProceed(xmpp_stream), - Ok(Async::NotReady), - ), - Err(e) => ( - StartTlsClientState::AwaitProceed(xmpp_stream), - Err(Error::Protocol(e.into())), - ), - }, - StartTlsClientState::StartingTls(mut connect) => match connect.poll() { - Ok(Async::Ready(tls_stream)) => { - (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream))) - } - Ok(Async::NotReady) => ( - StartTlsClientState::StartingTls(connect), - Ok(Async::NotReady), - ), - Err(e) => (StartTlsClientState::Invalid, Err(e.into())), - }, - StartTlsClientState::Invalid => unreachable!(), - }; + let domain = xmpp_stream.jid.clone().domain(); + let stream = xmpp_stream.into_inner(); + let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) + .connect(&domain, stream) + .await?; - self.state = new_state; - if retry { - self.poll() - } else { - result - } - } + Ok(tls_stream) } diff --git a/tokio-xmpp/src/stream_start.rs b/tokio-xmpp/src/stream_start.rs index 317a9f1ee3f80af01c0cd5d70e961ce2994d8451..1fe16845c66e0faca79c882f3994f3bd1e9c7596 100644 --- a/tokio-xmpp/src/stream_start.rs +++ b/tokio-xmpp/src/stream_start.rs @@ -1,7 +1,7 @@ -use futures::{sink, Async, Future, Poll, Sink, Stream}; -use std::mem::replace; -use tokio_codec::Framed; -use tokio_io::{AsyncRead, AsyncWrite}; +use futures::{sink::SinkExt, stream::StreamExt}; +use std::marker::Unpin; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; use xmpp_parsers::{Element, Jid}; use crate::xmpp_codec::{Packet, XMPPCodec}; @@ -10,116 +10,66 @@ use crate::{Error, ProtocolError}; const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams"; -pub struct StreamStart { - state: StreamStartState, +pub async fn start( + mut stream: Framed, jid: Jid, ns: String, -} - -enum StreamStartState { - SendStart(sink::Send>), - RecvStart(Framed), - RecvFeatures(Framed, String), - Invalid, -} - -impl StreamStart { - pub fn from_stream(stream: Framed, jid: Jid, ns: String) -> Self { - let attrs = [ - ("to".to_owned(), jid.clone().domain()), - ("version".to_owned(), "1.0".to_owned()), - ("xmlns".to_owned(), ns.clone()), - ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()), - ] - .iter() - .cloned() - .collect(); - let send = stream.send(Packet::StreamStart(attrs)); +) -> Result, Error> { + let attrs = [ + ("to".to_owned(), jid.clone().domain()), + ("version".to_owned(), "1.0".to_owned()), + ("xmlns".to_owned(), ns.clone()), + ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()), + ] + .iter() + .cloned() + .collect(); + stream.send(Packet::StreamStart(attrs)).await?; - StreamStart { - state: StreamStartState::SendStart(send), - jid, - ns, + let stream_attrs; + loop { + match stream.next().await { + Some(Ok(Packet::StreamStart(attrs))) => { + stream_attrs = attrs; + break; + } + Some(Ok(_)) => {} + Some(Err(e)) => return Err(e.into()), + None => return Err(Error::Disconnected), } } -} - -impl Future for StreamStart { - type Item = XMPPStream; - type Error = Error; - fn poll(&mut self) -> Poll { - let old_state = replace(&mut self.state, StreamStartState::Invalid); - let mut retry = false; - - let (new_state, result) = match old_state { - StreamStartState::SendStart(mut send) => match send.poll() { - Ok(Async::Ready(stream)) => { - retry = true; - (StreamStartState::RecvStart(stream), Ok(Async::NotReady)) - } - Ok(Async::NotReady) => (StreamStartState::SendStart(send), Ok(Async::NotReady)), - Err(e) => (StreamStartState::Invalid, Err(e.into())), - }, - StreamStartState::RecvStart(mut stream) => match stream.poll() { - Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => { - let stream_ns = stream_attrs - .get("xmlns") - .ok_or(ProtocolError::NoStreamNamespace)? - .clone(); - if self.ns == "jabber:client" { - retry = true; - // TODO: skip RecvFeatures for version < 1.0 - ( - StreamStartState::RecvFeatures(stream, stream_ns), - Ok(Async::NotReady), - ) - } else { - let id = stream_attrs - .get("id") - .ok_or(ProtocolError::NoStreamId)? - .clone(); - // FIXME: huge hack, shouldn’t be an element! - let stream = XMPPStream::new( - self.jid.clone(), - stream, - self.ns.clone(), - Element::builder(id).build(), - ); - (StreamStartState::Invalid, Ok(Async::Ready(stream))) - } + let stream_ns = stream_attrs + .get("xmlns") + .ok_or(ProtocolError::NoStreamNamespace)? + .clone(); + let stream_id = stream_attrs + .get("id") + .ok_or(ProtocolError::NoStreamId)? + .clone(); + let stream = if stream_ns == "jabber:client" && stream_attrs.get("version").is_some() { + let stream_features; + loop { + match stream.next().await { + Some(Ok(Packet::Stanza(stanza))) if stanza.is("features", NS_XMPP_STREAM) => { + stream_features = stanza; + break; } - Ok(Async::Ready(_)) => return Err(ProtocolError::InvalidToken.into()), - Ok(Async::NotReady) => (StreamStartState::RecvStart(stream), Ok(Async::NotReady)), - Err(e) => return Err(ProtocolError::from(e).into()), - }, - StreamStartState::RecvFeatures(mut stream, stream_ns) => match stream.poll() { - Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => { - if stanza.is("features", NS_XMPP_STREAM) { - let stream = - XMPPStream::new(self.jid.clone(), stream, self.ns.clone(), stanza); - (StreamStartState::Invalid, Ok(Async::Ready(stream))) - } else { - ( - StreamStartState::RecvFeatures(stream, stream_ns), - Ok(Async::NotReady), - ) - } - } - Ok(Async::Ready(_)) | Ok(Async::NotReady) => ( - StreamStartState::RecvFeatures(stream, stream_ns), - Ok(Async::NotReady), - ), - Err(e) => return Err(ProtocolError::from(e).into()), - }, - StreamStartState::Invalid => unreachable!(), - }; - - self.state = new_state; - if retry { - self.poll() - } else { - result + Some(Ok(_)) => {} + Some(Err(e)) => return Err(e.into()), + None => return Err(Error::Disconnected), + } } - } + XMPPStream::new(jid, stream, ns, stream_id, stream_features) + } else { + // FIXME: huge hack, shouldn’t be an element! + XMPPStream::new( + jid, + stream, + ns, + stream_id.clone(), + Element::builder(stream_id).build(), + ) + }; + Ok(stream) } diff --git a/tokio-xmpp/src/xmpp_codec.rs b/tokio-xmpp/src/xmpp_codec.rs index b1f6c42a0c70dd050ca3d56b08ce23c996c82711..eaf927d34b299aedc85abbb60607b6408a06749e 100644 --- a/tokio-xmpp/src/xmpp_codec.rs +++ b/tokio-xmpp/src/xmpp_codec.rs @@ -5,16 +5,16 @@ use bytes::{BufMut, BytesMut}; use log::debug; use std; use std::borrow::Cow; -use std::cell::RefCell; use std::collections::vec_deque::VecDeque; use std::collections::HashMap; use std::default::Default; use std::fmt::Write; use std::io; use std::iter::FromIterator; -use std::rc::Rc; use std::str::from_utf8; -use tokio_codec::{Decoder, Encoder}; +use std::sync::Arc; +use std::sync::Mutex; +use tokio_util::codec::{Decoder, Encoder}; use xml5ever::buffer_queue::BufferQueue; use xml5ever::interface::Attribute; use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer}; @@ -38,14 +38,14 @@ type QueueItem = Result; /// Parser state struct ParserSink { // Ready stanzas, shared with XMPPCodec - queue: Rc>>, + queue: Arc>>, // Parsing stack stack: Vec, ns_stack: Vec, String>>, } impl ParserSink { - pub fn new(queue: Rc>>) -> Self { + pub fn new(queue: Arc>>) -> Self { ParserSink { queue, stack: vec![], @@ -54,11 +54,11 @@ impl ParserSink { } fn push_queue(&self, pkt: Packet) { - self.queue.borrow_mut().push_back(Ok(pkt)); + self.queue.lock().unwrap().push_back(Ok(pkt)); } fn push_queue_error(&self, e: ParserError) { - self.queue.borrow_mut().push_back(Err(e)); + self.queue.lock().unwrap().push_back(Err(e)); } /// Lookup XML namespace declaration for given prefix (or no prefix) @@ -169,7 +169,6 @@ impl TokenSink for ParserSink { }, Token::EOFToken => self.push_queue(Packet::StreamEnd), Token::ParseError(s) => { - // println!("ParseError: {:?}", s); self.push_queue_error(ParserError::Parse(ParseError(s))); } _ => (), @@ -190,13 +189,13 @@ pub struct XMPPCodec { // TODO: optimize using tendrils? buf: Vec, /// Shared with ParserSink - queue: Rc>>, + queue: Arc>>, } impl XMPPCodec { /// Constructor pub fn new() -> Self { - let queue = Rc::new(RefCell::new(VecDeque::new())); + let queue = Arc::new(Mutex::new(VecDeque::new())); let sink = ParserSink::new(queue.clone()); // TODO: configure parser? let parser = XmlTokenizer::new(sink, Default::default()); @@ -222,10 +221,10 @@ impl Decoder for XMPPCodec { fn decode(&mut self, buf: &mut BytesMut) -> Result, Self::Error> { let buf1: Box> = if !self.buf.is_empty() && !buf.is_empty() { let mut prefix = std::mem::replace(&mut self.buf, vec![]); - prefix.extend_from_slice(buf.take().as_ref()); + prefix.extend_from_slice(&buf.split_to(buf.len())); Box::new(prefix) } else { - Box::new(buf.take()) + Box::new(buf.split_to(buf.len())) }; let buf1 = buf1.as_ref().as_ref(); match from_utf8(buf1) { @@ -258,7 +257,7 @@ impl Decoder for XMPPCodec { } } - match self.queue.borrow_mut().pop_front() { + match self.queue.lock().unwrap().pop_front() { None => Ok(None), Some(result) => result.map(|pkt| Some(pkt)), } @@ -372,7 +371,7 @@ mod tests { fn test_stream_start() { let mut c = XMPPCodec::new(); let mut b = BytesMut::with_capacity(1024); - b.put(r""); + b.put_slice(b""); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::StreamStart(_))) => true, @@ -384,14 +383,14 @@ mod tests { fn test_stream_end() { let mut c = XMPPCodec::new(); let mut b = BytesMut::with_capacity(1024); - b.put(r""); + b.put_slice(b""); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::StreamStart(_))) => true, _ => false, }); b.clear(); - b.put(r""); + b.put_slice(b""); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::StreamEnd)) => true, @@ -403,7 +402,7 @@ mod tests { fn test_truncated_stanza() { let mut c = XMPPCodec::new(); let mut b = BytesMut::with_capacity(1024); - b.put(r""); + b.put_slice(b""); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::StreamStart(_))) => true, @@ -411,7 +410,7 @@ mod tests { }); b.clear(); - b.put(r"ßß true, @@ -419,7 +418,7 @@ mod tests { }); b.clear(); - b.put(r">"); + b.put_slice(b">"); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true, @@ -431,7 +430,7 @@ mod tests { fn test_truncated_utf8() { let mut c = XMPPCodec::new(); let mut b = BytesMut::with_capacity(1024); - b.put(r""); + b.put_slice(b""); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::StreamStart(_))) => true, @@ -460,7 +459,7 @@ mod tests { fn test_atrribute_prefix() { let mut c = XMPPCodec::new(); let mut b = BytesMut::with_capacity(1024); - b.put(r""); + b.put_slice(b""); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::StreamStart(_))) => true, @@ -468,7 +467,7 @@ mod tests { }); b.clear(); - b.put(r"Test status"); + b.put_slice(b"Test status"); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::Stanza(ref el))) @@ -483,10 +482,10 @@ mod tests { /// By default, encode() only get's a BytesMut that has 8kb space reserved. #[test] fn test_large_stanza() { - use futures::{Future, Sink}; + use futures::{executor::block_on, sink::SinkExt}; use std::io::Cursor; - use tokio_codec::FramedWrite; - let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new()); + use tokio_util::codec::FramedWrite; + let mut framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new()); let mut text = "".to_owned(); for _ in 0..2usize.pow(15) { text = text + "A"; @@ -494,7 +493,7 @@ mod tests { let stanza = Element::builder("message") .append(Element::builder("body").append(text.as_ref()).build()) .build(); - let framed = framed.send(Packet::Stanza(stanza)).wait().expect("send"); + block_on(framed.send(Packet::Stanza(stanza))).expect("send"); assert_eq!( framed.get_ref().get_ref(), &("".to_owned() + &text + "").as_bytes() @@ -505,7 +504,7 @@ mod tests { fn test_cut_out_stanza() { let mut c = XMPPCodec::new(); let mut b = BytesMut::with_capacity(1024); - b.put(r""); + b.put_slice(b""); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::StreamStart(_))) => true, @@ -513,8 +512,8 @@ mod tests { }); b.clear(); - b.put(r"Foo"); + b.put_slice(b"Foo"); let r = c.decode(&mut b); assert!(match r { Ok(Some(Packet::Stanza(_))) => true, diff --git a/tokio-xmpp/src/xmpp_stream.rs b/tokio-xmpp/src/xmpp_stream.rs index 05e57156c3110b4b71981d4208592edb2991b4f8..209b7e005e94d7b4a6da1beeedea47319d240bd5 100644 --- a/tokio-xmpp/src/xmpp_stream.rs +++ b/tokio-xmpp/src/xmpp_stream.rs @@ -1,23 +1,28 @@ //! `XMPPStream` is the common container for all XMPP network connections use futures::sink::Send; -use futures::{Poll, Sink, StartSend, Stream}; -use tokio_codec::Framed; -use tokio_io::{AsyncRead, AsyncWrite}; +use futures::{sink::SinkExt, task::Poll, Sink, Stream}; +use std::ops::DerefMut; +use std::pin::Pin; +use std::sync::Mutex; +use std::task::Context; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; use xmpp_parsers::{Element, Jid}; -use crate::stream_start::StreamStart; +use crate::stream_start; use crate::xmpp_codec::{Packet, XMPPCodec}; +use crate::Error; /// namespace pub const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams"; /// Wraps a `stream` -pub struct XMPPStream { +pub struct XMPPStream { /// The local Jabber-Id pub jid: Jid, /// Codec instance - pub stream: Framed, + pub stream: Mutex>, /// `` for XMPP version 1.0 pub stream_features: Element, /// Root namespace @@ -25,68 +30,94 @@ pub struct XMPPStream { /// This is different for either c2s, s2s, or component /// connections. pub ns: String, + /// Stream `id` attribute + pub id: String, } -impl XMPPStream { +// // TODO: fix this hack +// unsafe impl core::marker::Send for XMPPStream {} +// unsafe impl Sync for XMPPStream {} + +impl XMPPStream { /// Constructor pub fn new( jid: Jid, stream: Framed, ns: String, + id: String, stream_features: Element, ) -> Self { XMPPStream { jid, - stream, + stream: Mutex::new(stream), stream_features, ns, + id, } } /// Send a `` start tag - pub fn start(stream: S, jid: Jid, ns: String) -> StreamStart { + pub async fn start<'a>(stream: S, jid: Jid, ns: String) -> Result { let xmpp_stream = Framed::new(stream, XMPPCodec::new()); - StreamStart::from_stream(xmpp_stream, jid, ns) + stream_start::start(xmpp_stream, jid, ns).await } /// Unwraps the inner stream + // TODO: use this everywhere pub fn into_inner(self) -> S { - self.stream.into_inner() + self.stream.into_inner().unwrap().into_inner() } /// Re-run `start()` - pub fn restart(self) -> StreamStart { - Self::start(self.stream.into_inner(), self.jid, self.ns) + pub async fn restart<'a>(self) -> Result { + let stream = self.stream.into_inner().unwrap().into_inner(); + Self::start(stream, self.jid, self.ns).await } } -impl XMPPStream { +impl XMPPStream { /// Convenience method - pub fn send_stanza>(self, e: E) -> Send { + pub fn send_stanza>(&mut self, e: E) -> Send { self.send(Packet::Stanza(e.into())) } } /// Proxy to self.stream -impl Sink for XMPPStream { - type SinkItem = as Sink>::SinkItem; - type SinkError = as Sink>::SinkError; +impl Sink for XMPPStream { + type Error = crate::Error; + + fn poll_ready(self: Pin<&mut Self>, _ctx: &mut Context) -> Poll> { + // Pin::new(&mut self.stream).poll_ready(ctx) + // .map_err(|e| e.into()) + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> { + Pin::new(&mut self.stream.lock().unwrap().deref_mut()) + .start_send(item) + .map_err(|e| e.into()) + } - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - self.stream.start_send(item) + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream.lock().unwrap().deref_mut()) + .poll_flush(cx) + .map_err(|e| e.into()) } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.stream.poll_complete() + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream.lock().unwrap().deref_mut()) + .poll_close(cx) + .map_err(|e| e.into()) } } /// Proxy to self.stream -impl Stream for XMPPStream { - type Item = as Stream>::Item; - type Error = as Stream>::Error; +impl Stream for XMPPStream { + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - self.stream.poll() + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream.lock().unwrap().deref_mut()) + .poll_next(cx) + .map(|result| result.map(|result| result.map_err(|e| e.into()))) } }