1
/*
2
 * Copyright 2017 Spotify AB.
3
 *
4
 * Licensed under the Apache License, Version 2.0 (the "License");
5
 * you may not use this file except in compliance with the License.
6
 * You may obtain a copy of the License at
7
 *
8
 *     http://www.apache.org/licenses/LICENSE-2.0
9
 *
10
 * Unless required by applicable law or agreed to in writing,
11
 * software distributed under the License is distributed on an
12
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13
 * KIND, either express or implied.  See the License for the
14
 * specific language governing permissions and limitations
15
 * under the License.
16
 */
17

18
package com.spotify.featran.numpy
19

20
import java.io.OutputStream
21

22
import simulacrum.typeclass
23

24
/** Type class for NumPy numeric types. */
25 0
@typeclass trait NumPyType[@specialized(Int, Long, Float, Double) T] {
26
  def descr: String
27

28
  def sizeOf: Int
29

30
  def write(out: OutputStream, value: T): Unit
31
}
32

33
object NumPyType {
34
  // from Guava LittleEndianDataOutputStream
35
  implicit class LittleEndianOutputStream(private val out: OutputStream) extends AnyVal {
36
    def writeInt(v: Int): Unit = {
37 1
      out.write(0xff & v)
38 1
      out.write(0xff & (v >> 8))
39 1
      out.write(0xff & (v >> 16))
40 1
      out.write(0xff & (v >> 24))
41
    }
42

43
    def writeLong(v: Long): Unit = {
44 1
      var value = java.lang.Long.reverseBytes(v)
45 1
      val result = new Array[Byte](8)
46 1
      var i = 7
47 1
      while (i >= 0) {
48 1
        result(i) = (value & 0xffL).toByte
49 1
        value >>= 8
50 1
        i -= 1
51
      }
52 1
      out.write(result)
53
    }
54

55 1
    def writeFloat(v: Float): Unit = writeInt(java.lang.Float.floatToIntBits(v))
56
    def writeDouble(v: Double): Unit =
57 1
      writeLong(java.lang.Double.doubleToLongBits(v))
58
  }
59

60 1
  implicit val intNumPyType: NumPyType[Int] = new NumPyType[Int] {
61 1
    override val descr: String = "<i4"
62 1
    override val sizeOf: Int = 4
63
    override def write(out: OutputStream, value: Int): Unit =
64 1
      out.writeInt(value)
65
  }
66

67 1
  implicit val longNumPyType: NumPyType[Long] = new NumPyType[Long] {
68 1
    override val descr: String = "<i8"
69 1
    override val sizeOf: Int = 8
70
    override def write(out: OutputStream, value: Long): Unit =
71 1
      out.writeLong(value)
72
  }
73

74 1
  implicit val floatNumPyType: NumPyType[Float] = new NumPyType[Float] {
75 1
    override val descr: String = "<f4"
76 1
    override val sizeOf: Int = 4
77
    override def write(out: OutputStream, value: Float): Unit =
78 1
      out.writeFloat(value)
79
  }
80

81 1
  implicit val doubleNumPyType: NumPyType[Double] = new NumPyType[Double] {
82 1
    override val descr: String = "<f8"
83 1
    override val sizeOf: Int = 8
84
    override def write(out: OutputStream, value: Double): Unit =
85 1
      out.writeDouble(value)
86
  }
87
}
88

89
/** Utilities for writing data as NumPy `.npy` files. */
90
object NumPy {
91
  private def header[T: NumPyType](dimensions: Seq[Int]): String = {
92
    // https://docs.scipy.org/doc/numpy/neps/npy-format.html
93 1
    val dims = dimensions.mkString(", ")
94 1
    val shape = if (dimensions.length > 1) s"($dims)" else s"($dims,)"
95
    val h =
96 1
      s"{'descr': '${NumPyType[T].descr}', 'fortran_order': False, 'shape': $shape, }"
97
    // 11 bytes: magic "0x93NUMPY", major version, minor version, (short) HEADER_LEN, '\n'
98 1
    val l = h.length + 11
99
    // pad magic string + 4 + HEADER_LEN to be evenly divisible by 16
100 1
    val n = if (l % 16 == 0) 0 else (l / 16 + 1) * 16 - l
101 1
    h + " " * n + "\n"
102
  }
103

104
  private def writeHeader[T: NumPyType](out: OutputStream, dimensions: Seq[Int]): Unit = {
105
    // magic
106 1
    out.write(0x93)
107 1
    out.write("NUMPY".getBytes)
108

109
    // major, minor
110 1
    out.write(1)
111 1
    out.write(0)
112

113
    // header
114 1
    val headerString = header(dimensions)
115
    // from Guava LittleEndianDataOutputStream#writeShort
116 1
    val l = headerString.length
117 1
    out.write(0xff & l)
118 1
    out.write(0xff & (l >> 8))
119 1
    out.write(headerString.getBytes)
120
  }
121

122
  private def writeData[T: NumPyType](out: OutputStream, data: Array[T]): Unit = {
123 1
    var i = 0
124 1
    while (i < data.length) {
125 1
      NumPyType[T].write(out, data(i))
126 1
      i += 1
127
    }
128
  }
129

130
  /**
131
   * Write an array as a NumPy `.npy` file to an output stream.
132
   * Default shape is `(data.length)`.
133
   */
134
  def write[@specialized(Int, Long, Float, Double) T: NumPyType](
135
    out: OutputStream,
136
    data: Array[T],
137
    shape: Seq[Int] = Seq.empty
138
  ): Unit = {
139 1
    val dims = if (shape.isEmpty) {
140 1
      Seq(data.length)
141 1
    } else {
142 1
      require(
143 1
        data.length == shape.product,
144 1
        s"Invalid shape, ${shape.mkString(" * ")} != ${data.length}"
145
      )
146
      shape
147
    }
148 1
    writeHeader(out, dims)
149 1
    writeData(out, data)
150 1
    out.flush()
151
  }
152

153
  /**
154
   * Write an iterator of arrays as a 2-dimensional NumPy `.npy` file to an output stream. Each
155
   * array should have length `numCols` and the iterator should have `numRows` elements.
156
   */
157
  def write[@specialized(Int, Long, Float, Double) T: NumPyType](
158
    out: OutputStream,
159
    data: Iterator[Array[T]],
160
    numRows: Int,
161
    numCols: Int
162
  ): Unit = {
163 1
    val dims = Seq(numRows, numCols)
164 1
    writeHeader[T](out, dims)
165 1
    var n = 0
166 1
    while (data.hasNext) {
167 1
      val row = data.next()
168 1
      require(row.length == numCols, s"Invalid row size, expected: $numCols, actual: ${row.length}")
169 1
      writeData(out, row)
170 1
      n += 1
171
    }
172 1
    require(n == numRows, s"Invalid number of rows, expected: $numRows, actual: $n")
173 1
    out.flush()
174
  }
175
}

Read our documentation on viewing source code .

Loading