starttls works

Astro created

Change summary

Cargo.toml           |   2 
examples/echo_bot.rs |  21 ++++++
src/lib.rs           |   4 +
src/starttls.rs      | 142 ++++++++++++++++++++++++++++++++++++++++++++++
src/tcp.rs           |   3 
src/xmpp_codec.rs    |  10 +++
6 files changed, 178 insertions(+), 4 deletions(-)

Detailed changes

Cargo.toml 🔗

@@ -9,3 +9,5 @@ tokio-core = "*"
 tokio-io = "*"
 bytes = "*"
 RustyXML = "*"
+rustls = "*"
+tokio-rustls = "*"

examples/echo_bot.rs 🔗

@@ -1,10 +1,15 @@
 extern crate futures;
 extern crate tokio_core;
 extern crate tokio_xmpp;
+extern crate rustls;
 
+use std::sync::Arc;
+use std::io::BufReader;
+use std::fs::File;
 use tokio_core::reactor::Core;
 use futures::{Future, Stream};
-use tokio_xmpp::{Packet, TcpClient};
+use tokio_xmpp::{Packet, TcpClient, StartTlsClient};
+use rustls::ClientConfig;
 
 fn main() {
     use std::net::ToSocketAddrs;
@@ -12,10 +17,16 @@ fn main() {
         .to_socket_addrs().unwrap()
         .next().unwrap();
 
+    let mut config = ClientConfig::new();
+    let mut certfile = BufReader::new(File::open("/usr/share/ca-certificates/CAcert/root.crt").unwrap());
+    config.root_store.add_pem_file(&mut certfile).unwrap();
+    let arc_config = Arc::new(config);
+
     let mut core = Core::new().unwrap();
     let client = TcpClient::connect(
         &addr,
         &core.handle()
+    ).and_then(|stream| StartTlsClient::from_stream(stream, arc_config)
     ).and_then(|stream| {
         stream.for_each(|event| {
             match event {
@@ -25,5 +36,11 @@ fn main() {
             Ok(())
         })
     });
-    core.run(client).unwrap();
+    match core.run(client) {
+        Ok(_) => (),
+        Err(e) => {
+            println!("Fatal: {}", e);
+            ()
+        }
+    }
 }

src/lib.rs 🔗

@@ -4,12 +4,16 @@ extern crate tokio_core;
 extern crate tokio_io;
 extern crate bytes;
 extern crate xml;
+extern crate rustls;
+extern crate tokio_rustls;
 
 
 mod xmpp_codec;
 pub use xmpp_codec::*;
 mod tcp;
 pub use tcp::*;
+mod starttls;
+pub use starttls::*;
 
 
 // type FullClient = sasl::Client<StartTLS<TCPConnection>>

src/starttls.rs 🔗

@@ -0,0 +1,142 @@
+use std::mem::replace;
+use std::io::{Error, ErrorKind};
+use std::sync::Arc;
+use futures::{Future, Sink, Poll, Async};
+use futures::stream::Stream;
+use futures::sink;
+use tokio_core::net::TcpStream;
+use rustls::*;
+use tokio_rustls::*;
+use xml;
+
+use super::{XMPPStream, XMPPCodec, Packet};
+
+
+const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
+const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
+
+pub struct StartTlsClient {
+    state: StartTlsClientState,
+    arc_config: Arc<ClientConfig>,
+}
+
+enum StartTlsClientState {
+    Invalid,
+    AwaitFeatures(XMPPStream<TcpStream>),
+    SendStartTls(sink::Send<XMPPStream<TcpStream>>),
+    AwaitProceed(XMPPStream<TcpStream>),
+    StartingTls(ConnectAsync<TcpStream>),
+}
+
+impl StartTlsClient {
+    /// Waits for <stream:features>
+    pub fn from_stream(xmpp_stream: XMPPStream<TcpStream>, arc_config: Arc<ClientConfig>) -> Self {
+        StartTlsClient {
+            state: StartTlsClientState::AwaitFeatures(xmpp_stream),
+            arc_config: arc_config,
+        }
+    }
+}
+
+// TODO: eval <stream:features>, check ns
+impl Future for StartTlsClient {
+    type Item = XMPPStream<TlsStream<TcpStream, ClientSession>>;
+    type Error = Error;
+
+    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
+        let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
+        let mut retry = false;
+        
+        let (new_state, result) = match old_state {
+            StartTlsClientState::AwaitFeatures(mut xmpp_stream) =>
+                match xmpp_stream.poll() {
+                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
+                        if stanza.name == "features"
+                        && stanza.ns == Some(NS_XMPP_STREAM.to_owned())
+                        =>
+                    {
+                        println!("Got features: {}", stanza);
+                        match stanza.get_child("starttls", Some(NS_XMPP_TLS)) {
+                            None =>
+                                (StartTlsClientState::Invalid, Err(Error::from(ErrorKind::InvalidData))),
+                            Some(_) => {
+                                let nonza = xml::Element::new(
+                                    "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
+                                    vec![]
+                                );
+                                println!("send {}", nonza);
+                                let packet = Packet::Stanza(nonza);
+                                let send = xmpp_stream.send(packet);
+                                let new_state = StartTlsClientState::SendStartTls(send);
+                                retry = true;
+                                (new_state, Ok(Async::NotReady))
+                            },
+                        }
+                    },
+                    Ok(Async::Ready(value)) => {
+                        println!("StartTlsClient ignore {:?}", value);
+                        (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady))
+                    },
+                    Ok(_) =>
+                        (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady)),
+                    Err(e) =>
+                        (StartTlsClientState::AwaitFeatures(xmpp_stream), Err(e)),
+                },
+            StartTlsClientState::SendStartTls(mut send) =>
+                match send.poll() {
+                    Ok(Async::Ready(xmpp_stream)) => {
+                        println!("starttls sent");
+                        let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
+                        retry = true;
+                        (new_state, Ok(Async::NotReady))
+                    },
+                    Ok(Async::NotReady) =>
+                        (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
+                    Err(e) =>
+                        (StartTlsClientState::SendStartTls(send), Err(e)),
+                },
+            StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
+                match xmpp_stream.poll() {
+                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
+                        if stanza.name == "proceed" =>
+                    {
+                        println!("* proceed *");
+                        let stream = xmpp_stream.into_inner();
+                        let connect = self.arc_config.connect_async("spaceboyz.net", stream);
+                        let new_state = StartTlsClientState::StartingTls(connect);
+                        retry = true;
+                        (new_state, Ok(Async::NotReady))
+                    },
+                    Ok(Async::Ready(value)) => {
+                        println!("StartTlsClient ignore {:?}", value);
+                        (StartTlsClientState::AwaitFeatures(xmpp_stream), Ok(Async::NotReady))
+                    },
+                    Ok(_) =>
+                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
+                    Err(e) =>
+                        (StartTlsClientState::AwaitProceed(xmpp_stream),  Err(e)),
+                },
+            StartTlsClientState::StartingTls(mut connect) =>
+                match connect.poll() {
+                    Ok(Async::Ready(tls_stream)) => {
+                        println!("Got a TLS stream!");
+                        let xmpp_stream = XMPPCodec::frame_stream(tls_stream);
+                        (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream)))
+                    },
+                    Ok(Async::NotReady) =>
+                        (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
+                    Err(e) =>
+                        (StartTlsClientState::StartingTls(connect),  Err(e)),
+                },
+            StartTlsClientState::Invalid =>
+                unreachable!(),
+        };
+
+        self.state = new_state;
+        if retry {
+            self.poll()
+        } else {
+            result
+        }
+    }
+}

src/tcp.rs 🔗

@@ -5,7 +5,6 @@ use futures::{Future, Sink, Poll, Async};
 use futures::stream::Stream;
 use futures::sink;
 use tokio_core::reactor::Handle;
-use tokio_io::AsyncRead;
 use tokio_core::net::{TcpStream, TcpStreamNew};
 
 use super::{XMPPStream, XMPPCodec, Packet};
@@ -53,7 +52,7 @@ impl Future for TcpClient {
         let (new_state, result) = match self.state {
             TcpClientState::Connecting(ref mut tcp_stream_new) => {
                 let tcp_stream = try_ready!(tcp_stream_new.poll());
-                let xmpp_stream = AsyncRead::framed(tcp_stream, XMPPCodec::new());
+                let xmpp_stream = XMPPCodec::frame_stream(tcp_stream);
                 let send = xmpp_stream.send(Packet::StreamStart);
                 let new_state = TcpClientState::SendStart(send);
                 (new_state, Ok(Async::NotReady))

src/xmpp_codec.rs 🔗

@@ -3,6 +3,7 @@ use std::fmt::Write;
 use std::str::from_utf8;
 use std::io::{Error, ErrorKind};
 use std::collections::HashMap;
+use tokio_io::{AsyncRead, AsyncWrite};
 use tokio_io::codec::{Framed, Encoder, Decoder};
 use xml;
 use bytes::*;
@@ -67,6 +68,12 @@ impl XMPPCodec {
             root: None,
         }
     }
+
+    pub fn frame_stream<S>(stream: S) -> Framed<S, XMPPCodec>
+        where S: AsyncRead + AsyncWrite
+    {
+        AsyncRead::framed(stream, XMPPCodec::new())
+    }
 }
 
 impl Decoder for XMPPCodec {
@@ -146,6 +153,9 @@ impl Encoder for XMPPCodec {
                        NS_CLIENT, NS_STREAMS)
                     .map_err(|_| Error::from(ErrorKind::WriteZero))
             },
+            Packet::Stanza(stanza) =>
+                write!(dst, "{}", stanza)
+                .map_err(|_| Error::from(ErrorKind::InvalidInput)),
             // TODO: Implement all
             _ => Ok(())
         }