Detailed changes
@@ -8,7 +8,8 @@ use std::io::BufReader;
use std::fs::File;
use tokio_core::reactor::Core;
use futures::{Future, Stream};
-use tokio_xmpp::{Packet, TcpClient, StartTlsClient};
+use tokio_xmpp::TcpClient;
+use tokio_xmpp::xmpp_codec::Packet;
use rustls::ClientConfig;
fn main() {
@@ -26,8 +27,13 @@ fn main() {
let client = TcpClient::connect(
&addr,
&core.handle()
- ).and_then(|stream| StartTlsClient::from_stream(stream, arc_config)
).and_then(|stream| {
+ if stream.can_starttls() {
+ stream.starttls(arc_config)
+ } else {
+ panic!("No STARTTLS")
+ }
+ }).and_then(|stream| {
stream.for_each(|event| {
match event {
Packet::Stanza(el) => println!("<< {}", el),
@@ -8,8 +8,9 @@ extern crate rustls;
extern crate tokio_rustls;
-mod xmpp_codec;
-pub use xmpp_codec::*;
+pub mod xmpp_codec;
+pub mod xmpp_stream;
+mod stream_start;
mod tcp;
pub use tcp::*;
mod starttls;
@@ -1,5 +1,5 @@
use std::mem::replace;
-use std::io::{Error, ErrorKind};
+use std::io::Error;
use std::sync::Arc;
use futures::{Future, Sink, Poll, Async};
use futures::stream::Stream;
@@ -9,36 +9,44 @@ use rustls::*;
use tokio_rustls::*;
use xml;
-use super::{XMPPStream, XMPPCodec, Packet};
+use xmpp_codec::*;
+use xmpp_stream::*;
+use stream_start::StreamStart;
-const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
-const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
+pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
-pub struct StartTlsClient<S: AsyncWrite> {
+pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
state: StartTlsClientState<S>,
arc_config: Arc<ClientConfig>,
}
-enum StartTlsClientState<S: AsyncWrite> {
+enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
Invalid,
- AwaitFeatures(XMPPStream<S>),
SendStartTls(sink::Send<XMPPStream<S>>),
AwaitProceed(XMPPStream<S>),
StartingTls(ConnectAsync<S>),
+ Start(StreamStart<TlsStream<S, ClientSession>>),
}
-impl<S: AsyncWrite> StartTlsClient<S> {
+impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
/// Waits for <stream:features>
pub fn from_stream(xmpp_stream: XMPPStream<S>, arc_config: Arc<ClientConfig>) -> Self {
+ let nonza = xml::Element::new(
+ "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
+ vec![]
+ );
+ println!("send {}", nonza);
+ let packet = Packet::Stanza(nonza);
+ let send = xmpp_stream.send(packet);
+
StartTlsClient {
- state: StartTlsClientState::AwaitFeatures(xmpp_stream),
+ state: StartTlsClientState::SendStartTls(send),
arc_config: arc_config,
}
}
}
-// TODO: eval <stream:features>, check ns
impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
type Item = XMPPStream<TlsStream<S, ClientSession>>;
type Error = Error;
@@ -48,40 +56,6 @@ impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
let mut retry = false;
let (new_state, result) = match old_state {
- StartTlsClientState::AwaitFeatures(mut xmpp_stream) =>
- match xmpp_stream.poll() {
- Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
- if stanza.name == "features"
- && stanza.ns == Some(NS_XMPP_STREAM.to_owned())
- =>
- {
- println!("Got features: {}", stanza);
- match stanza.get_child("starttls", Some(NS_XMPP_TLS)) {
- None =>
- (StartTlsClientState::Invalid, Err(Error::from(ErrorKind::InvalidData))),
- Some(_) => {
- let nonza = xml::Element::new(
- "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
- vec![]
- );
- println!("send {}", nonza);
- let packet = Packet::Stanza(nonza);
- let send = xmpp_stream.send(packet);
- let new_state = StartTlsClientState::SendStartTls(send);
- retry = true;
- (new_state, Ok(Async::NotReady))
- },
- }
- },
- Ok(Async::Ready(value)) => {
- println!("StartTlsClient ignore {:?}", value);
- (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady))
- },
- Ok(_) =>
- (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady)),
- Err(e) =>
- (StartTlsClientState::AwaitFeatures(xmpp_stream), Err(e)),
- },
StartTlsClientState::SendStartTls(mut send) =>
match send.poll() {
Ok(Async::Ready(xmpp_stream)) => {
@@ -109,7 +83,7 @@ impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
},
Ok(Async::Ready(value)) => {
println!("StartTlsClient ignore {:?}", value);
- (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady))
+ (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
},
Ok(_) =>
(StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
@@ -120,14 +94,25 @@ impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
match connect.poll() {
Ok(Async::Ready(tls_stream)) => {
println!("Got a TLS stream!");
- let xmpp_stream = XMPPCodec::frame_stream(tls_stream);
- (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream)))
+ let start = XMPPStream::from_stream(tls_stream, "spaceboyz.net".to_owned());
+ let new_state = StartTlsClientState::Start(start);
+ retry = true;
+ (new_state, Ok(Async::NotReady))
},
Ok(Async::NotReady) =>
(StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
Err(e) =>
(StartTlsClientState::StartingTls(connect), Err(e)),
},
+ StartTlsClientState::Start(mut start) =>
+ match start.poll() {
+ Ok(Async::Ready(xmpp_stream)) =>
+ (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))),
+ Ok(Async::NotReady) =>
+ (StartTlsClientState::Start(start), Ok(Async::NotReady)),
+ Err(e) =>
+ (StartTlsClientState::Invalid, Err(e)),
+ },
StartTlsClientState::Invalid =>
unreachable!(),
};
@@ -0,0 +1,102 @@
+use std::mem::replace;
+use std::io::{Error, ErrorKind};
+use std::collections::HashMap;
+use futures::*;
+use tokio_io::{AsyncRead, AsyncWrite};
+use tokio_io::codec::Framed;
+
+use xmpp_codec::*;
+use xmpp_stream::*;
+
+const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
+
+pub struct StreamStart<S: AsyncWrite> {
+ state: StreamStartState<S>,
+}
+
+enum StreamStartState<S: AsyncWrite> {
+ SendStart(sink::Send<Framed<S, XMPPCodec>>),
+ RecvStart(Framed<S, XMPPCodec>),
+ RecvFeatures(Framed<S, XMPPCodec>, HashMap<String, String>),
+ Invalid,
+}
+
+impl<S: AsyncWrite> StreamStart<S> {
+ pub fn from_stream(stream: Framed<S, XMPPCodec>, to: String) -> Self {
+ let attrs = [("to".to_owned(), to),
+ ("version".to_owned(), "1.0".to_owned()),
+ ("xmlns".to_owned(), "jabber:client".to_owned()),
+ ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
+ ].iter().cloned().collect();
+ let send = stream.send(Packet::StreamStart(attrs));
+
+ StreamStart {
+ state: StreamStartState::SendStart(send),
+ }
+ }
+}
+
+impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
+ type Item = XMPPStream<S>;
+ type Error = Error;
+
+ fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
+ let old_state = replace(&mut self.state, StreamStartState::Invalid);
+ let mut retry = false;
+
+ let (new_state, result) = match old_state {
+ StreamStartState::SendStart(mut send) =>
+ match send.poll() {
+ Ok(Async::Ready(stream)) => {
+ retry = true;
+ (StreamStartState::RecvStart(stream), Ok(Async::NotReady))
+ },
+ Ok(Async::NotReady) =>
+ (StreamStartState::SendStart(send), Ok(Async::NotReady)),
+ Err(e) =>
+ (StreamStartState::Invalid, Err(e)),
+ },
+ StreamStartState::RecvStart(mut stream) =>
+ match stream.poll() {
+ Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => {
+ retry = true;
+ // TODO: skip RecvFeatures for version < 1.0
+ (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
+ },
+ Ok(Async::Ready(_)) =>
+ return Err(Error::from(ErrorKind::InvalidData)),
+ Ok(Async::NotReady) =>
+ (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
+ Err(e) =>
+ return Err(e),
+ },
+ StreamStartState::RecvFeatures(mut stream, stream_attrs) =>
+ match stream.poll() {
+ Ok(Async::Ready(Some(Packet::Stanza(stanza)))) =>
+ if stanza.name == "features"
+ && stanza.ns == Some(NS_XMPP_STREAM.to_owned()) {
+ (StreamStartState::Invalid, Ok(Async::Ready(XMPPStream { stream, stream_attrs, stream_features: stanza })))
+ } else {
+ (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
+ },
+ Ok(Async::Ready(item)) => {
+ println!("StreamStart skip {:?}", item);
+ (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
+ },
+ Ok(Async::NotReady) =>
+ (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady)),
+ Err(e) =>
+ return Err(e),
+ },
+ StreamStartState::Invalid =>
+ unreachable!(),
+ };
+
+ self.state = new_state;
+ if retry {
+ self.poll()
+ } else {
+ result
+ }
+ }
+}
@@ -1,40 +1,22 @@
-use std::fmt;
use std::net::SocketAddr;
-use std::io::{Error, ErrorKind};
-use futures::{Future, Sink, Poll, Async};
-use futures::stream::Stream;
-use futures::sink;
+use std::io::Error;
+use futures::{Future, Poll, Async};
use tokio_core::reactor::Handle;
use tokio_core::net::{TcpStream, TcpStreamNew};
-use super::{XMPPStream, XMPPCodec, Packet};
+use xmpp_stream::*;
+use stream_start::StreamStart;
-
-#[derive(Debug)]
pub struct TcpClient {
state: TcpClientState,
}
enum TcpClientState {
Connecting(TcpStreamNew),
- SendStart(sink::Send<XMPPStream<TcpStream>>),
- RecvStart(Option<XMPPStream<TcpStream>>),
+ Start(StreamStart<TcpStream>),
Established,
}
-impl fmt::Debug for TcpClientState {
- fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
- let s = match *self {
- TcpClientState::Connecting(_) => "Connecting",
- TcpClientState::SendStart(_) => "SendStart",
- TcpClientState::RecvStart(_) => "RecvStart",
- TcpClientState::Established => "Established",
- };
- try!(write!(fmt, "{}", s));
- Ok(())
- }
-}
-
impl TcpClient {
pub fn connect(addr: &SocketAddr, handle: &Handle) -> Self {
let tcp_stream_new = TcpStream::connect(addr, handle);
@@ -52,27 +34,12 @@ impl Future for TcpClient {
let (new_state, result) = match self.state {
TcpClientState::Connecting(ref mut tcp_stream_new) => {
let tcp_stream = try_ready!(tcp_stream_new.poll());
- let xmpp_stream = XMPPCodec::frame_stream(tcp_stream);
- let send = xmpp_stream.send(Packet::StreamStart);
- let new_state = TcpClientState::SendStart(send);
- (new_state, Ok(Async::NotReady))
- },
- TcpClientState::SendStart(ref mut send) => {
- let xmpp_stream = try_ready!(send.poll());
- let new_state = TcpClientState::RecvStart(Some(xmpp_stream));
+ let start = XMPPStream::from_stream(tcp_stream, "spaceboyz.net".to_owned());
+ let new_state = TcpClientState::Start(start);
(new_state, Ok(Async::NotReady))
},
- TcpClientState::RecvStart(ref mut opt_xmpp_stream) => {
- let mut xmpp_stream = opt_xmpp_stream.take().unwrap();
- match xmpp_stream.poll() {
- Ok(Async::Ready(Some(Packet::StreamStart))) => println!("Recv start!"),
- Ok(Async::Ready(_)) => return Err(Error::from(ErrorKind::InvalidData)),
- Ok(Async::NotReady) => {
- *opt_xmpp_stream = Some(xmpp_stream);
- return Ok(Async::NotReady);
- },
- Err(e) => return Err(e)
- };
+ TcpClientState::Start(ref mut start) => {
+ let xmpp_stream = try_ready!(start.poll());
let new_state = TcpClientState::Established;
(new_state, Ok(Async::Ready(xmpp_stream)))
},
@@ -80,7 +47,6 @@ impl Future for TcpClient {
unreachable!(),
};
- println!("Next state: {:?}", new_state);
self.state = new_state;
match result {
// by polling again, we register new future
@@ -3,18 +3,17 @@ use std::fmt::Write;
use std::str::from_utf8;
use std::io::{Error, ErrorKind};
use std::collections::HashMap;
-use tokio_io::{AsyncRead, AsyncWrite};
-use tokio_io::codec::{Framed, Encoder, Decoder};
+use tokio_io::codec::{Encoder, Decoder};
use xml;
use bytes::*;
const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/";
-const NS_STREAMS: &'static str = "http://etherx.jabber.org/streams";
-const NS_CLIENT: &'static str = "jabber:client";
+
+pub type Attributes = HashMap<(String, Option<String>), String>;
struct XMPPRoot {
builder: xml::ElementBuilder,
- pub attributes: HashMap<(String, Option<String>), String>,
+ pub attributes: Attributes,
}
impl XMPPRoot {
@@ -49,13 +48,11 @@ impl XMPPRoot {
#[derive(Debug)]
pub enum Packet {
Error(Box<std::error::Error>),
- StreamStart,
+ StreamStart(HashMap<String, String>),
Stanza(xml::Element),
StreamEnd,
}
-pub type XMPPStream<T> = Framed<T, XMPPCodec>;
-
pub struct XMPPCodec {
parser: xml::Parser,
root: Option<XMPPRoot>,
@@ -68,12 +65,6 @@ impl XMPPCodec {
root: None,
}
}
-
- pub fn frame_stream<S>(stream: S) -> Framed<S, XMPPCodec>
- where S: AsyncRead + AsyncWrite
- {
- AsyncRead::framed(stream, XMPPCodec::new())
- }
}
impl Decoder for XMPPCodec {
@@ -97,8 +88,12 @@ impl Decoder for XMPPCodec {
// Expecting <stream:stream>
match event {
Ok(xml::Event::ElementStart(start_tag)) => {
+ let mut attrs: HashMap<String, String> = HashMap::new();
+ for (&(ref name, _), value) in &start_tag.attributes {
+ attrs.insert(name.to_owned(), value.to_owned());
+ }
+ result = Some(Packet::StreamStart(attrs));
self.root = Some(XMPPRoot::new(start_tag));
- result = Some(Packet::StreamStart);
break
},
Err(e) => {
@@ -146,18 +141,23 @@ impl Encoder for XMPPCodec {
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
match item {
- Packet::StreamStart => {
- write!(dst,
- "<?xml version='1.0'?>\n
-<stream:stream version='1.0' to='spaceboyz.net' xmlns='{}' xmlns:stream='{}'>\n",
- NS_CLIENT, NS_STREAMS)
- .map_err(|_| Error::from(ErrorKind::WriteZero))
+ Packet::StreamStart(start_attrs) => {
+ let mut buf = String::new();
+ write!(buf, "<stream:stream").unwrap();
+ for (ref name, ref value) in &start_attrs {
+ write!(buf, " {}=\"{}\"", xml::escape(&name), xml::escape(&value))
+ .unwrap();
+ }
+ write!(buf, ">\n").unwrap();
+
+ println!("Encode start to {}", buf);
+ write!(dst, "{}", buf)
},
Packet::Stanza(stanza) =>
- write!(dst, "{}", stanza)
- .map_err(|_| Error::from(ErrorKind::InvalidInput)),
+ write!(dst, "{}", stanza),
// TODO: Implement all
_ => Ok(())
}
+ .map_err(|_| Error::from(ErrorKind::InvalidInput))
}
}
@@ -0,0 +1,62 @@
+use std::sync::Arc;
+use std::collections::HashMap;
+use futures::*;
+use tokio_io::{AsyncRead, AsyncWrite};
+use tokio_io::codec::Framed;
+use rustls::ClientConfig;
+use xml;
+
+use xmpp_codec::*;
+use stream_start::*;
+use starttls::{NS_XMPP_TLS, StartTlsClient};
+
+pub const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
+
+pub struct XMPPStream<S> {
+ pub stream: Framed<S, XMPPCodec>,
+ pub stream_attrs: HashMap<String, String>,
+ pub stream_features: xml::Element,
+}
+
+impl<S: AsyncRead + AsyncWrite> XMPPStream<S> {
+ pub fn from_stream(stream: S, to: String) -> StreamStart<S> {
+ let xmpp_stream = AsyncRead::framed(stream, XMPPCodec::new());
+ StreamStart::from_stream(xmpp_stream, to)
+ }
+
+ pub fn into_inner(self) -> S {
+ self.stream.into_inner()
+ }
+
+ pub fn can_starttls(&self) -> bool {
+ self.stream_features
+ .get_child("starttls", Some(NS_XMPP_TLS))
+ .is_some()
+ }
+
+ pub fn starttls(self, arc_config: Arc<ClientConfig>) -> StartTlsClient<S> {
+ StartTlsClient::from_stream(self, arc_config)
+ }
+}
+
+impl<S: AsyncWrite> Sink for XMPPStream<S> {
+ type SinkItem = <Framed<S, XMPPCodec> as Sink>::SinkItem;
+ type SinkError = <Framed<S, XMPPCodec> as Sink>::SinkError;
+
+ fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
+ self.stream.start_send(item)
+ }
+
+ fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
+ self.stream.poll_complete()
+ }
+}
+
+impl<S: AsyncRead> Stream for XMPPStream<S> {
+ type Item = <Framed<S, XMPPCodec> as Stream>::Item;
+ type Error = <Framed<S, XMPPCodec> as Stream>::Error;
+
+ fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
+ self.stream.poll()
+ }
+}