1
package com.twitter.finagle
2

3
import com.twitter.conversions.StorageUnitOps._
4
import com.twitter.finagle.Mux.param.{CompressionPreferences, MaxFrameSize, OppTls}
5
import com.twitter.finagle.client._
6
import com.twitter.finagle.factory.TimeoutFactory
7
import com.twitter.finagle.naming.BindingFactory
8
import com.twitter.finagle.filter.{NackAdmissionFilter, PayloadSizeFilter}
9
import com.twitter.finagle.mux.Handshake.Headers
10
import com.twitter.finagle.mux.pushsession._
11
import com.twitter.finagle.mux.transport._
12
import com.twitter.finagle.mux.{
13
  ExportCompressionUsage,
14
  Handshake,
15
  OpportunisticTlsParams,
16
  Request,
17
  Response,
18
  WithCompressionPreferences
19
}
20
import com.twitter.finagle.netty4.pushsession.{Netty4PushListener, Netty4PushTransporter}
21
import com.twitter.finagle.netty4.ssl.server.Netty4ServerSslChannelInitializer
22
import com.twitter.finagle.netty4.ssl.client.Netty4ClientSslChannelInitializer
23
import com.twitter.finagle.param.{Label, ProtocolLibrary, Stats, Timer, WithDefaultLoadBalancer}
24
import com.twitter.finagle.pool.BalancingPool
25
import com.twitter.finagle.pushsession._
26
import com.twitter.finagle.server._
27
import com.twitter.finagle.tracing._
28
import com.twitter.finagle.transport.Transport.{ClientSsl, ServerSsl}
29
import com.twitter.finagle.transport.Transport
30
import com.twitter.io.{Buf, ByteReader}
31
import com.twitter.util.{Future, StorageUnit}
32
import io.netty.channel.{Channel, ChannelPipeline}
33
import java.net.SocketAddress
34
import java.util.concurrent.Executor
35
import scala.collection.mutable.ArrayBuffer
36

37
/**
38
 * A client and server for the mux protocol described in [[com.twitter.finagle.mux]].
39
 */
40
object Mux extends Client[mux.Request, mux.Response] with Server[mux.Request, mux.Response] {
41

42
  /**
43
   * The current version of the mux protocol.
44
   */
45 14
  val LatestVersion: Short = 0x0001
46

47
  /**
48
   * Mux-specific stack params.
49
   */
50
  object param {
51

52
    /**
53
     * A class eligible for configuring the maximum size of a mux frame.
54
     * Any message that is larger than this value is fragmented across multiple
55
     * transmissions. Clients and Servers can use this to set an upper bound
56
     * on the size of messages they are willing to receive. The value is exchanged
57
     * and applied during the mux handshake.
58
     */
59
    case class MaxFrameSize(size: StorageUnit) {
60 14
      assert(size.inBytes <= Int.MaxValue, s"$size is not <= Int.MaxValue bytes")
61 14
      assert(size.inBytes > 0, s"$size must be positive")
62

63
      def mk(): (MaxFrameSize, Stack.Param[MaxFrameSize]) =
64 4
        (this, MaxFrameSize.param)
65
    }
66
    object MaxFrameSize {
67 14
      implicit val param = Stack.Param(MaxFrameSize(Int.MaxValue.bytes))
68
    }
69

70
    /**
71
     * A class eligible for configuring if a client's TLS mode is opportunistic.
72
     * If it's not None, then mux will negotiate with the supplied level whether
73
     * to use TLS or not before setting up TLS.
74
     *
75
     * If it's None, it will not attempt to negotiate whether to use TLS or not
76
     * with the remote peer, and if TLS is configured, it will use mux over TLS.
77
     *
78
     * @note opportunistic TLS is not mutually intelligible with simple mux
79
     *       over TLS
80
     */
81
    case class OppTls(level: Option[OpportunisticTls.Level]) {
82
      def mk(): (OppTls, Stack.Param[OppTls]) =
83 4
        (this, OppTls.param)
84
    }
85
    object OppTls {
86 14
      implicit val param = new Stack.Param[OppTls] {
87 14
        val default: OppTls = OppTls(None)
88

89
        // override this to have a "cleaner" output in the registry
90
        override def show(value: OppTls): Seq[(String, () => String)] = {
91
          val levelStr = value match {
92 0
            case OppTls(Some(oppTls)) => oppTls.value.toString
93 0
            case OppTls(None) => "none"
94
          }
95 0
          Seq(("opportunisticTlsLevel", () => levelStr))
96
        }
97
      }
98

99
      /** Determine whether opportunistic TLS is configured to `Desired` or `Required`. */
100 14
      def enabled(params: Stack.Params): Boolean = params[OppTls].level match {
101 7
        case Some(OpportunisticTls.Desired | OpportunisticTls.Required) => true
102 14
        case _ => false
103
      }
104
    }
105

106
    /**
107
     * A class eligible for configuring how to enable TLS.
108
     *
109
     * Only for internal use and testing--not intended to be exposed for
110
     * configuration to end-users.
111
     */
112
    private[finagle] case class TurnOnTlsFn(fn: (Stack.Params, ChannelPipeline) => Unit)
113
    private[finagle] object TurnOnTlsFn {
114 14
      implicit val param = Stack.Param(TurnOnTlsFn((_: Stack.Params, _: ChannelPipeline) => ()))
115
    }
116

117
    // tells the Netty4Transporter not to turn on TLS so we can turn it on later
118
    private[finagle] def removeTlsIfOpportunisticClient(params: Stack.Params): Stack.Params = {
119 14
      params[param.OppTls].level match {
120
        case None => params
121 7
        case _ => params + Transport.ClientSsl(None)
122
      }
123
    }
124

125
    // tells the Netty4Listener not to turn on TLS so we can turn it on later
126
    private[finagle] def removeTlsIfOpportunisticServer(params: Stack.Params): Stack.Params = {
127 14
      params[param.OppTls].level match {
128
        case None => params
129 7
        case _ => params + Transport.ServerSsl(None)
130
      }
131
    }
132

133
    private[finagle] case class PingManager(builder: (Executor, MessageWriter) => ServerPingManager)
134

135
    private[finagle] object PingManager {
136 14
      implicit val param = Stack.Param(PingManager { (_, writer) =>
137 14
        ServerPingManager.default(writer)
138
      })
139
    }
140

141
    /**
142
     * A class eligible for configuring if the client or server is willing to compress or decompress
143
     * requests or responses.
144
     */
145
    case class CompressionPreferences(compressionPreferences: Compression.LocalPreferences) {
146
      def mk(): (CompressionPreferences, Stack.Param[CompressionPreferences]) =
147 0
        (this, CompressionPreferences.param)
148
    }
149
    object CompressionPreferences {
150 14
      implicit val param = Stack.Param(CompressionPreferences(Compression.DefaultLocal))
151
    }
152
  }
153

154
  object Client {
155

156
    /** Prepends bound residual paths to outbound Mux requests's destinations. */
157
    private object MuxBindingFactory extends BindingFactory.Module[mux.Request, mux.Response] {
158
      protected[this] def boundPathFilter(residual: Path) =
159 14
        Filter.mk[mux.Request, mux.Response, mux.Request, mux.Response] { (req, service) =>
160 14
          service(mux.Request(residual ++ req.destination, req.contexts, req.body))
161
        }
162
    }
163

164
    private[finagle] val tlsEnable: (Stack.Params, ChannelPipeline) => Unit = (params, pipeline) =>
165 7
      pipeline.addFirst("opportunisticSslInit", new Netty4ClientSslChannelInitializer(params))
166

167 14
    private[finagle] val params: Stack.Params = StackClient.defaultParams +
168 14
      ProtocolLibrary("mux") +
169 14
      param.TurnOnTlsFn(tlsEnable)
170

171
    private val stack: Stack[ServiceFactory[mux.Request, mux.Response]] = StackClient.newStack
172
    // We use a singleton pool to manage a multiplexed session. Because it's mux'd, we
173
    // don't want arbitrary interrupts on individual dispatches to cancel outstanding service
174
    // acquisitions, so we disable `allowInterrupts`.
175
      .replace(
176 14
        StackClient.Role.pool,
177 14
        BalancingPool.module[mux.Request, mux.Response](allowInterrupts = false)
178
      )
179
      // As per the config above, we don't allow interrupts to propagate past the pool.
180
      // However, we need to provide a way to cancel service acquisitions which are taking
181
      // too long, so we "move" the [[TimeoutFactory]] below the pool.
182 14
      .remove(StackClient.Role.postNameResolutionTimeout)
183
      .insertAfter(
184 14
        StackClient.Role.pool,
185 14
        TimeoutFactory.module[mux.Request, mux.Response](Stack.Role("MuxSessionTimeout"))
186
      )
187 14
      .replace(BindingFactory.role, MuxBindingFactory)
188
      // Because the payload filter also traces the sizes, it's important that we do so
189
      // after the tracing context is initialized.
190
      .insertAfter(
191 14
        TraceInitializerFilter.role,
192 14
        PayloadSizeFilter.clientModule[mux.Request, mux.Response](_.body.length, _.body.length)
193
      )
194
      // Since NackAdmissionFilter should operate on all requests sent over
195
      // the wire including retries, it must be below `Retries`. Since it
196
      // aggregates the status of the entire cluster, it must be above
197
      // `LoadBalancerFactory` (not part of the endpoint stack).
198
      .insertBefore(
199 14
        StackClient.Role.prepFactory,
200 14
        NackAdmissionFilter.module[mux.Request, mux.Response]
201
      )
202 14
      .prepend(ExportCompressionUsage.module)
203

204
    /**
205
     * Returns the headers that a client sends to a server.
206
     *
207
     * @param maxFrameSize the maximum mux fragment size the client is willing to
208
     * receive from a server.
209
     */
210
    private[finagle] def headers(
211
      maxFrameSize: StorageUnit,
212
      tlsLevel: OpportunisticTls.Level,
213
      compressionPreferences: Compression.LocalPreferences
214
    ): Handshake.Headers = {
215 14
      val buffer = ArrayBuffer(
216 14
        MuxFramer.Header.KeyBuf -> MuxFramer.Header.encodeFrameSize(maxFrameSize.inBytes.toInt),
217 14
        OpportunisticTls.Header.KeyBuf -> tlsLevel.buf
218
      )
219

220 14
      if (!compressionPreferences.isDisabled) {
221 0
        buffer += (CompressionNegotiation.ClientHeader.KeyBuf ->
222 0
          CompressionNegotiation.ClientHeader.encode(compressionPreferences))
223
      }
224 14
      buffer.toSeq
225
    }
226

227
    /**
228
     * Check the opportunistic TLS configuration to ensure it's in a consistent state
229
     */
230
    private[finagle] def validateTlsParamConsistency(params: Stack.Params): Unit = {
231 14
      if (param.OppTls.enabled(params) && params[ClientSsl].sslClientConfiguration.isEmpty) {
232 7
        val level = params[param.OppTls].level
233 7
        throw new IllegalStateException(
234
          s"Client desired opportunistic TLS ($level) but ClientSsl param is empty."
235
        )
236
      }
237
    }
238
  }
239

240
  final case class Client(
241
    stack: Stack[ServiceFactory[mux.Request, mux.Response]] = Mux.Client.stack,
242
    params: Stack.Params = Mux.Client.params)
243
      extends PushStackClient[mux.Request, mux.Response, Client]
244
      with WithDefaultLoadBalancer[Client]
245
      with OpportunisticTlsParams[Client]
246
      with WithCompressionPreferences[Client] {
247

248 14
    private[this] val statsReceiver = params[Stats].statsReceiver
249 14
    private[this] val sessionStats = new SharedNegotiationStats(statsReceiver)
250 14
    private[this] val sessionParams = params + Stats(statsReceiver.scope("mux"))
251

252
    protected type SessionT = MuxClientNegotiatingSession
253
    protected type In = ByteReader
254
    protected type Out = Buf
255

256
    protected def newSession(
257
      handle: PushChannelHandle[ByteReader, Buf]
258
    ): Future[MuxClientNegotiatingSession] = {
259
      val negotiator: Option[Headers] => Future[MuxClientSession] = { headers =>
260 14
        new Negotiation.Client(sessionParams, sessionStats).negotiateAsync(handle, headers)
261
      }
262 14
      val headers = Mux.Client.headers(
263 14
        params[MaxFrameSize].size,
264 14
        params[OppTls].level.getOrElse(OpportunisticTls.Off),
265 14
        params[CompressionPreferences].compressionPreferences
266
      )
267

268 14
      Future.value(
269 14
        new MuxClientNegotiatingSession(
270
          handle = handle,
271 14
          version = Mux.LatestVersion,
272
          negotiator = negotiator,
273
          headers = headers,
274 14
          name = params[Label].label,
275 14
          stats = sessionParams[Stats].statsReceiver
276
        )
277
      )
278
    }
279

280
    override def newClient(dest: Name, label0: String): ServiceFactory[Request, Response] = {
281
      // We want to fail fast if the client's TLS configuration is inconsistent
282 14
      Mux.Client.validateTlsParamConsistency(params)
283 14
      super.newClient(dest, label0)
284
    }
285

286
    protected def newPushTransporter(sa: SocketAddress): PushTransporter[ByteReader, Buf] = {
287
      // We use a custom Netty4PushTransporter to provide a handle to the
288
      // underlying Netty channel via MuxChannelHandle, giving us the ability to
289
      // add TLS support later in the lifecycle of the socket connection.
290 14
      new Netty4PushTransporter[ByteReader, Buf](
291
        transportInit = _ => (),
292
        protocolInit = PipelineInit,
293
        remoteAddress = sa,
294
        params = Mux.param.removeTlsIfOpportunisticClient(params)
295
      ) {
296
        override protected def initSession[T <: PushSession[ByteReader, Buf]](
297
          channel: Channel,
298
          protocolInit: (ChannelPipeline) => Unit,
299
          sessionBuilder: (PushChannelHandle[ByteReader, Buf]) => Future[T]
300
        ): Future[T] = {
301
          // With this builder we add support for opportunistic TLS via `MuxChannelHandle`
302
          // and the respective `Negotiation` types. Adding more proxy types will break this pathway.
303
          def wrappedBuilder(pushChannelHandle: PushChannelHandle[ByteReader, Buf]): Future[T] =
304 14
            sessionBuilder(new MuxChannelHandle(pushChannelHandle, channel, sessionParams))
305

306 14
          super.initSession(channel, protocolInit, wrappedBuilder)
307
        }
308
      }
309
    }
310

311
    protected def toService(
312
      session: MuxClientNegotiatingSession
313
    ): Future[Service[Request, Response]] =
314 14
      session.negotiate().flatMap(_.asService)
315

316
    protected def copy1(
317
      stack: Stack[ServiceFactory[Request, Response]],
318
      params: Stack.Params
319 14
    ): Client = copy(stack, params)
320
  }
321

322 14
  def client: Client = Client()
323

324
  def newService(dest: Name, label: String): Service[mux.Request, mux.Response] =
325 0
    client.newService(dest, label)
326

327
  def newClient(dest: Name, label: String): ServiceFactory[mux.Request, mux.Response] =
328 0
    client.newClient(dest, label)
329

330
  object Server {
331

332
    private[finagle] val stack: Stack[ServiceFactory[mux.Request, mux.Response]] =
333
      StackServer.newStack
334
      // We remove the trace init filter and don't replace it with anything because
335
      // the mux codec initializes tracing.
336 14
        .remove(TraceInitializerFilter.role)
337 14
        .prepend(ExportCompressionUsage.module)
338
        // Because tracing initialization happens in the mux codec, we know the service stack
339
        // is dispatched with proper tracing context, so the ordering of this filter isn't
340
        // relevant.
341 14
        .prepend(PayloadSizeFilter.serverModule(_.body.length, _.body.length))
342

343
    private[finagle] val tlsEnable: (Stack.Params, ChannelPipeline) => Unit = (params, pipeline) =>
344 7
      pipeline.addFirst("opportunisticSslInit", new Netty4ServerSslChannelInitializer(params))
345

346 14
    private[finagle] val params: Stack.Params = StackServer.defaultParams +
347 14
      ProtocolLibrary("mux") +
348 14
      param.TurnOnTlsFn(tlsEnable)
349

350
    type SessionF = (
351
      RefPushSession[ByteReader, Buf],
352
      Stack.Params,
353
      SharedNegotiationStats,
354
      MuxChannelHandle,
355
      Service[Request, Response]
356
    ) => PushSession[ByteReader, Buf]
357

358
    val defaultSessionFactory: SessionF = (
359
      ref: RefPushSession[ByteReader, Buf],
360
      params: Stack.Params,
361
      sharedStats: SharedNegotiationStats,
362
      handle: MuxChannelHandle,
363
      service: Service[Request, Response]
364
    ) => {
365 14
      val scopedStatsParams = params + Stats(params[Stats].statsReceiver.scope("mux"))
366 14
      MuxServerNegotiator.build(
367
        ref = ref,
368
        handle = handle,
369
        service = service,
370
        makeLocalHeaders = Mux.Server
371 14
          .headers(
372
            _: Headers,
373 14
            params[MaxFrameSize].size,
374 14
            params[OppTls].level.getOrElse(OpportunisticTls.Off),
375 14
            params[CompressionPreferences].compressionPreferences
376
          ),
377
        negotiate = (service, headers) =>
378
          new Negotiation.Server(scopedStatsParams, sharedStats, service)
379 14
            .negotiate(handle, headers),
380 14
        timer = params[Timer].timer
381
      )
382
      ref
383
    }
384

385
    /**
386
     * Returns the headers that a server sends to a client.
387
     *
388
     * @param clientHeaders The headers received from the client. This is useful since
389
     * the headers the server responds with can be based on the clients.
390
     *
391
     * @param maxFrameSize the maximum mux fragment size the server is willing to
392
     * receive from a client.
393
     */
394
    private[finagle] def headers(
395
      clientHeaders: Handshake.Headers,
396
      maxFrameSize: StorageUnit,
397
      tlsLevel: OpportunisticTls.Level,
398
      compressionPreferences: Compression.LocalPreferences
399
    ): Handshake.Headers = {
400
      val clientCompressionPreferences = Handshake
401 14
        .valueOf(CompressionNegotiation.ClientHeader.KeyBuf, clientHeaders)
402 0
        .map(CompressionNegotiation.ClientHeader.decode(_))
403 14
        .getOrElse(Compression.PeerCompressionOff)
404 14
      val compressionFormats = CompressionNegotiation.negotiate(
405
        compressionPreferences,
406
        clientCompressionPreferences
407
      )
408

409 14
      val withoutCompression = Seq(
410 14
        MuxFramer.Header.KeyBuf -> MuxFramer.Header.encodeFrameSize(maxFrameSize.inBytes.toInt),
411 14
        OpportunisticTls.Header.KeyBuf -> tlsLevel.buf
412
      )
413

414 14
      if (compressionFormats.isDisabled) {
415 14
        withoutCompression
416
      } else {
417 0
        withoutCompression :+ (CompressionNegotiation.ServerHeader.KeyBuf ->
418 0
          CompressionNegotiation.ServerHeader.encode(compressionFormats))
419
      }
420
    }
421

422
    /**
423
     * Check the opportunistic TLS configuration to ensure it's in a consistent state
424
     */
425
    private[finagle] def validateTlsParamConsistency(params: Stack.Params): Unit = {
426
      // We need to make sure
427 14
      if (param.OppTls.enabled(params) && params[ServerSsl].sslServerConfiguration.isEmpty) {
428 7
        val level = params[param.OppTls].level
429 7
        throw new IllegalStateException(
430
          s"Server desired opportunistic TLS ($level) but ServerSsl param is empty."
431
        )
432
      }
433
    }
434
  }
435

436
  final case class Server(
437
    stack: Stack[ServiceFactory[mux.Request, mux.Response]] = Mux.Server.stack,
438
    params: Stack.Params = Mux.Server.params,
439
    sessionFactory: Server.SessionF = Server.defaultSessionFactory)
440
      extends PushStackServer[mux.Request, mux.Response, Server]
441
      with OpportunisticTlsParams[Server]
442
      with WithCompressionPreferences[Server] {
443

444
    protected type PipelineReq = ByteReader
445
    protected type PipelineRep = Buf
446

447 14
    private[this] val sessionStats = new SharedNegotiationStats(params[Stats].statsReceiver)
448

449
    protected def newListener(): PushListener[ByteReader, Buf] = {
450 14
      Mux.Server.validateTlsParamConsistency(params)
451 14
      new Netty4PushListener[ByteReader, Buf](
452
        pipelineInit = PipelineInit,
453
        params = Mux.param.removeTlsIfOpportunisticServer(params),
454
        setupMarshalling = identity
455
      ) {
456
        override protected def initializePushChannelHandle(
457
          ch: Channel,
458
          sessionFactory: SessionFactory
459
        ): Unit = {
460
          val proxyFactory: SessionFactory = { handle =>
461
            // We need to proxy via the MuxChannelHandle to get a vector
462
            // into the netty pipeline for handling installing the TLS
463
            // components of the pipeline after the negotiation.
464 14
            sessionFactory(new MuxChannelHandle(handle, ch, params))
465
          }
466 14
          super.initializePushChannelHandle(ch, proxyFactory)
467
        }
468
      }
469
    }
470

471
    protected def newSession(
472
      handle: PushChannelHandle[ByteReader, Buf],
473
      service: Service[Request, Response]
474
    ): RefPushSession[ByteReader, Buf] = {
475
      handle match {
476
        case h: MuxChannelHandle =>
477 14
          val ref = new RefPushSession[ByteReader, Buf](h, SentinelSession[ByteReader, Buf](h))
478 14
          sessionFactory(ref, params, sessionStats, h, service)
479
          ref
480

481
        case other =>
482 0
          throw new IllegalStateException(
483
            s"Expected to find a `MuxChannelHandle` but found ${other.getClass.getSimpleName}"
484
          )
485
      }
486
    }
487

488
    protected def copy1(
489
      stack: Stack[ServiceFactory[Request, Response]],
490
      params: Stack.Params
491 14
    ): Server = copy(stack, params)
492
  }
493

494 14
  def server: Server = Server()
495

496
  def serve(
497
    addr: SocketAddress,
498
    service: ServiceFactory[mux.Request, mux.Response]
499 0
  ): ListeningServer = server.serve(addr, service)
500
}

Read our documentation on viewing source code .

Loading