arviz-devs / ArviZ.jl
1
"""
2
    Dataset(::PyObject)
3
    Dataset(; data_vars = nothing, coords = nothing, attrs = nothing)
4

5
Loose wrapper around `xarray.Dataset`, mostly used for dispatch.
6

7
# Keywords
8

9
  - `data_vars::Dict{String,Any}`: Dict mapping variable names to
10
    
11
      + `Vector`: Data vector. Single dimension is named after variable.
12
      + `Tuple{String,Vector}`: Dimension name and data vector.
13
      + `Tuple{NTuple{N,String},Array{T,N}} where {N,T}`: Dimension names and data array.
14

15
  - `coords::Dict{String,Any}`: Dict mapping dimension names to index names. Possible
16
    arguments has same form as `data_vars`.
17
  - `attrs::Dict{String,Any}`: Global attributes to save on this dataset.
18

19
In most cases, use [`convert_to_dataset`](@ref) or [`convert_to_constant_dataset`](@ref) or
20
to create a `Dataset` instead of directly using a constructor.
21
"""
22
struct Dataset
23
    o::PyObject
24

25
    function Dataset(o::PyObject)
26 20
        if !pyisinstance(o, xarray.Dataset)
27 20
            throw(ArgumentError("$o is not an `xarray.Dataset`."))
28
        end
29 20
        return new(o)
30
    end
31
end
32

33 20
Dataset(; kwargs...) = Dataset(xarray.Dataset(; kwargs...))
34 20
@inline Dataset(data::Dataset) = data
35

36 20
@inline PyObject(data::Dataset) = getfield(data, :o)
37

38 20
Base.convert(::Type{Dataset}, obj::PyObject) = Dataset(obj)
39 20
Base.convert(::Type{Dataset}, obj::Dataset) = obj
40 20
Base.convert(::Type{Dataset}, obj) = convert_to_dataset(obj)
41

42 20
Base.hash(data::Dataset) = hash(PyObject(data))
43

44 20
Base.propertynames(data::Dataset) = propertynames(PyObject(data))
45

46
function Base.getproperty(data::Dataset, name::Symbol)
47 20
    o = PyObject(data)
48 20
    name === :o && return o
49 20
    return getproperty(o, name)
50
end
51

52 20
Base.getindex(data::Dataset, k) = py"$(data)[$k]"
53

54
function Base.show(io::IO, data::Dataset)
55 20
    out = pycall(pybuiltin("str"), String, data)
56 20
    out = replace(out, "<xarray.Dataset>" => "Dataset (xarray.Dataset)")
57 20
    print(io, out)
58 20
    return nothing
59
end
60
function Base.show(io::IO, ::MIME"text/html", data::Dataset)
61 20
    obj = PyObject(data)
62 20
    (:_repr_html_ in propertynames(obj)) || return show(io, data)
63 20
    out = obj._repr_html_()
64 20
    out = replace(out, r"(<|&lt;)?xarray.Dataset(>|&gt;)?" => "Dataset (xarray.Dataset)")
65 20
    print(io, out)
66 20
    return nothing
67
end
68

69 20
attributes(data::Dataset) = getproperty(PyObject(data), :_attrs)
70

71
function setattribute!(data::Dataset, key, value)
72 20
    attrs = merge(attributes(data), Dict(key => value))
73 20
    setproperty!(PyObject(data), :_attrs, attrs)
74 20
    return attrs
75
end
76

77
@doc doc"""
78
    convert_to_dataset(obj; group = :posterior, kwargs...) -> Dataset
79

80
Convert a supported object to a `Dataset`.
81

82
In most cases, this function calls [`convert_to_inference_data`](@ref) and returns the
83
corresponding `group`.
84
"""
85
convert_to_dataset
86

87
function convert_to_dataset(obj; group=:posterior, kwargs...)
88 20
    group = Symbol(group)
89 20
    idata = convert_to_inference_data(obj; group=group, kwargs...)
90 20
    dataset = getproperty(idata, group)
91 20
    return dataset
92
end
93 20
convert_to_dataset(data::Dataset; kwargs...) = data
94

95
@doc doc"""
96
    convert_to_constant_dataset(obj::Dict; kwargs...) -> Dataset
97
    convert_to_constant_dataset(obj::NamedTuple; kwargs...) -> Dataset
98

99
Convert `obj` into a `Dataset`.
100

101
Unlike [`convert_to_dataset`](@ref), this is intended for containing constant parameters
102
such as observed data and constant data, and the first two dimensions are not required to be
103
the number of chains and draws.
104

105
# Keywords
106

107
- `coords::Dict{String,Vector}`: Map from named dimension to index names
108
- `dims::Dict{String,Vector{String}}`: Map from variable name to names of its dimensions
109
- `library::Any`: A library associated with the data to add to `attrs`.
110
- `attrs::Dict{String,Any}`: Global attributes to save on this dataset.
111
"""
112
convert_to_constant_dataset
113

114
function convert_to_constant_dataset(
115
    obj; coords=nothing, dims=nothing, library=nothing, attrs=nothing
116
)
117 20
    base = arviz.data.base
118

119 20
    obj = _asstringkeydict(obj)
120 20
    coords = _asstringkeydict(coords)
121 20
    dims = _asstringkeydict(dims)
122 20
    attrs = _asstringkeydict(attrs)
123

124 20
    data = Dict{String,PyObject}()
125 20
    for (key, vals) in obj
126 20
        vals = _asarray(vals)
127 20
        val_dims = get(dims, key, nothing)
128 20
        (val_dims, val_coords) = base.generate_dims_coords(
129
            size(vals), key; dims=val_dims, coords=coords
130
        )
131 20
        data[key] = xarray.DataArray(vals; dims=val_dims, coords=val_coords)
132
    end
133

134 20
    default_attrs = base.make_attrs()
135 20
    if library !== nothing
136 20
        default_attrs = merge(default_attrs, Dict("inference_library" => string(library)))
137
    end
138 20
    attrs = merge(default_attrs, attrs)
139 20
    return Dataset(; data_vars=data, coords=coords, attrs=attrs)
140
end
141

142
@doc doc"""
143
    dict_to_dataset(data::Dict{String,Array}; kwargs...) -> Dataset
144

145
Convert a dictionary with data and keys as variable names to a [`Dataset`](@ref).
146

147
# Keywords
148

149
- `attrs::Dict{String,Any}`: Json serializable metadata to attach to the dataset, in
150
    addition to defaults.
151
- `library::String`: Name of library used for performing inference. Will be attached to the
152
    `attrs` metadata.
153
- `coords::Dict{String,Array}`: Coordinates for the dataset
154
- `dims::Dict{String,Vector{String}}`: Dimensions of each variable. The keys are variable
155
    names, values are vectors of coordinates.
156

157
# Examples
158

159
```@example
160
using ArviZ
161
ArviZ.dict_to_dataset(Dict("x" => randn(4, 100), "y" => randn(4, 100)))
162
```
163
"""
164
dict_to_dataset
165

166
function dict_to_dataset(data; library=nothing, attrs=nothing, kwargs...)
167 20
    if library !== nothing
168 20
        ldict = Dict("inference_library" => string(library))
169 20
        attrs = (attrs === nothing ? ldict : merge(attrs, ldict))
170
    end
171 20
    return arviz.dict_to_dataset(data; attrs=attrs, kwargs...)
172
end
173

174
@doc doc"""
175
    dataset_to_dict(ds::Dataset) -> Tuple{Dict{String,Array},NamedTuple}
176

177
Convert a `Dataset` to a dictionary of `Array`s. The function also returns keyword arguments
178
to [`dict_to_dataset`](@ref).
179
"""
180
dataset_to_dict
181

182
function dataset_to_dict(ds::Dataset)
183 20
    ds_dict = ds.to_dict()
184 20
    data_vars = ds_dict["data_vars"]
185 20
    attrs = ds_dict["attrs"]
186

187 20
    coords = ds_dict["coords"]
188 20
    delete!(coords, "chain")
189 20
    delete!(coords, "draw")
190 20
    coords = Dict(k => v["data"] for (k, v) in coords)
191

192 20
    data = Dict{String,Array}()
193 20
    dims = Dict{String,Vector{String}}()
194 20
    for (k, v) in data_vars
195 20
        data[k] = v["data"]
196 20
        dim = v["dims"][3:end]
197 20
        if !isempty(dim)
198 20
            dims[k] = [dim...]
199
        end
200
    end
201

202 20
    return data, (attrs=attrs, coords=coords, dims=dims)
203
end

Read our documentation on viewing source code .

Loading