1
using System.Collections.Generic;
2
using System.IO;
3
using System.Linq;
4
using System.Security.Cryptography;
5
using System.Threading.Tasks;
6
using Libplanet.Stun.Attributes;
7

8
namespace Libplanet.Stun.Messages
9
{
10
    public abstract class StunMessage
11
    {
12
        private const int HeaderBytes = 20;
13
        private const int MessageIntegrityBytes = 24;
14
        private const int FingerprintBytes = 8;
15

16 2
        protected StunMessage()
17 2
        {
18 2
            var transactionId = new byte[12];
19 2
            using var rng = new RNGCryptoServiceProvider();
20 2
            rng.GetBytes(transactionId);
21 2
            TransactionId = transactionId;
22 2
        }
23

24
        // TODO Should document following STUN / TURN RFC
25
        // https://www.iana.org/assignments/stun-parameters/stun-parameters.xhtml
26
        #pragma warning disable SA1602
27
        public enum MessageClass : byte
28
        {
29
            Request = 0x0,
30
            Indication = 0x1,
31
            SuccessResponse = 0x2,
32
            ErrorResponse = 0x3,
33
        }
34

35
        public enum MessageMethod : ushort
36
        {
37
            Binding = 0x001,
38
            Allocate = 0x003,
39
            Refresh = 0x004,
40
            Send = 0x006,
41
            Data = 0x007,
42
            CreatePermission = 0x008,
43
            ChannelBind = 0x009,
44
            Connect = 0x00a,
45
            ConnectionBind = 0x00b,
46
            ConnectionAttempt = 0x00c,
47
        }
48
        #pragma warning restore SA1602
49

50
        /// <summary>
51
        /// A <see cref="MessageClass"/> of STUN packet.
52
        /// </summary>
53
        public abstract MessageClass Class { get; }
54

55
        /// <summary>
56
        /// A <see cref="MessageMethod"/> of STUN packet.
57
        /// </summary>
58
        public abstract MessageMethod Method { get; }
59

60
        /// <summary>
61
        /// A 96-bit length identifier, used to uniquely identify STUN transactions.
62
        /// </summary>
63 2
        public byte[] TransactionId { get; internal set; }
64

65
        /// <summary>
66
        /// A fixed value to distinguish STUN packets from packets of another protocol.
67
        /// </summary>
68
        /// <remarks>It should be always 0x2112A442 in network byte order.</remarks>
69 2
        internal static byte[] MagicCookie => new byte[]
70 2
        {
71 2
            0x21, 0x12, 0xa4, 0x42,
72 2
        };
73

74
        /// <summary>
75
        /// A list of <see cref="Attribute"/> of STUN packet.
76
        /// </summary>
77 2
        protected IEnumerable<Attribute> Attributes { get; set; }
78

79
        /// <summary>
80
        /// Parses <see cref="StunMessage"/> from <paramref name="stream"/>.
81
        /// </summary>
82
        /// <param name="stream">A view of a sequence of STUN packet's bytes.</param>
83
        /// <returns>A <see cref="StunMessage"/> derived on
84
        /// bytes read from <paramref name="stream"/>.
85
        /// </returns>
86
        public static async Task<StunMessage> Parse(Stream stream)
87 2
        {
88 2
            var header = new byte[20];
89 2
            await stream.ReadAsync(header, 0, 20);
90

91 2
            MessageMethod method = ParseMethod(header[0], header[1]);
92 2
            MessageClass @class = ParseClass(header[0], header[1]);
93

94 2
            var length = new byte[2];
95 2
            System.Array.Copy(header, 2, length, 0, 2);
96

97 2
            var transactionId = new byte[12];
98 2
            System.Array.Copy(header, 8, transactionId, 0, 12);
99

100 2
            var body = new byte[length.ToUShort()];
101 2
            await stream.ReadAsync(body, 0, body.Length);
102 2
            IEnumerable<Attribute> attributes = ParseAttributes(
103 2
                body,
104 2
                transactionId
105 2
            );
106

107 2
            StunMessage rv = null;
108 2
            rv = @class switch
109 2
            {
110 2
                MessageClass.SuccessResponse => method switch
111 2
                {
112 2
                    MessageMethod.Allocate => new AllocateSuccessResponse(),
113 2
                    MessageMethod.Connect => new ConnectSuccessResponse(),
114 2
                    MessageMethod.ConnectionBind => new ConnectionBindSuccessResponse(),
115 2
                    MessageMethod.Binding => new BindingSuccessResponse(),
116 2
                    MessageMethod.CreatePermission => new CreatePermissionSuccessResponse(),
117 2
                    MessageMethod.Refresh => new RefreshSuccessResponse(),
118 2
                    _ => rv,
119 2
                },
120 2
                MessageClass.ErrorResponse => method switch
121 2
                {
122 2
                    MessageMethod.Allocate => new AllocateErrorResponse(),
123 2
                    MessageMethod.CreatePermission => new CreatePermissionErrorResponse(),
124 2
                    MessageMethod.Refresh => new RefreshErrorResponse(),
125 2
                    _ => rv,
126 2
                },
127 2
                MessageClass.Indication => method switch
128 2
                {
129 2
                    MessageMethod.ConnectionAttempt => new ConnectionAttempt(),
130 2
                    _ => rv,
131 2
                },
132 2
                _ => rv,
133 2
            };
134

135 2
            if (rv is null)
136 1
            {
137 1
                throw new TurnClientException("Parsed result is null.");
138
            }
139

140 2
            rv.TransactionId = transactionId;
141 2
            rv.Attributes = attributes;
142

143 2
            return rv;
144 2
        }
145

146
        public byte[] Encode(IStunContext ctx)
147 2
        {
148 2
            bool useMessageIntegrity =
149 2
                !string.IsNullOrEmpty(ctx?.Username) &&
150 2
                !string.IsNullOrEmpty(ctx?.Password) &&
151 2
                !string.IsNullOrEmpty(ctx?.Realm);
152

153 2
            var c = (ushort)Class;
154 2
            var m = (ushort)Method;
155 2
            int type =
156 2
                (m & 0x0f80) << 2 |
157 2
                (m & 0x0070) << 1 |
158 2
                (m & 0x000f) << 0 |
159 2
                (c & 0x2) << 7 |
160 2
                (c & 0x1) << 4;
161

162 2
            using var ms = new MemoryStream();
163 2
            List<Attribute> attrs = Attributes.ToList();
164

165 2
            if (!string.IsNullOrEmpty(ctx?.Username))
166 2
            {
167 2
                attrs.Add(new Username(ctx.Username));
168 2
            }
169

170 2
            if (ctx?.Nonce != null)
171 1
            {
172 1
                attrs.Add(new Attributes.Nonce(ctx.Nonce));
173 1
            }
174

175 2
            if (!string.IsNullOrEmpty(ctx?.Realm))
176 1
            {
177 1
                attrs.Add(new Realm(ctx.Realm));
178 1
            }
179

180
            byte[] encodedAttrs;
181 2
            using (var ams = new MemoryStream())
182 2
            {
183 2
                foreach (Attribute attr in attrs)
184 2
                {
185 2
                    byte[] asBytes = attr.ToByteArray(TransactionId);
186 2
                    ams.Write(asBytes, 0, asBytes.Length);
187 2
                }
188

189 2
                encodedAttrs = ams.ToArray();
190 2
            }
191

192
            // 8 bytes for Fingerprint
193 2
            var messageLength =
194 2
                (ushort)(encodedAttrs.Length + FingerprintBytes);
195

196 2
            if (useMessageIntegrity)
197 1
            {
198 1
                messageLength += MessageIntegrityBytes;
199 1
            }
200

201 2
            ms.Write(((ushort)type).ToBytes(), 0, 2);
202 2
            ms.Write(messageLength.ToBytes(), 0, 2);
203 2
            ms.Write(MagicCookie, 0, MagicCookie.Length);
204 2
            ms.Write(TransactionId, 0, TransactionId.Length);
205 2
            ms.Write(encodedAttrs, 0, encodedAttrs.Length);
206

207 2
            if (useMessageIntegrity)
208 1
            {
209 1
                var lengthWithoutFingerprint =
210 1
                    (ushort)(messageLength - FingerprintBytes);
211 1
                byte[] toCalc = ms.ToArray();
212 1
                lengthWithoutFingerprint.ToBytes().CopyTo(toCalc, 2);
213

214 1
                MessageIntegrity mi =
215 1
                    MessageIntegrity.Calculate(
216 1
                        ctx?.Username,
217 1
                        ctx?.Password,
218 1
                        ctx?.Realm,
219 1
                        toCalc);
220 1
                ms.Write(mi.ToByteArray(), 0, MessageIntegrityBytes);
221 1
            }
222

223 2
            Fingerprint fingerprint = Fingerprint.FromMessage(
224 2
                ms.ToArray()
225 2
            );
226 2
            ms.Write(fingerprint.ToByteArray(), 0, FingerprintBytes);
227

228 2
            return ms.ToArray();
229 2
        }
230

231
        internal static IEnumerable<Attribute> ParseAttributes(
232
            IEnumerable<byte> bytes,
233
            byte[] transactionId = null
234
        )
235 2
        {
236 2
            while (bytes.Any())
237 2
            {
238 2
                var type = (Attribute.AttributeType)bytes.Take(2).ToUShort();
239 2
                ushort length = bytes.Skip(2).Take(2).ToUShort();
240 2
                byte[] payload = bytes.Skip(4).Take(length).ToArray();
241

242 2
                Attribute attr = type switch
243 2
                {
244 2
                    Attribute.AttributeType.ErrorCode => ErrorCode.Parse(payload),
245 2
                    Attribute.AttributeType.Realm => Realm.Parse(payload),
246 2
                    Attribute.AttributeType.Nonce => Stun.Attributes.Nonce.Parse(payload),
247 2
                    Attribute.AttributeType.Software => Software.Parse(payload),
248 2
                    Attribute.AttributeType.Fingerprint => Fingerprint.Parse(payload),
249 2
                    Attribute.AttributeType.XorMappedAddress =>
250 2
                        XorMappedAddress.Parse(payload, transactionId),
251 2
                    Attribute.AttributeType.XorRelayedAddress =>
252 2
                        XorRelayedAddress.Parse(payload, transactionId),
253 2
                    Attribute.AttributeType.ConnectionId => new ConnectionId(payload),
254 2
                    Attribute.AttributeType.Lifetime => new Lifetime((int)payload.ToUInt()),
255 2
                    _ => null,
256 2
                };
257

258 2
                if (!(attr is null))
259 2
                {
260 2
                    yield return attr;
261 1
                }
262

263
                // Detect padding
264 1
                var padBytes = (ushort)((4 + length) % 4);
265 1
                if (padBytes > 0)
266 1
                {
267 1
                    length += padBytes;
268 1
                }
269

270 1
                bytes = bytes.Skip(4 + length);
271 1
            }
272 1
        }
273

274
        internal static MessageClass ParseClass(byte high, byte low)
275 2
        {
276 2
            ushort type = high;
277 2
            type = (ushort)(type << 8);
278 2
            type |= low;
279

280 2
            return (MessageClass)((type >> 7 | type >> 4) & 0x3);
281 2
        }
282

283
        internal static MessageMethod ParseMethod(byte high, byte low)
284 2
        {
285 2
            ushort type = high;
286 2
            type = (ushort)(type << 8);
287 2
            type |= low;
288

289 2
            return (MessageMethod)(
290 2
                (type & 0x3e00) >> 2 | (type & 0x00e0) >> 1 | (type & 0x000f));
291 2
        }
292

293
        protected T GetAttribute<T>()
294
            where T : Attribute
295 2
        {
296 2
            foreach (Attribute attr in Attributes)
297 2
            {
298 2
                if (attr is T asT)
299 2
                {
300 2
                    return asT;
301
                }
302 1
            }
303

304 0
            return null;
305 2
        }
306
    }
307
}

Read our documentation on viewing source code .

Loading