1
/*
2
 * Copyright 2017 Spotify AB.
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing,
11
 * software distributed under the License is distributed on an
12
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13
 * KIND, either express or implied.  See the License for the
14
 * specific language governing permissions and limitations
15
 * under the License.
16
 */
17

18
package com.spotify.featran.java
19

20
import java.lang.{Double => JDouble}
21
import java.util.function.BiFunction
22
import java.util.{Collections, Optional, List => JList}
23

24
import com.spotify.featran._
25
import com.spotify.featran.tensorflow._
26
import com.spotify.featran.xgboost._
27
import ml.dmlc.xgboost4j.LabeledPoint
28
import org.tensorflow.example.Example
29

30
import scala.collection.JavaConverters._
31
import scala.reflect.ClassTag
32

33
private object JavaOps {
34
  def requiredFn[I, O](f: SerializableFunction[I, O]): I => O =
35 1
    (input: I) => f(input)
36

37
  def optionalFn[I, O](f: SerializableFunction[I, Optional[O]]): I => Option[O] =
38
    (input: I) => {
39 1
      val o = f(input)
40 1
      if (o.isPresent) Some(o.get()) else None
41
    }
42

43
  def crossFn(f: BiFunction[JDouble, JDouble, JDouble]): (Double, Double) => Double =
44 1
    (a, b) => f(a, b)
45

46 1
  implicit val jListCollectionType: CollectionType[JList] = new CollectionType[JList] {
47
    override def map[A, B: ClassTag](ma: JList[A])(f: A => B): JList[B] =
48 1
      ma.asScala.map(f).asJava
49

50
    override def reduce[A](ma: JList[A])(f: (A, A) => A): JList[A] =
51 1
      Collections.singletonList(ma.asScala.reduce(f))
52

53
    override def cross[A, B: ClassTag](ma: JList[A])(mb: JList[B]): JList[(A, B)] =
54 1
      ma.asScala.map((_, mb.get(0))).asJava
55

56 1
    override def pure[A, B: ClassTag](ma: JList[A])(b: B): JList[B] = Collections.singletonList(b)
57
  }
58

59
  def extractWithSubsetSettingsFloat[T](
60
    fs: FeatureSpec[T],
61
    settings: String
62
  ): RecordExtractor[T, Array[Float]] =
63 1
    fs.extractWithSubsetSettings(settings)
64

65
  def extractWithSubsetSettingsDouble[T](
66
    fs: FeatureSpec[T],
67
    settings: String
68
  ): RecordExtractor[T, Array[Double]] =
69 1
    fs.extractWithSubsetSettings(settings)
70

71
  def extractWithSubsetSettingsFloatSparseArray[T](
72
    fs: FeatureSpec[T],
73
    settings: String
74
  ): RecordExtractor[T, FloatSparseArray] =
75 1
    fs.extractWithSubsetSettings(settings)
76

77
  def extractWithSubsetSettingsDoubleSparseArray[T](
78
    fs: FeatureSpec[T],
79
    settings: String
80
  ): RecordExtractor[T, DoubleSparseArray] =
81 1
    fs.extractWithSubsetSettings(settings)
82

83
  def extractWithSubsetSettingsDoubleNamedSparseArray[T](
84
    fs: FeatureSpec[T],
85
    settings: String
86
  ): RecordExtractor[T, DoubleNamedSparseArray] =
87 1
    fs.extractWithSubsetSettings(settings)
88

89
  def extractWithSubsetSettingsFloatNamedSparseArray[T](
90
    fs: FeatureSpec[T],
91
    settings: String
92
  ): RecordExtractor[T, FloatNamedSparseArray] =
93 1
    fs.extractWithSubsetSettings(settings)
94

95
  def extractWithSubsetSettingsExample[T](
96
    fs: FeatureSpec[T],
97
    settings: String
98
  ): RecordExtractor[T, Example] =
99 1
    fs.extractWithSubsetSettings(settings)
100

101
  def extractWithSubsetSettingsLabeledPoint[T](
102
    fs: FeatureSpec[T],
103
    settings: String
104
  ): RecordExtractor[T, LabeledPoint] =
105 1
    fs.extractWithSubsetSettings(settings)
106

107
  def extractWithSubsetSettingsSparseLabeledPoint[T](
108
    fs: FeatureSpec[T],
109
    settings: String
110
  ): RecordExtractor[T, SparseLabeledPoint] =
111 1
    fs.extractWithSubsetSettings(settings)
112

113
  //================================================================================
114
  // Wrappers for FeatureSpec
115
  //================================================================================
116

117
  def extract[T](fs: FeatureSpec[T], input: JList[T]): FeatureExtractor[JList, T] =
118 1
    fs.extract(input)
119

120
  def extractWithSettings[T](
121
    fs: FeatureSpec[T],
122
    input: JList[T],
123
    settings: String
124
  ): FeatureExtractor[JList, T] =
125 1
    fs.extractWithSettings(input, Collections.singletonList(settings))
126

127
  def extractWithSettingsFloat[T](
128
    fs: FeatureSpec[T],
129
    settings: String
130
  ): RecordExtractor[T, Array[Float]] =
131 1
    fs.extractWithSettings(settings)
132
  def extractWithSettingsDouble[T](
133
    fs: FeatureSpec[T],
134
    settings: String
135
  ): RecordExtractor[T, Array[Double]] =
136 1
    fs.extractWithSettings(settings)
137

138
  def extractWithSettingsFloatSparseArray[T](
139
    fs: FeatureSpec[T],
140
    settings: String
141
  ): RecordExtractor[T, FloatSparseArray] =
142 1
    fs.extractWithSettings(settings)
143
  def extractWithSettingsDoubleSparseArray[T](
144
    fs: FeatureSpec[T],
145
    settings: String
146
  ): RecordExtractor[T, DoubleSparseArray] =
147 1
    fs.extractWithSettings(settings)
148

149
  def extractWithSettingsFloatNamedSparseArray[T](
150
    fs: FeatureSpec[T],
151
    settings: String
152
  ): RecordExtractor[T, FloatNamedSparseArray] =
153 1
    fs.extractWithSettings(settings)
154
  def extractWithSettingsDoubleNamedSparseArray[T](
155
    fs: FeatureSpec[T],
156
    settings: String
157
  ): RecordExtractor[T, DoubleNamedSparseArray] =
158 1
    fs.extractWithSettings(settings)
159

160
  def extractWithSettingsExample[T](
161
    fs: FeatureSpec[T],
162
    settings: String
163
  ): RecordExtractor[T, Example] =
164 1
    fs.extractWithSettings(settings)
165

166
  def extractWithSettingsLabeledPoint[T](
167
    fs: FeatureSpec[T],
168
    settings: String
169
  ): RecordExtractor[T, LabeledPoint] =
170 1
    fs.extractWithSettings(settings)
171

172
  def extractWithSettingsSparseLabeledPoint[T](
173
    fs: FeatureSpec[T],
174
    settings: String
175
  ): RecordExtractor[T, SparseLabeledPoint] =
176 1
    fs.extractWithSettings(settings)
177

178
  //================================================================================
179
  // Wrappers for FeatureExtractor
180
  //================================================================================
181

182
  def featureSettings[T](fe: FeatureExtractor[JList, T]): String =
183 1
    fe.featureSettings.get(0)
184
  def featureNames[T](fe: FeatureExtractor[JList, T]): JList[String] =
185 1
    fe.featureNames.get(0).asJava
186
  def featureValuesFloat[T](fe: FeatureExtractor[JList, T]): JList[Array[Float]] =
187 1
    fe.featureValues[Array[Float]]
188
  def featureValuesDouble[T](fe: FeatureExtractor[JList, T]): JList[Array[Double]] =
189 1
    fe.featureValues[Array[Double]]
190

191
  implicit def floatSparseArrayFB: FeatureBuilder[FloatSparseArray] =
192 1
    FeatureBuilder[SparseArray[Float]].map(a => new FloatSparseArray(a.indices, a.values, a.length))
193
  implicit def doubleSparseArrayFB: FeatureBuilder[DoubleSparseArray] =
194 1
    FeatureBuilder[SparseArray[Double]]
195 1
      .map(a => new DoubleSparseArray(a.indices, a.values, a.length))
196

197
  def featureValuesFloatSparseArray[T](fe: FeatureExtractor[JList, T]): JList[FloatSparseArray] =
198 1
    fe.featureValues[FloatSparseArray]
199
  def featureValuesDoubleSparseArray[T](fe: FeatureExtractor[JList, T]): JList[DoubleSparseArray] =
200 1
    fe.featureValues[DoubleSparseArray]
201

202
  implicit def floatNamedSparseArrayFB: FeatureBuilder[FloatNamedSparseArray] =
203 1
    FeatureBuilder[NamedSparseArray[Float]]
204 1
      .map(a => new FloatNamedSparseArray(a.indices, a.values, a.length, a.names))
205
  implicit def doubleNamedSparseArrayFB: FeatureBuilder[DoubleNamedSparseArray] =
206 1
    FeatureBuilder[NamedSparseArray[Double]]
207 1
      .map(a => new DoubleNamedSparseArray(a.indices, a.values, a.length, a.names))
208

209
  def featureValuesFloatNamedSparseArray[T](
210
    fe: FeatureExtractor[JList, T]
211
  ): JList[FloatNamedSparseArray] =
212 1
    fe.featureValues[FloatNamedSparseArray]
213
  def featureValuesDoubleNamedSparseArray[T](
214
    fe: FeatureExtractor[JList, T]
215
  ): JList[DoubleNamedSparseArray] =
216 1
    fe.featureValues[DoubleNamedSparseArray]
217

218
  def featureValuesExample[T](fe: FeatureExtractor[JList, T]): JList[Example] =
219 1
    fe.featureValues[Example]
220

221
  def featureValuesLabeledPoint[T](fe: FeatureExtractor[JList, T]): JList[LabeledPoint] =
222 1
    fe.featureValues[LabeledPoint]
223

224
  def featureValuesSparseLabeledPoint[T](
225
    fe: FeatureExtractor[JList, T]
226
  ): JList[SparseLabeledPoint] =
227 1
    fe.featureValues[SparseLabeledPoint]
228

229
  //================================================================================
230
  // Wrappers for RecordExtractor
231
  //================================================================================
232

233
  def featureNames[F, T](fe: RecordExtractor[T, F]): JList[String] =
234 1
    fe.featureNames.asJava
235
}
236

237
/** A sparse array of float values. */
238
class FloatSparseArray private[java] (
239
  indices: Array[Int],
240
  override val values: Array[Float],
241
  length: Int
242
) extends SparseArray[Float](indices, values, length) {
243 1
  def toDense: Array[Float] = super.toDense
244
}
245

246
/** A sparse array of double values. */
247
class DoubleSparseArray private[java] (
248
  indices: Array[Int],
249
  override val values: Array[Double],
250
  length: Int
251
) extends SparseArray[Double](indices, values, length) {
252 1
  def toDense: Array[Double] = super.toDense
253
}
254

255
/** A named sparse array of float values. */
256
class FloatNamedSparseArray private[java] (
257
  indices: Array[Int],
258
  override val values: Array[Float],
259
  length: Int,
260
  names: Seq[String]
261
) extends NamedSparseArray[Float](indices, values, length, names) {
262 1
  def toDense: Array[Float] = super.toDense
263
}
264

265
/** A named sparse array of double values. */
266
class DoubleNamedSparseArray private[java] (
267
  indices: Array[Int],
268
  override val values: Array[Double],
269
  length: Int,
270
  names: Seq[String]
271
) extends NamedSparseArray[Double](indices, values, length, names) {
272 1
  def toDense: Array[Double] = super.toDense
273
}

Read our documentation on viewing source code .

Loading