Detailed changes
@@ -41,7 +41,7 @@ tokio-xmpp = { path = ".", features = ["insecure-tcp"]}
[features]
default = ["starttls-rust"]
-starttls = ["hickory-resolver", "idna"]
+starttls = ["dns"]
tls-rust = ["tokio-rustls", "webpki-roots"]
tls-native = ["tokio-native-tls", "native-tls"]
starttls-native = ["starttls", "tls-native"]
@@ -50,6 +50,8 @@ insecure-tcp = []
syntax-highlighting = ["syntect"]
# Enable serde support in jid crate
serde = [ "xmpp-parsers/serde" ]
+# Required by starttls, and used by insecure-tcp by default
+dns = [ "hickory-resolver", "idna" ]
[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(xmpprs_doc_build)'] }
@@ -10,9 +10,11 @@ use super::connect::client_login;
use crate::connect::{AsyncReadAndWrite, ServerConnector};
use crate::error::{Error, ProtocolError};
use crate::event::Event;
+#[cfg(feature = "starttls")]
use crate::starttls::ServerConfig;
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::{add_stanza_id, XMPPStream};
+#[cfg(feature = "starttls")]
use crate::AsyncConfig;
/// XMPP client connection and state
@@ -46,6 +48,7 @@ enum ClientState<S: AsyncReadAndWrite> {
Connected(XMPPStream<S>),
}
+#[cfg(feature = "starttls")]
impl Client<ServerConfig> {
/// Start a new XMPP client using StartTLS transport and autoreconnect
///
@@ -1,12 +1,14 @@
use futures::{sink::SinkExt, Sink, Stream};
use minidom::Element;
use std::pin::Pin;
+#[cfg(feature = "starttls")]
use std::str::FromStr;
use std::task::{Context, Poll};
use tokio_stream::StreamExt;
use xmpp_parsers::{jid::Jid, ns, stream_features::StreamFeatures};
use crate::connect::ServerConnector;
+#[cfg(feature = "starttls")]
use crate::starttls::ServerConfig;
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::{add_stanza_id, XMPPStream};
@@ -22,6 +24,7 @@ pub struct Client<C: ServerConnector> {
stream: XMPPStream<C::Stream>,
}
+#[cfg(feature = "starttls")]
impl Client<ServerConfig> {
/// Start a new XMPP client and wait for a usable session
pub async fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, Error> {
@@ -1,7 +1,17 @@
//! `ServerConnector` provides streams for XMPP clients
+#[cfg(feature = "dns")]
+use futures::{future::select_ok, FutureExt};
+#[cfg(feature = "dns")]
+use hickory_resolver::{
+ config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver,
+};
+#[cfg(feature = "dns")]
+use log::debug;
use sasl::common::ChannelBinding;
+use std::net::{IpAddr, SocketAddr};
use tokio::io::{AsyncRead, AsyncWrite};
+use tokio::net::TcpStream;
use xmpp_parsers::jid::Jid;
use crate::xmpp_stream::XMPPStream;
@@ -32,3 +42,78 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
Ok(ChannelBinding::None)
}
}
+
+/// A simple wrapper to build [`TcpStream`]
+pub struct Tcp;
+
+impl Tcp {
+ /// Connect directly to an IP/Port combo
+ pub async fn connect(ip: IpAddr, port: u16) -> Result<TcpStream, Error> {
+ Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?)
+ }
+
+ /// Connect over TCP, resolving A/AAAA records (happy eyeballs)
+ #[cfg(feature = "dns")]
+ pub async fn resolve(domain: &str, port: u16) -> Result<TcpStream, Error> {
+ let ascii_domain = idna::domain_to_ascii(&domain)?;
+
+ if let Ok(ip) = ascii_domain.parse() {
+ return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
+ }
+
+ let (config, mut options) = hickory_resolver::system_conf::read_system_conf()?;
+ options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
+ let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default());
+
+ let ips = resolver.lookup_ip(ascii_domain).await?;
+
+ // Happy Eyeballs: connect to all records in parallel, return the
+ // first to succeed
+ select_ok(
+ ips.into_iter()
+ .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
+ )
+ .await
+ .map(|(result, _)| result)
+ .map_err(|_| Error::Disconnected)
+ }
+
+ /// Connect over TCP, resolving SRV records
+ #[cfg(feature = "dns")]
+ pub async fn resolve_with_srv(
+ domain: &str,
+ srv: &str,
+ fallback_port: u16,
+ ) -> Result<TcpStream, Error> {
+ let ascii_domain = idna::domain_to_ascii(&domain)?;
+
+ if let Ok(ip) = ascii_domain.parse() {
+ debug!("Attempting connection to {ip}:{fallback_port}");
+ return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
+ }
+
+ let resolver = TokioAsyncResolver::tokio_from_system_conf()?;
+
+ let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?;
+ let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
+
+ match srv_records {
+ Some(lookup) => {
+ // TODO: sort lookup records by priority/weight
+ for srv in lookup.iter() {
+ debug!("Attempting connection to {srv_domain} {srv}");
+ match Self::resolve(&srv.target().to_ascii(), srv.port()).await {
+ Ok(stream) => return Ok(stream),
+ Err(_) => {}
+ }
+ }
+ Err(Error::Disconnected)
+ }
+ None => {
+ // SRV lookup error, retry with hostname
+ debug!("Attempting connection to {domain}:{fallback_port}");
+ Self::resolve(domain, fallback_port).await
+ }
+ }
+ }
+}
@@ -1,3 +1,7 @@
+#[cfg(feature = "dns")]
+use hickory_resolver::{
+ error::ResolveError as DnsResolveError, proto::error::ProtoError as DnsProtoError,
+};
use sasl::client::MechanismError as SaslMechanismError;
use std::error::Error as StdError;
use std::fmt;
@@ -28,8 +32,18 @@ pub enum Error {
Fmt(fmt::Error),
/// Utf8 error
Utf8(Utf8Error),
- /// Error resolving DNS and/or establishing a connection, returned by a ServerConnector impl
+ /// Error specific to ServerConnector impl
Connection(Box<dyn ServerConnectorError>),
+ /// DNS protocol error
+ #[cfg(feature = "dns")]
+ Dns(DnsProtoError),
+ /// DNS resolution error
+ #[cfg(feature = "dns")]
+ Resolve(DnsResolveError),
+ /// DNS label conversion error, no details available from module
+ /// `idna`
+ #[cfg(feature = "dns")]
+ Idna,
}
impl fmt::Display for Error {
@@ -44,6 +58,12 @@ impl fmt::Display for Error {
Error::InvalidState => write!(fmt, "invalid state"),
Error::Fmt(e) => write!(fmt, "Fmt error: {}", e),
Error::Utf8(e) => write!(fmt, "Utf8 error: {}", e),
+ #[cfg(feature = "dns")]
+ Error::Dns(e) => write!(fmt, "{:?}", e),
+ #[cfg(feature = "dns")]
+ Error::Resolve(e) => write!(fmt, "{:?}", e),
+ #[cfg(feature = "dns")]
+ Error::Idna => write!(fmt, "IDNA error"),
}
}
}
@@ -92,6 +112,27 @@ impl From<Utf8Error> for Error {
}
}
+#[cfg(feature = "dns")]
+impl From<idna::Errors> for Error {
+ fn from(_e: idna::Errors) -> Self {
+ Error::Idna
+ }
+}
+
+#[cfg(feature = "dns")]
+impl From<DnsResolveError> for Error {
+ fn from(e: DnsResolveError) -> Error {
+ Error::Resolve(e)
+ }
+}
+
+#[cfg(feature = "dns")]
+impl From<DnsProtoError> for Error {
+ fn from(e: DnsProtoError) -> Error {
+ Error::Dns(e)
+ }
+}
+
/// XMPP protocol-level error
#[derive(Debug)]
pub enum ProtocolError {
@@ -1,6 +1,5 @@
//! StartTLS ServerConnector Error
-use hickory_resolver::{error::ResolveError, proto::error::ProtoError};
#[cfg(feature = "tls-native")]
use native_tls::Error as TlsError;
use std::error::Error as StdError;
@@ -15,13 +14,6 @@ use super::ServerConnectorError;
/// StartTLS ServerConnector Error
#[derive(Debug)]
pub enum Error {
- /// DNS protocol error
- Dns(ProtoError),
- /// DNS resolution error
- Resolve(ResolveError),
- /// DNS label conversion error, no details available from module
- /// `idna`
- Idna,
/// TLS error
Tls(TlsError),
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
@@ -34,9 +26,6 @@ impl ServerConnectorError for Error {}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self {
- Self::Dns(e) => write!(fmt, "{:?}", e),
- Self::Resolve(e) => write!(fmt, "{:?}", e),
- Self::Idna => write!(fmt, "IDNA error"),
Self::Tls(e) => write!(fmt, "TLS error: {}", e),
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
@@ -1,75 +0,0 @@
-use super::error::Error as StartTlsError;
-use crate::Error;
-use futures::{future::select_ok, FutureExt};
-use hickory_resolver::{
- config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver,
-};
-use log::debug;
-use std::net::SocketAddr;
-use tokio::net::TcpStream;
-
-pub async fn connect_to_host(domain: &str, port: u16) -> Result<TcpStream, Error> {
- let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| StartTlsError::Idna)?;
-
- if let Ok(ip) = ascii_domain.parse() {
- return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
- }
-
- let (config, mut options) =
- hickory_resolver::system_conf::read_system_conf().map_err(StartTlsError::Resolve)?;
- options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
- let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default());
-
- let ips = resolver
- .lookup_ip(ascii_domain)
- .await
- .map_err(StartTlsError::Resolve)?;
- // Happy Eyeballs: connect to all records in parallel, return the
- // first to succeed
- select_ok(
- ips.into_iter()
- .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
- )
- .await
- .map(|(result, _)| result)
- .map_err(|_| crate::Error::Disconnected)
-}
-
-pub async fn connect_with_srv(
- domain: &str,
- srv: &str,
- fallback_port: u16,
-) -> Result<TcpStream, Error> {
- let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| StartTlsError::Idna)?;
-
- if let Ok(ip) = ascii_domain.parse() {
- debug!("Attempting connection to {ip}:{fallback_port}");
- return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
- }
-
- let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(StartTlsError::Resolve)?;
-
- let srv_domain = format!("{}.{}.", srv, ascii_domain)
- .into_name()
- .map_err(StartTlsError::Dns)?;
- let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
-
- match srv_records {
- Some(lookup) => {
- // TODO: sort lookup records by priority/weight
- for srv in lookup.iter() {
- debug!("Attempting connection to {srv_domain} {srv}");
- match connect_to_host(&srv.target().to_ascii(), srv.port()).await {
- Ok(stream) => return Ok(stream),
- Err(_) => {}
- }
- }
- Err(crate::Error::Disconnected.into())
- }
- None => {
- // SRV lookup error, retry with hostname
- debug!("Attempting connection to {domain}:{fallback_port}");
- connect_to_host(domain, fallback_port).await
- }
- }
-}
@@ -27,16 +27,17 @@ use tokio::{
};
use xmpp_parsers::{jid::Jid, ns};
-use crate::error::ProtocolError;
-use crate::Error;
-use crate::{connect::ServerConnector, xmpp_codec::Packet, AsyncClient, SimpleClient};
-use crate::{connect::ServerConnectorError, xmpp_stream::XMPPStream};
+use crate::{
+ connect::{ServerConnector, ServerConnectorError, Tcp},
+ error::{Error, ProtocolError},
+ xmpp_codec::Packet,
+ xmpp_stream::XMPPStream,
+ AsyncClient, SimpleClient,
+};
use self::error::Error as StartTlsError;
-use self::happy_eyeballs::{connect_to_host, connect_with_srv};
pub mod error;
-mod happy_eyeballs;
/// AsyncClient that connects over StartTls
pub type StartTlsAsyncClient = AsyncClient<ServerConfig>;
@@ -64,9 +65,9 @@ impl ServerConnector for ServerConfig {
// TCP connection
let tcp_stream = match self {
ServerConfig::UseSrv => {
- connect_with_srv(jid.domain().as_str(), "_xmpp-client._tcp", 5222).await?
+ Tcp::resolve_with_srv(jid.domain().as_str(), "_xmpp-client._tcp", 5222).await?
}
- ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?,
+ ServerConfig::Manual { host, port } => Tcp::resolve(host.as_str(), *port).await?,
};
// Unencryped XMPPStream