refactor into stream_start + xmpp_stream

Astro created

Change summary

examples/echo_bot.rs |  10 +++
src/lib.rs           |   5 +
src/starttls.rs      |  79 ++++++++++++++---------------------
src/stream_start.rs  | 102 ++++++++++++++++++++++++++++++++++++++++++++++
src/tcp.rs           |  52 ++++-------------------
src/xmpp_codec.rs    |  46 ++++++++++----------
src/xmpp_stream.rs   |  62 +++++++++++++++++++++++++++
7 files changed, 239 insertions(+), 117 deletions(-)

Detailed changes

examples/echo_bot.rs 🔗

@@ -8,7 +8,8 @@ use std::io::BufReader;
 use std::fs::File;
 use tokio_core::reactor::Core;
 use futures::{Future, Stream};
-use tokio_xmpp::{Packet, TcpClient, StartTlsClient};
+use tokio_xmpp::TcpClient;
+use tokio_xmpp::xmpp_codec::Packet;
 use rustls::ClientConfig;
 
 fn main() {
@@ -26,8 +27,13 @@ fn main() {
     let client = TcpClient::connect(
         &addr,
         &core.handle()
-    ).and_then(|stream| StartTlsClient::from_stream(stream, arc_config)
     ).and_then(|stream| {
+        if stream.can_starttls() {
+            stream.starttls(arc_config)
+        } else {
+            panic!("No STARTTLS")
+        }
+    }).and_then(|stream| {
         stream.for_each(|event| {
             match event {
                 Packet::Stanza(el) => println!("<< {}", el),

src/lib.rs 🔗

@@ -8,8 +8,9 @@ extern crate rustls;
 extern crate tokio_rustls;
 
 
-mod xmpp_codec;
-pub use xmpp_codec::*;
+pub mod xmpp_codec;
+pub mod xmpp_stream;
+mod stream_start;
 mod tcp;
 pub use tcp::*;
 mod starttls;

src/starttls.rs 🔗

@@ -1,5 +1,5 @@
 use std::mem::replace;
-use std::io::{Error, ErrorKind};
+use std::io::Error;
 use std::sync::Arc;
 use futures::{Future, Sink, Poll, Async};
 use futures::stream::Stream;
@@ -9,36 +9,44 @@ use rustls::*;
 use tokio_rustls::*;
 use xml;
 
-use super::{XMPPStream, XMPPCodec, Packet};
+use xmpp_codec::*;
+use xmpp_stream::*;
+use stream_start::StreamStart;
 
 
-const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
-const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
+pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
 
-pub struct StartTlsClient<S: AsyncWrite> {
+pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
     state: StartTlsClientState<S>,
     arc_config: Arc<ClientConfig>,
 }
 
-enum StartTlsClientState<S: AsyncWrite> {
+enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
     Invalid,
-    AwaitFeatures(XMPPStream<S>),
     SendStartTls(sink::Send<XMPPStream<S>>),
     AwaitProceed(XMPPStream<S>),
     StartingTls(ConnectAsync<S>),
+    Start(StreamStart<TlsStream<S, ClientSession>>),
 }
 
-impl<S: AsyncWrite> StartTlsClient<S> {
+impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
     /// Waits for <stream:features>
     pub fn from_stream(xmpp_stream: XMPPStream<S>, arc_config: Arc<ClientConfig>) -> Self {
+        let nonza = xml::Element::new(
+            "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
+            vec![]
+        );
+        println!("send {}", nonza);
+        let packet = Packet::Stanza(nonza);
+        let send = xmpp_stream.send(packet);
+
         StartTlsClient {
-            state: StartTlsClientState::AwaitFeatures(xmpp_stream),
+            state: StartTlsClientState::SendStartTls(send),
             arc_config: arc_config,
         }
     }
 }
 
-// TODO: eval <stream:features>, check ns
 impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
     type Item = XMPPStream<TlsStream<S, ClientSession>>;
     type Error = Error;
@@ -48,40 +56,6 @@ impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
         let mut retry = false;
         
         let (new_state, result) = match old_state {
-            StartTlsClientState::AwaitFeatures(mut xmpp_stream) =>
-                match xmpp_stream.poll() {
-                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
-                        if stanza.name == "features"
-                        && stanza.ns == Some(NS_XMPP_STREAM.to_owned())
-                        =>
-                    {
-                        println!("Got features: {}", stanza);
-                        match stanza.get_child("starttls", Some(NS_XMPP_TLS)) {
-                            None =>
-                                (StartTlsClientState::Invalid, Err(Error::from(ErrorKind::InvalidData))),
-                            Some(_) => {
-                                let nonza = xml::Element::new(
-                                    "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
-                                    vec![]
-                                );
-                                println!("send {}", nonza);
-                                let packet = Packet::Stanza(nonza);
-                                let send = xmpp_stream.send(packet);
-                                let new_state = StartTlsClientState::SendStartTls(send);
-                                retry = true;
-                                (new_state, Ok(Async::NotReady))
-                            },
-                        }
-                    },
-                    Ok(Async::Ready(value)) => {
-                        println!("StartTlsClient ignore {:?}", value);
-                        (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady))
-                    },
-                    Ok(_) =>
-                        (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady)),
-                    Err(e) =>
-                        (StartTlsClientState::AwaitFeatures(xmpp_stream), Err(e)),
-                },
             StartTlsClientState::SendStartTls(mut send) =>
                 match send.poll() {
                     Ok(Async::Ready(xmpp_stream)) => {
@@ -109,7 +83,7 @@ impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
                     },
                     Ok(Async::Ready(value)) => {
                         println!("StartTlsClient ignore {:?}", value);
-                        (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady))
+                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
                     },
                     Ok(_) =>
                         (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
@@ -120,14 +94,25 @@ impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
                 match connect.poll() {
                     Ok(Async::Ready(tls_stream)) => {
                         println!("Got a TLS stream!");
-                        let xmpp_stream = XMPPCodec::frame_stream(tls_stream);
-                        (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream)))
+                        let start = XMPPStream::from_stream(tls_stream, "spaceboyz.net".to_owned());
+                        let new_state = StartTlsClientState::Start(start);
+                        retry = true;
+                        (new_state, Ok(Async::NotReady))
                     },
                     Ok(Async::NotReady) =>
                         (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
                     Err(e) =>
                         (StartTlsClientState::StartingTls(connect),  Err(e)),
                 },
+            StartTlsClientState::Start(mut start) =>
+                match start.poll() {
+                    Ok(Async::Ready(xmpp_stream)) =>
+                        (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))),
+                    Ok(Async::NotReady) =>
+                        (StartTlsClientState::Start(start), Ok(Async::NotReady)),
+                    Err(e) =>
+                        (StartTlsClientState::Invalid, Err(e)),
+                },
             StartTlsClientState::Invalid =>
                 unreachable!(),
         };

src/stream_start.rs 🔗

@@ -0,0 +1,102 @@
+use std::mem::replace;
+use std::io::{Error, ErrorKind};
+use std::collections::HashMap;
+use futures::*;
+use tokio_io::{AsyncRead, AsyncWrite};
+use tokio_io::codec::Framed;
+
+use xmpp_codec::*;
+use xmpp_stream::*;
+
+const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
+
+pub struct StreamStart<S: AsyncWrite> {
+    state: StreamStartState<S>,
+}
+
+enum StreamStartState<S: AsyncWrite> {
+    SendStart(sink::Send<Framed<S, XMPPCodec>>),
+    RecvStart(Framed<S, XMPPCodec>),
+    RecvFeatures(Framed<S, XMPPCodec>, HashMap<String, String>),
+    Invalid,
+}
+
+impl<S: AsyncWrite> StreamStart<S> {
+    pub fn from_stream(stream: Framed<S, XMPPCodec>, to: String) -> Self {
+        let attrs = [("to".to_owned(), to),
+                     ("version".to_owned(), "1.0".to_owned()),
+                     ("xmlns".to_owned(), "jabber:client".to_owned()),
+                     ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
+        ].iter().cloned().collect();
+        let send = stream.send(Packet::StreamStart(attrs));
+
+        StreamStart {
+            state: StreamStartState::SendStart(send),
+        }
+    }
+}
+
+impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
+    type Item = XMPPStream<S>;
+    type Error = Error;
+
+    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
+        let old_state = replace(&mut self.state, StreamStartState::Invalid);
+        let mut retry = false;
+
+        let (new_state, result) = match old_state {
+            StreamStartState::SendStart(mut send) =>
+                match send.poll() {
+                    Ok(Async::Ready(stream)) => {
+                        retry = true;
+                        (StreamStartState::RecvStart(stream), Ok(Async::NotReady))
+                    },
+                    Ok(Async::NotReady) =>
+                        (StreamStartState::SendStart(send), Ok(Async::NotReady)),
+                    Err(e) =>
+                        (StreamStartState::Invalid, Err(e)),
+                },
+            StreamStartState::RecvStart(mut stream) =>
+                match stream.poll() {
+                    Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => {
+                        retry = true;
+                        // TODO: skip RecvFeatures for version < 1.0
+                        (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
+                    },
+                    Ok(Async::Ready(_)) =>
+                        return Err(Error::from(ErrorKind::InvalidData)),
+                    Ok(Async::NotReady) =>
+                        (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
+                    Err(e) =>
+                        return Err(e),
+                },
+            StreamStartState::RecvFeatures(mut stream, stream_attrs) =>
+                match stream.poll() {
+                    Ok(Async::Ready(Some(Packet::Stanza(stanza)))) =>
+                        if stanza.name == "features"
+                        && stanza.ns == Some(NS_XMPP_STREAM.to_owned()) {
+                            (StreamStartState::Invalid, Ok(Async::Ready(XMPPStream { stream, stream_attrs, stream_features: stanza })))
+                        } else {
+                            (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
+                        },
+                    Ok(Async::Ready(item)) => {
+                        println!("StreamStart skip {:?}", item);
+                        (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
+                    },
+                    Ok(Async::NotReady) =>
+                        (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady)),
+                    Err(e) =>
+                        return Err(e),
+                },
+            StreamStartState::Invalid =>
+                unreachable!(),
+        };
+
+        self.state = new_state;
+        if retry {
+            self.poll()
+        } else {
+            result
+        }
+    }
+}

src/tcp.rs 🔗

@@ -1,40 +1,22 @@
-use std::fmt;
 use std::net::SocketAddr;
-use std::io::{Error, ErrorKind};
-use futures::{Future, Sink, Poll, Async};
-use futures::stream::Stream;
-use futures::sink;
+use std::io::Error;
+use futures::{Future, Poll, Async};
 use tokio_core::reactor::Handle;
 use tokio_core::net::{TcpStream, TcpStreamNew};
 
-use super::{XMPPStream, XMPPCodec, Packet};
+use xmpp_stream::*;
+use stream_start::StreamStart;
 
-
-#[derive(Debug)]
 pub struct TcpClient {
     state: TcpClientState,
 }
 
 enum TcpClientState {
     Connecting(TcpStreamNew),
-    SendStart(sink::Send<XMPPStream<TcpStream>>),
-    RecvStart(Option<XMPPStream<TcpStream>>),
+    Start(StreamStart<TcpStream>),
     Established,
 }
 
-impl fmt::Debug for TcpClientState {
-    fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
-        let s = match *self {
-            TcpClientState::Connecting(_) => "Connecting",
-            TcpClientState::SendStart(_) => "SendStart",
-            TcpClientState::RecvStart(_) => "RecvStart",
-            TcpClientState::Established => "Established",
-        };
-        try!(write!(fmt, "{}", s));
-        Ok(())
-    }
-}
-
 impl TcpClient {
     pub fn connect(addr: &SocketAddr, handle: &Handle) -> Self {
         let tcp_stream_new = TcpStream::connect(addr, handle);
@@ -52,27 +34,12 @@ impl Future for TcpClient {
         let (new_state, result) = match self.state {
             TcpClientState::Connecting(ref mut tcp_stream_new) => {
                 let tcp_stream = try_ready!(tcp_stream_new.poll());
-                let xmpp_stream = XMPPCodec::frame_stream(tcp_stream);
-                let send = xmpp_stream.send(Packet::StreamStart);
-                let new_state = TcpClientState::SendStart(send);
-                (new_state, Ok(Async::NotReady))
-            },
-            TcpClientState::SendStart(ref mut send) => {
-                let xmpp_stream = try_ready!(send.poll());
-                let new_state = TcpClientState::RecvStart(Some(xmpp_stream));
+                let start = XMPPStream::from_stream(tcp_stream, "spaceboyz.net".to_owned());
+                let new_state = TcpClientState::Start(start);
                 (new_state, Ok(Async::NotReady))
             },
-            TcpClientState::RecvStart(ref mut opt_xmpp_stream) => {
-                let mut xmpp_stream = opt_xmpp_stream.take().unwrap();
-                match xmpp_stream.poll() {
-                    Ok(Async::Ready(Some(Packet::StreamStart))) => println!("Recv start!"),
-                    Ok(Async::Ready(_)) => return Err(Error::from(ErrorKind::InvalidData)),
-                    Ok(Async::NotReady) => {
-                        *opt_xmpp_stream = Some(xmpp_stream);
-                        return Ok(Async::NotReady);
-                    },
-                    Err(e) => return Err(e)
-                };
+            TcpClientState::Start(ref mut start) => {
+                let xmpp_stream = try_ready!(start.poll());
                 let new_state = TcpClientState::Established;
                 (new_state, Ok(Async::Ready(xmpp_stream)))
             },
@@ -80,7 +47,6 @@ impl Future for TcpClient {
                 unreachable!(),
         };
 
-        println!("Next state: {:?}", new_state);
         self.state = new_state;
 	match result {
 	    // by polling again, we register new future

src/xmpp_codec.rs 🔗

@@ -3,18 +3,17 @@ use std::fmt::Write;
 use std::str::from_utf8;
 use std::io::{Error, ErrorKind};
 use std::collections::HashMap;
-use tokio_io::{AsyncRead, AsyncWrite};
-use tokio_io::codec::{Framed, Encoder, Decoder};
+use tokio_io::codec::{Encoder, Decoder};
 use xml;
 use bytes::*;
 
 const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/";
-const NS_STREAMS: &'static str = "http://etherx.jabber.org/streams";
-const NS_CLIENT: &'static str = "jabber:client";
+
+pub type Attributes = HashMap<(String, Option<String>), String>;
 
 struct XMPPRoot {
     builder: xml::ElementBuilder,
-    pub attributes: HashMap<(String, Option<String>), String>,
+    pub attributes: Attributes,
 }
 
 impl XMPPRoot {
@@ -49,13 +48,11 @@ impl XMPPRoot {
 #[derive(Debug)]
 pub enum Packet {
     Error(Box<std::error::Error>),
-    StreamStart,
+    StreamStart(HashMap<String, String>),
     Stanza(xml::Element),
     StreamEnd,
 }
 
-pub type XMPPStream<T> = Framed<T, XMPPCodec>;
-
 pub struct XMPPCodec {
     parser: xml::Parser,
     root: Option<XMPPRoot>,
@@ -68,12 +65,6 @@ impl XMPPCodec {
             root: None,
         }
     }
-
-    pub fn frame_stream<S>(stream: S) -> Framed<S, XMPPCodec>
-        where S: AsyncRead + AsyncWrite
-    {
-        AsyncRead::framed(stream, XMPPCodec::new())
-    }
 }
 
 impl Decoder for XMPPCodec {
@@ -97,8 +88,12 @@ impl Decoder for XMPPCodec {
                     // Expecting <stream:stream>
                     match event {
                         Ok(xml::Event::ElementStart(start_tag)) => {
+                            let mut attrs: HashMap<String, String> = HashMap::new();
+                            for (&(ref name, _), value) in &start_tag.attributes {
+                                attrs.insert(name.to_owned(), value.to_owned());
+                            }
+                            result = Some(Packet::StreamStart(attrs));
                             self.root = Some(XMPPRoot::new(start_tag));
-                            result = Some(Packet::StreamStart);
                             break
                         },
                         Err(e) => {
@@ -146,18 +141,23 @@ impl Encoder for XMPPCodec {
 
     fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
         match item {
-            Packet::StreamStart => {
-                write!(dst,
-                       "<?xml version='1.0'?>\n
-<stream:stream version='1.0' to='spaceboyz.net' xmlns='{}' xmlns:stream='{}'>\n",
-                       NS_CLIENT, NS_STREAMS)
-                    .map_err(|_| Error::from(ErrorKind::WriteZero))
+            Packet::StreamStart(start_attrs) => {
+                let mut buf = String::new();
+                write!(buf, "<stream:stream").unwrap();
+                for (ref name, ref value) in &start_attrs {
+                    write!(buf, " {}=\"{}\"", xml::escape(&name), xml::escape(&value))
+                        .unwrap();
+                }
+                write!(buf, ">\n").unwrap();
+
+                println!("Encode start to {}", buf);
+                write!(dst, "{}", buf)
             },
             Packet::Stanza(stanza) =>
-                write!(dst, "{}", stanza)
-                .map_err(|_| Error::from(ErrorKind::InvalidInput)),
+                write!(dst, "{}", stanza),
             // TODO: Implement all
             _ => Ok(())
         }
+        .map_err(|_| Error::from(ErrorKind::InvalidInput))
     }
 }

src/xmpp_stream.rs 🔗

@@ -0,0 +1,62 @@
+use std::sync::Arc;
+use std::collections::HashMap;
+use futures::*;
+use tokio_io::{AsyncRead, AsyncWrite};
+use tokio_io::codec::Framed;
+use rustls::ClientConfig;
+use xml;
+
+use xmpp_codec::*;
+use stream_start::*;
+use starttls::{NS_XMPP_TLS, StartTlsClient};
+
+pub const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
+
+pub struct XMPPStream<S> {
+    pub stream: Framed<S, XMPPCodec>,
+    pub stream_attrs: HashMap<String, String>,
+    pub stream_features: xml::Element,
+}
+
+impl<S: AsyncRead + AsyncWrite> XMPPStream<S> {
+    pub fn from_stream(stream: S, to: String) -> StreamStart<S> {
+        let xmpp_stream = AsyncRead::framed(stream, XMPPCodec::new());
+        StreamStart::from_stream(xmpp_stream, to)
+    }
+
+    pub fn into_inner(self) -> S {
+        self.stream.into_inner()
+    }
+
+    pub fn can_starttls(&self) -> bool {
+        self.stream_features
+            .get_child("starttls", Some(NS_XMPP_TLS))
+            .is_some()
+    }
+
+    pub fn starttls(self, arc_config: Arc<ClientConfig>) -> StartTlsClient<S> {
+        StartTlsClient::from_stream(self, arc_config)
+    }
+}
+
+impl<S: AsyncWrite> Sink for XMPPStream<S> {
+    type SinkItem = <Framed<S, XMPPCodec> as Sink>::SinkItem;
+    type SinkError = <Framed<S, XMPPCodec> as Sink>::SinkError;
+
+    fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
+        self.stream.start_send(item)
+    }
+
+    fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
+        self.stream.poll_complete()
+    }
+}
+
+impl<S: AsyncRead> Stream for XMPPStream<S> {
+    type Item = <Framed<S, XMPPCodec> as Stream>::Item;
+    type Error = <Framed<S, XMPPCodec> as Stream>::Error;
+
+    fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
+        self.stream.poll()
+    }
+}