1use futures::{sink::SinkExt, stream::StreamExt};
2use std::marker::Unpin;
3use tokio::io::{AsyncRead, AsyncWrite};
4use tokio_util::codec::Framed;
5use xmpp_parsers::{ns, Element, Jid};
6
7use crate::xmpp_codec::{Packet, XMPPCodec};
8use crate::xmpp_stream::XMPPStream;
9use crate::{Error, ProtocolError};
10
11/// Sends a `<stream:stream>`, then wait for one from the server, and
12/// construct an XMPPStream.
13pub async fn start<S: AsyncRead + AsyncWrite + Unpin>(
14 mut stream: Framed<S, XMPPCodec>,
15 jid: Jid,
16 ns: String,
17) -> Result<XMPPStream<S>, Error> {
18 let attrs = [
19 ("to".to_owned(), jid.clone().domain()),
20 ("version".to_owned(), "1.0".to_owned()),
21 ("xmlns".to_owned(), ns.clone()),
22 ("xmlns:stream".to_owned(), ns::STREAM.to_owned()),
23 ]
24 .iter()
25 .cloned()
26 .collect();
27 stream.send(Packet::StreamStart(attrs)).await?;
28
29 let stream_attrs;
30 loop {
31 match stream.next().await {
32 Some(Ok(Packet::StreamStart(attrs))) => {
33 stream_attrs = attrs;
34 break;
35 }
36 Some(Ok(_)) => {}
37 Some(Err(e)) => return Err(e.into()),
38 None => return Err(Error::Disconnected),
39 }
40 }
41
42 let stream_ns = stream_attrs
43 .get("xmlns")
44 .ok_or(ProtocolError::NoStreamNamespace)?
45 .clone();
46 let stream_id = stream_attrs
47 .get("id")
48 .ok_or(ProtocolError::NoStreamId)?
49 .clone();
50 let stream = if stream_ns == "jabber:client" && stream_attrs.get("version").is_some() {
51 let stream_features;
52 loop {
53 match stream.next().await {
54 Some(Ok(Packet::Stanza(stanza))) if stanza.is("features", ns::STREAM) => {
55 stream_features = stanza;
56 break;
57 }
58 Some(Ok(_)) => {}
59 Some(Err(e)) => return Err(e.into()),
60 None => return Err(Error::Disconnected),
61 }
62 }
63 XMPPStream::new(jid, stream, ns, stream_id, stream_features)
64 } else {
65 // FIXME: huge hack, shouldn’t be an element!
66 XMPPStream::new(
67 jid,
68 stream,
69 ns,
70 stream_id.clone(),
71 Element::builder(stream_id, ns::STREAM).build(),
72 )
73 };
74 Ok(stream)
75}