1// Copyright (c) 2025 Saarko <saarko@tutanota.com>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7//! `direct_tls::ServerConfig` provides a `ServerConnector` for direct TLS connections
8
9use alloc::borrow::Cow;
10use sasl::common::ChannelBinding;
11use tokio::{io::BufStream, net::TcpStream};
12use xmpp_parsers::jid::Jid;
13
14use crate::{
15 connect::{
16 tls_common::{establish_tls_connection, TlsConnectorError, TlsStream},
17 DnsConfig, ServerConnector,
18 },
19 error::Error,
20 xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts},
21};
22
23/// Connect via direct TLS to an XMPP server
24#[derive(Debug, Clone)]
25pub struct DirectTlsServerConnector(pub DnsConfig);
26
27impl From<DnsConfig> for DirectTlsServerConnector {
28 fn from(dns_config: DnsConfig) -> DirectTlsServerConnector {
29 Self(dns_config)
30 }
31}
32
33impl ServerConnector for DirectTlsServerConnector {
34 type Stream = BufStream<TlsStream<TcpStream>>;
35
36 async fn connect(
37 &self,
38 jid: &Jid,
39 ns: &'static str,
40 timeouts: Timeouts,
41 ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
42 let tcp_stream = self.0.resolve().await?;
43
44 // Immediately establish TLS connection
45 let (tls_stream, channel_binding) =
46 establish_tls_connection(tcp_stream, jid.domain().as_str()).await?;
47
48 // Establish XMPP stream over TLS
49 Ok((
50 initiate_stream(
51 tokio::io::BufStream::new(tls_stream),
52 ns,
53 StreamHeader {
54 to: Some(Cow::Borrowed(jid.domain().as_str())),
55 // Setting explicitly `from` here, because server require it
56 // in order to advertise i.e. SASL2 (XEP-0388).
57 from: Some(Cow::Borrowed(jid.to_bare().as_str())),
58 id: None,
59 },
60 timeouts,
61 )
62 .await?,
63 channel_binding,
64 ))
65 }
66}
67
68/// Direct TLS ServerConnector Error - now just an alias to the common error type
69pub type DirectTlsError = TlsConnectorError;