spotify / featran
1
/*
2
 * Copyright 2017 Spotify AB.
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing,
11
 * software distributed under the License is distributed on an
12
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13
 * KIND, either express or implied.  See the License for the
14
 * specific language governing permissions and limitations
15
 * under the License.
16
 */
17

18
package com.spotify.featran.numpy
19

20
import java.io.OutputStream
21

22
import simulacrum.typeclass
23
import scala.annotation.implicitNotFound
24

25
/** Type class for NumPy numeric types. */
26
@implicitNotFound("Could not find an instance of NumPyType for ${T}")
27
@typeclass trait NumPyType[@specialized(Int, Long, Float, Double) T] extends Serializable {
28
  def descr: String
29

30
  def sizeOf: Int
31

32
  def write(out: OutputStream, value: T): Unit
33
}
34

35
object NumPyType {
36
  // from Guava LittleEndianDataOutputStream
37
  implicit class LittleEndianOutputStream(private val out: OutputStream) extends AnyVal {
38
    def writeInt(v: Int): Unit = {
39 6
      out.write(0xff & v)
40 6
      out.write(0xff & (v >> 8))
41 6
      out.write(0xff & (v >> 16))
42 6
      out.write(0xff & (v >> 24))
43
    }
44

45
    def writeLong(v: Long): Unit = {
46 6
      var value = java.lang.Long.reverseBytes(v)
47 6
      val result = new Array[Byte](8)
48 6
      var i = 7
49 6
      while (i >= 0) {
50 6
        result(i) = (value & 0xffL).toByte
51 6
        value >>= 8
52 6
        i -= 1
53
      }
54 6
      out.write(result)
55
    }
56

57 6
    def writeFloat(v: Float): Unit = writeInt(java.lang.Float.floatToIntBits(v))
58
    def writeDouble(v: Double): Unit =
59 6
      writeLong(java.lang.Double.doubleToLongBits(v))
60
  }
61

62 6
  implicit val intNumPyType: NumPyType[Int] = new NumPyType[Int] {
63 6
    override val descr: String = "<i4"
64 6
    override val sizeOf: Int = 4
65
    override def write(out: OutputStream, value: Int): Unit =
66 6
      out.writeInt(value)
67
  }
68

69 6
  implicit val longNumPyType: NumPyType[Long] = new NumPyType[Long] {
70 6
    override val descr: String = "<i8"
71 6
    override val sizeOf: Int = 8
72
    override def write(out: OutputStream, value: Long): Unit =
73 6
      out.writeLong(value)
74
  }
75

76 6
  implicit val floatNumPyType: NumPyType[Float] = new NumPyType[Float] {
77 6
    override val descr: String = "<f4"
78 6
    override val sizeOf: Int = 4
79
    override def write(out: OutputStream, value: Float): Unit =
80 6
      out.writeFloat(value)
81
  }
82

83 6
  implicit val doubleNumPyType: NumPyType[Double] = new NumPyType[Double] {
84 6
    override val descr: String = "<f8"
85 6
    override val sizeOf: Int = 8
86
    override def write(out: OutputStream, value: Double): Unit =
87 6
      out.writeDouble(value)
88
  }
89

90
  /* ======================================================================== */
91
  /* THE FOLLOWING CODE IS MANAGED BY SIMULACRUM; PLEASE DO NOT EDIT!!!!      */
92
  /* ======================================================================== */
93

94
  /** Summon an instance of [[NumPyType]] for `T`. */
95
  @inline def apply[T](implicit instance: NumPyType[T]): NumPyType[T] = instance
96

97
  object ops {
98
    implicit def toAllNumPyTypeOps[T](target: T)(implicit tc: NumPyType[T]): AllOps[T] {
99
      type TypeClassType = NumPyType[T]
100 0
    } = new AllOps[T] {
101
      type TypeClassType = NumPyType[T]
102
      val self: T = target
103
      val typeClassInstance: TypeClassType = tc
104
    }
105
  }
106
  trait Ops[@specialized(Int, Long, Float, Double) T] extends Serializable {
107
    type TypeClassType <: NumPyType[T]
108
    def self: T
109
    val typeClassInstance: TypeClassType
110
  }
111
  trait AllOps[@specialized(Int, Long, Float, Double) T] extends Ops[T]
112
  trait ToNumPyTypeOps extends Serializable {
113
    implicit def toNumPyTypeOps[T](target: T)(implicit tc: NumPyType[T]): Ops[T] {
114
      type TypeClassType = NumPyType[T]
115 0
    } = new Ops[T] {
116
      type TypeClassType = NumPyType[T]
117
      val self: T = target
118
      val typeClassInstance: TypeClassType = tc
119
    }
120
  }
121
  object nonInheritedOps extends ToNumPyTypeOps
122

123
  /* ======================================================================== */
124
  /* END OF SIMULACRUM-MANAGED CODE                                           */
125
  /* ======================================================================== */
126

127
}
128

129
/** Utilities for writing data as NumPy `.npy` files. */
130
object NumPy {
131
  private def header[T: NumPyType](dimensions: Seq[Int]): String = {
132
    // https://docs.scipy.org/doc/numpy/neps/npy-format.html
133 6
    val dims = dimensions.mkString(", ")
134 6
    val shape = if (dimensions.length > 1) s"($dims)" else s"($dims,)"
135
    val h =
136 6
      s"{'descr': '${NumPyType[T].descr}', 'fortran_order': False, 'shape': $shape, }"
137
    // 11 bytes: magic "0x93NUMPY", major version, minor version, (short) HEADER_LEN, '\n'
138 6
    val l = h.length + 11
139
    // pad magic string + 4 + HEADER_LEN to be evenly divisible by 16
140 6
    val n = if (l % 16 == 0) 0 else (l / 16 + 1) * 16 - l
141 6
    h + " " * n + "\n"
142
  }
143

144
  private def writeHeader[T: NumPyType](out: OutputStream, dimensions: Seq[Int]): Unit = {
145
    // magic
146 6
    out.write(0x93)
147 6
    out.write("NUMPY".getBytes)
148

149
    // major, minor
150 6
    out.write(1)
151 6
    out.write(0)
152

153
    // header
154 6
    val headerString = header(dimensions)
155
    // from Guava LittleEndianDataOutputStream#writeShort
156 6
    val l = headerString.length
157 6
    out.write(0xff & l)
158 6
    out.write(0xff & (l >> 8))
159 6
    out.write(headerString.getBytes)
160
  }
161

162
  private def writeData[T: NumPyType](out: OutputStream, data: Array[T]): Unit = {
163 6
    var i = 0
164 6
    while (i < data.length) {
165 6
      NumPyType[T].write(out, data(i))
166 6
      i += 1
167
    }
168
  }
169

170
  /**
171
   * Write an array as a NumPy `.npy` file to an output stream.
172
   * Default shape is `(data.length)`.
173
   */
174
  def write[@specialized(Int, Long, Float, Double) T: NumPyType](
175
    out: OutputStream,
176
    data: Array[T],
177
    shape: Seq[Int] = Seq.empty
178
  ): Unit = {
179 6
    val dims = if (shape.isEmpty) {
180 6
      Seq(data.length)
181 6
    } else {
182 6
      require(
183 6
        data.length == shape.product,
184 6
        s"Invalid shape, ${shape.mkString(" * ")} != ${data.length}"
185
      )
186
      shape
187
    }
188 6
    writeHeader(out, dims)
189 6
    writeData(out, data)
190 6
    out.flush()
191
  }
192

193
  /**
194
   * Write an iterator of arrays as a 2-dimensional NumPy `.npy` file to an output stream. Each
195
   * array should have length `numCols` and the iterator should have `numRows` elements.
196
   */
197
  def write[@specialized(Int, Long, Float, Double) T: NumPyType](
198
    out: OutputStream,
199
    data: Iterator[Array[T]],
200
    numRows: Int,
201
    numCols: Int
202
  ): Unit = {
203 6
    val dims = Seq(numRows, numCols)
204 6
    writeHeader[T](out, dims)
205 6
    var n = 0
206 6
    while (data.hasNext) {
207 6
      val row = data.next()
208 6
      require(row.length == numCols, s"Invalid row size, expected: $numCols, actual: ${row.length}")
209 6
      writeData(out, row)
210 6
      n += 1
211
    }
212 6
    require(n == numRows, s"Invalid number of rows, expected: $numRows, actual: $n")
213 6
    out.flush()
214
  }
215
}

Read our documentation on viewing source code .

Loading