twitter / chill
1
/*
2
Copyright 2013 Twitter, Inc.
3

4
Licensed under the Apache License, Version 2.0 (the "License");
5
you may not use this file except in compliance with the License.
6
You may obtain a copy of the License at
7

8
http://www.apache.org/licenses/LICENSE-2.0
9

10
Unless required by applicable law or agreed to in writing, software
11
distributed under the License is distributed on an "AS IS" BASIS,
12
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
See the License for the specific language governing permissions and
14
limitations under the License.
15
 */
16

17
package com.twitter.chill.scrooge
18

19
import com.esotericsoftware.kryo.Kryo
20
import com.esotericsoftware.kryo.Serializer
21
import com.esotericsoftware.kryo.io.Input
22
import com.esotericsoftware.kryo.io.Output
23
import com.twitter.scrooge.{ThriftStruct, ThriftStructCodec, ThriftStructSerializer}
24
import org.apache.thrift.protocol.TBinaryProtocol
25
import scala.collection.mutable
26
import scala.util.Try
27

28
/**
29
 * Kryo serializer for Scrooge generated Thrift structs
30
 * this probably isn't thread safe, but neither is Kryo
31
 */
32
object ScroogeThriftStructSerializer {
33
  /* don't serialize classToCodec because it contains anonymous inner ThriftStructSerializers that have reference to
34
   * ScroogeThriftStructSerializer, which itself has a reference to classToCodec etc.
35
   */
36
  @transient lazy private[this] val classToTSS: mutable.Map[Class[_], ThriftStructSerializer[_]] = {
37 0
    mutable.Map()
38
  }
39

40
  private def getObject[T](companionClass: Class[T]): AnyRef =
41 0
    companionClass.getField("MODULE$").get(null)
42

43
  /**
44
   * For unions, we split on $ after the dot.
45
   * this is costly, but only done once per Class
46
   */
47
  private[this] def codecForUnion[T <: ThriftStruct](maybeUnion: Class[T]): Try[ThriftStructCodec[T]] =
48
    Try(
49 0
      getObject(
50 0
        Class.forName(maybeUnion.getName.reverse.dropWhile(_ != '$').reverse, true, maybeUnion.getClassLoader)
51
      )
52 0
    ).map(_.asInstanceOf[ThriftStructCodec[T]])
53

54
  private[this] def codecForNormal[T <: ThriftStruct](
55
      thriftStructClass: Class[T]
56
  ): Try[ThriftStructCodec[T]] =
57 0
    Try(getObject(Class.forName(thriftStructClass.getName + "$", true, thriftStructClass.getClassLoader)))
58 0
      .map(_.asInstanceOf[ThriftStructCodec[T]])
59

60
  // the companion to a ThriftStruct generated by scrooge will always be its codec
61
  private[this] def constructCodec[T <: ThriftStruct](thriftStructClass: Class[T]): ThriftStructCodec[T] =
62
    codecForNormal(thriftStructClass)
63 0
      .orElse(codecForUnion(thriftStructClass))
64 0
      .get
65

66
  private[this] def constructThriftStructSerializer[T <: ThriftStruct](
67
      thriftStructClass: Class[T]
68
  ): ThriftStructSerializer[T] = {
69
    // capture the codec here:
70 0
    val newCodec = constructCodec(thriftStructClass)
71 0
    new ThriftStructSerializer[T] {
72 0
      val protocolFactory = new TBinaryProtocol.Factory
73
      override def codec: ThriftStructCodec[T] = newCodec
74
    }
75
  }
76

77
  def lookupThriftStructSerializer[T <: ThriftStruct](
78
      thriftStructClass: Class[_ <: T]
79
  ): ThriftStructSerializer[T] = {
80
    val tss =
81 0
      classToTSS.getOrElseUpdate(thriftStructClass, constructThriftStructSerializer(thriftStructClass))
82 0
    tss.asInstanceOf[ThriftStructSerializer[T]]
83
  }
84

85
  def lookupThriftStructSerializer[T <: ThriftStruct](thriftStruct: T): ThriftStructSerializer[T] =
86 0
    lookupThriftStructSerializer(thriftStruct.getClass)
87
}
88

89
class ScroogeThriftStructSerializer[T <: ThriftStruct] extends Serializer[T] {
90
  import ScroogeThriftStructSerializer._
91
  override def write(kryo: Kryo, output: Output, thriftStruct: T): Unit =
92 0
    try {
93 0
      val thriftStructSerializer = lookupThriftStructSerializer(thriftStruct)
94 0
      val serThrift = thriftStructSerializer.toBytes(thriftStruct)
95 0
      output.writeInt(serThrift.length, true)
96 0
      output.writeBytes(serThrift)
97
    } catch {
98
      case e: Exception =>
99 0
        throw new RuntimeException("Could not serialize ThriftStruct of type " + thriftStruct.getClass, e)
100
    }
101

102
  /* nb: thriftStructClass doesn't actually have type Class[T] it has type Class[_ <: T]
103
   * this lie is courtesy of the Kryo API
104
   * */
105
  override def read(kryo: Kryo, input: Input, thriftStructClass: Class[T]): T =
106
    // code reviewers: is this use of an anonymous inner class ok, or should I separate it out into something outside?
107 0
    try {
108 0
      val thriftStructSerializer = lookupThriftStructSerializer(thriftStructClass)
109 0
      val tSize = input.readInt(true)
110 0
      val barr = new Array[Byte](tSize)
111 0
      input.readBytes(barr)
112 0
      thriftStructSerializer.fromBytes(barr)
113
    } catch {
114 0
      case e: Exception => throw new RuntimeException("Could not create ThriftStruct " + thriftStructClass, e)
115
    }
116
}

Read our documentation on viewing source code .

Loading