1
/*
2
 * Copyright 2017 Spotify AB.
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing,
11
 * software distributed under the License is distributed on an
12
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13
 * KIND, either express or implied.  See the License for the
14
 * specific language governing permissions and limitations
15
 * under the License.
16
 */
17

18
package com.spotify.featran
19

20
import _root_.java.util.regex.Pattern
21
import _root_.java.util.concurrent.ConcurrentHashMap
22
import _root_.java.util.function.Function
23

24
import com.spotify.featran.transformers.{MDLRecord, WeightedLabel}
25
import org.tensorflow.example.{Example, Features}
26
import org.tensorflow.{example => tf}
27
import shapeless.datatype.tensorflow.TensorFlowType
28

29
package object tensorflow {
30
  private[this] object FeatureNameNormalization {
31 1
    private[this] val NamePattern = Pattern.compile("[^A-Za-z0-9_]")
32

33
    val normalize: String => String = {
34
      lazy val cache = new ConcurrentHashMap[String, String]()
35
      fn =>
36 1
        cache.computeIfAbsent(
37
          fn,
38 1
          new Function[String, String] {
39
            override def apply(n: String): String =
40 1
              NamePattern.matcher(n).replaceAll("_")
41
          }
42
        )
43
    }
44
  }
45

46
  final case class NamedTFFeature(name: String, f: tf.Feature)
47

48
  final case class TensorFlowFeatureBuilder(
49
    @transient private var underlying: tf.Features.Builder = tf.Features.newBuilder()
50
  ) extends FeatureBuilder[tf.Example] {
51
    override def init(dimension: Int): Unit = {
52 1
      if (underlying == null) {
53 1
        underlying = tf.Features.newBuilder()
54
      }
55 1
      underlying.clear()
56
    }
57
    override def add(name: String, value: Double): Unit = {
58
      val feature = tf.Feature
59
        .newBuilder()
60
        .setFloatList(tf.FloatList.newBuilder().addValue(value.toFloat))
61 1
        .build()
62 1
      val normalized = FeatureNameNormalization.normalize(name)
63 1
      underlying.putFeature(normalized, feature)
64
    }
65 1
    override def skip(): Unit = ()
66 1
    override def skip(n: Int): Unit = ()
67
    override def result: tf.Example =
68 1
      tf.Example.newBuilder().setFeatures(underlying).build()
69

70 1
    override def newBuilder: FeatureBuilder[tf.Example] = TensorFlowFeatureBuilder()
71
  }
72

73
  /** [[FeatureBuilder]] for output as TensorFlow `Example` type. */
74 1
  implicit def tensorFlowFeatureBuilder: FeatureBuilder[tf.Example] = TensorFlowFeatureBuilder()
75

76 1
  implicit val exampleFlatReader: FlatReader[tf.Example] = new FlatReader[tf.Example] {
77
    import TensorFlowType._
78

79
    def toFeature(name: String, ex: tf.Example): Option[tf.Feature] = {
80 1
      val fm = ex.getFeatures.getFeatureMap
81 1
      if (fm.containsKey(name)) {
82 1
        Some(fm.get(name))
83
      } else {
84 0
        None
85
      }
86
    }
87

88
    def readDouble(name: String): Example => Option[Double] =
89 1
      (ex: Example) => toFeature(name, ex).flatMap(v => toDoubles(v).headOption)
90

91
    def readMdlRecord(name: String): Example => Option[MDLRecord[String]] =
92
      (ex: Example) => {
93
        for {
94 0
          labelFeature <- toFeature(name + "_label", ex)
95 0
          label <- toStrings(labelFeature).headOption
96 0
          valueFeature <- toFeature(name + "_value", ex)
97 0
          value <- toDoubles(valueFeature).headOption
98 0
        } yield MDLRecord(label, value)
99
      }
100

101
    def readWeightedLabel(name: String): Example => Option[List[WeightedLabel]] =
102
      (ex: Example) => {
103
        val labels = for {
104 1
          keyFeature <- toFeature(name + "_key", ex).toList
105 1
          key <- toStrings(keyFeature)
106 1
          valueFeature <- toFeature(name + "_value", ex).toList
107 1
          value <- toDoubles(valueFeature)
108 1
        } yield WeightedLabel(key, value)
109 1
        if (labels.isEmpty) None else Some(labels)
110
      }
111

112
    def readDoubles(name: String): Example => Option[Seq[Double]] =
113 1
      (ex: Example) => toFeature(name, ex).map(v => toDoubles(v))
114

115
    def readDoubleArray(name: String): Example => Option[Array[Double]] =
116 1
      (ex: Example) => toFeature(name, ex).map(v => toDoubles(v).toArray)
117

118
    def readString(name: String): Example => Option[String] =
119 1
      (ex: Example) => toFeature(name, ex).flatMap(v => toStrings(v).headOption)
120

121
    def readStrings(name: String): Example => Option[Seq[String]] =
122 1
      (ex: Example) => toFeature(name, ex).map(v => toStrings(v))
123
  }
124

125 1
  implicit val exampleFlatWriter: FlatWriter[Example] = new FlatWriter[tf.Example] {
126
    import TensorFlowType._
127
    type IF = List[NamedTFFeature]
128

129
    override def writeDouble(name: String): Option[Double] => List[NamedTFFeature] =
130 1
      (v: Option[Double]) => v.toList.map(r => NamedTFFeature(name, fromDoubles(Seq(r)).build()))
131

132
    override def writeMdlRecord(name: String): Option[MDLRecord[String]] => List[NamedTFFeature] =
133
      (v: Option[MDLRecord[String]]) => {
134 1
        v.toList.flatMap { values =>
135 1
          List(
136 1
            NamedTFFeature(name + "_label", fromStrings(Seq(values.label.toString)).build()),
137 1
            NamedTFFeature(name + "_value", fromDoubles(Seq(values.value)).build())
138
          )
139
        }
140
      }
141

142
    override def writeWeightedLabel(n: String): Option[Seq[WeightedLabel]] => List[NamedTFFeature] =
143
      (v: Option[Seq[WeightedLabel]]) => {
144 1
        v.toList.flatMap { values =>
145 1
          List(
146 1
            NamedTFFeature(n + "_key", fromStrings(values.map(_.name)).build()),
147 1
            NamedTFFeature(n + "_value", fromDoubles(values.map(_.value)).build())
148
          )
149
        }
150
      }
151

152
    override def writeDoubles(name: String): Option[Seq[Double]] => List[NamedTFFeature] =
153
      (v: Option[Seq[Double]]) => {
154 1
        v.toList.flatMap(values => List(NamedTFFeature(name, fromDoubles(values).build())))
155
      }
156

157
    override def writeDoubleArray(name: String): Option[Array[Double]] => List[NamedTFFeature] =
158
      (v: Option[Array[Double]]) => {
159 0
        v.toList.flatMap(values => List(NamedTFFeature(name, fromDoubles(values).build())))
160
      }
161

162
    override def writeString(name: String): Option[String] => List[NamedTFFeature] =
163
      (v: Option[String]) => {
164 1
        v.toList.flatMap(values => List(NamedTFFeature(name, fromStrings(Seq(values)).build())))
165
      }
166

167
    override def writeStrings(name: String): Option[Seq[String]] => List[NamedTFFeature] =
168
      (v: Option[Seq[String]]) => {
169 1
        v.toList.flatMap(values => List(NamedTFFeature(name, fromStrings(values).build())))
170
      }
171

172
    override def writer: Seq[List[NamedTFFeature]] => Example =
173
      (fns: Seq[List[NamedTFFeature]]) => {
174 1
        val builder = Features.newBuilder()
175 1
        fns.foreach { f =>
176 1
          f.foreach { nf =>
177 1
            val normalized = FeatureNameNormalization.normalize(nf.name)
178 1
            builder.putFeature(normalized, nf.f)
179
          }
180
        }
181
        Example
182
          .newBuilder()
183
          .setFeatures(builder.build())
184 1
          .build()
185
      }
186
  }
187
}

Read our documentation on viewing source code .

Loading