1
|
1
|
module CombinatorialBandits
|
2
|
|
using LinearAlgebra
|
3
|
|
|
4
|
|
using DataStructures
|
5
|
|
using IterTools
|
6
|
|
using Distributions
|
7
|
|
|
8
|
|
import Munkres # Avoid clashing with Hungarian.
|
9
|
|
import Hungarian
|
10
|
|
using LightGraphs
|
11
|
|
using JuMP
|
12
|
|
|
13
|
|
import Base: push!, copy, hash, isequal
|
14
|
|
import JuMP: value
|
15
|
|
|
16
|
|
export Policy, CombinatorialInstance, State,
|
17
|
|
initial_state, initial_trace, simulate, choose_action, pull, update!, solve_linear, solve_budgeted_linear, solve_all_budgeted_linear, get_lp_formulation, has_lp_formulation, is_feasible, is_partially_acceptable,
|
18
|
|
ThompsonSampling, LLR, CUCB, ESCB2, OLSUCB,
|
19
|
|
ThompsonSamplingDetails, LLRDetails, CUCBDetails, ESCB2Details, OLSUCBDetails,
|
20
|
|
ESCB2OptimisationAlgorithm, ESCB2Exact, ESCB2Greedy, ESCB2Budgeted, OLSUCBOptimisationAlgorithm, OLSUCBGreedy,
|
21
|
|
PerfectBipartiteMatching, UncorrelatedPerfectBipartiteMatching, CorrelatedPerfectBipartiteMatching, PerfectBipartiteMatchingSolver, PerfectBipartiteMatchingLPSolver, PerfectBipartiteMatchingMunkresSolver, PerfectBipartiteMatchingHungarianSolver, PerfectBipartiteMatchingAlgosSolver,
|
22
|
|
ElementaryPath, ElementaryPathSolver, ElementaryPathLightGraphsDijkstraSolver, ElementaryPathLPSolver, ElementaryPathAlgosSolver,
|
23
|
|
SpanningTree, SpanningTreeSolver, SpanningTreeLightGraphsPrimSolver, SpanningTreeAlgosSolver, SpanningTreeLPSolver,
|
24
|
|
MSet, MSetSolver, MSetAlgosSolver, MSetLPSolver,
|
25
|
|
# Algos.
|
26
|
|
MSetInstance, MSetSolution, dimension, m, value, values, msets_greedy, msets_dp, msets_lp,
|
27
|
|
BudgetedMSetInstance, BudgetedMSetSolution, weight, weights, budget, max_weight, items, items_all_budgets, budgeted_msets_dp, budgeted_msets_lp, budgeted_msets_lp_select, budgeted_msets_lp_all,
|
28
|
|
ElementaryPathInstance, ElementaryPathSolution, graph, costs, src, dst, cost, lp_dp,
|
29
|
|
BudgetedElementaryPathInstance, BudgetedElementaryPathSolution, rewards, reward, budgeted_lp_dp,
|
30
|
|
SpanningTreeInstance, SpanningTreeSolution, st_prim,
|
31
|
|
BudgetedSpanningTreeInstance, BudgetedSpanningTreeSolution, BudgetedSpanningTreeLagrangianSolution, SimpleBudgetedSpanningTreeSolution, _budgeted_spanning_tree_compute_value, _budgeted_spanning_tree_compute_weight, st_prim_budgeted_lagrangian, st_prim_budgeted_lagrangian_search, _solution_symmetric_difference, _solution_symmetric_difference_size, st_prim_budgeted_lagrangian_refinement, st_prim_budgeted_lagrangian_approx_half,
|
32
|
|
BipartiteMatchingInstance, BipartiteMatchingSolution, matching_hungarian, BudgetedBipartiteMatchingInstance, BudgetedBipartiteMatchingSolution, BudgetedBipartiteMatchingLagrangianSolution, SimpleBudgetedBipartiteMatchingSolution, matching_hungarian_budgeted_lagrangian, matching_hungarian_budgeted_lagrangian_search, matching_hungarian_budgeted_lagrangian_refinement, matching_hungarian_budgeted_lagrangian_approx_half
|
33
|
|
|
34
|
|
# General algorithm.
|
35
|
|
abstract type Policy end
|
36
|
|
abstract type PolicyDetails end
|
37
|
|
abstract type CombinatorialInstance{T} end
|
38
|
|
|
39
|
|
# Define the state of a bandit (evolves at each round).
|
40
|
|
mutable struct State{T}
|
41
|
1
|
round::Int
|
42
|
|
regret::Float64
|
43
|
|
reward::Float64
|
44
|
|
arm_counts::Dict{T, Int}
|
45
|
|
arm_reward::Dict{T, Float64}
|
46
|
|
arm_average_reward::Dict{T, Float64}
|
47
|
|
end
|
48
|
|
|
49
|
1
|
copy(s::State{T}) where T = State{T}(s.round, s.regret, s.reward, s.arm_counts, s.arm_reward, s.arm_average_reward)
|
50
|
|
|
51
|
|
# Define the trace of the execution throughout the rounds.
|
52
|
|
struct Trace{T}
|
53
|
1
|
states::Vector{State{T}}
|
54
|
|
arms::Vector{Vector{T}}
|
55
|
|
reward::Vector{Vector{Float64}}
|
56
|
|
policy_details::Vector{PolicyDetails}
|
57
|
|
time_choose_action::Vector{Int}
|
58
|
|
end
|
59
|
|
|
60
|
|
"""
|
61
|
|
push!(trace::Trace{T}, state::State{T}, arms::Vector{T}, reward::Vector{Float64}, policy_details::PolicyDetails, time_choose_action::Int) where T
|
62
|
|
|
63
|
|
Appends the arguments to the execution trace of the bandit algorithm. More specifically, `trace`'s data structures are
|
64
|
|
updated to also include `state`, `arms`, `reward`, `policy_details`, and `time_choose_action` (expressed in milliseconds).
|
65
|
|
All of these arguments are copied, *except* `policy_details`.
|
66
|
|
(Indeed, the usual scenario is to keep updating the state, the arms and the rewards, but to build the details at each round from the ground up.)
|
67
|
|
"""
|
68
|
|
function push!(trace::Trace{T}, state::State{T}, arms::Vector{T}, reward::Vector{Float64}, policy_details::PolicyDetails, time_choose_action::Int) where T
|
69
|
1
|
push!(trace.states, copy(state))
|
70
|
1
|
push!(trace.arms, copy(arms))
|
71
|
1
|
push!(trace.reward, copy(reward))
|
72
|
1
|
push!(trace.policy_details, policy_details)
|
73
|
1
|
push!(trace.time_choose_action, time_choose_action)
|
74
|
|
end
|
75
|
|
|
76
|
|
# Interface for combinatorial instances.
|
77
|
0
|
function initial_state(instance::CombinatorialInstance{T}) where T end # TODO: Can't this be provided by default based on template types?
|
78
|
0
|
function initial_trace(instance::CombinatorialInstance{T}) where T end # TODO: Can't this be provided by default based on template types?
|
79
|
0
|
function is_feasible(instance::CombinatorialInstance{T}, arms::Vector{T}) where T end
|
80
|
0
|
function is_partially_acceptable(instance::CombinatorialInstance{T}, arms::Vector{T}) where T end
|
81
|
|
|
82
|
|
function all_arm_indices(reward::Matrix{Distribution})
|
83
|
1
|
reward_indices_cartesian = eachindex(view(reward, [1:s for s in size(reward)]...))
|
84
|
1
|
return [Tuple(i) for i in reward_indices_cartesian]
|
85
|
|
end
|
86
|
|
|
87
|
|
function all_arm_indices(reward::Vector{Distribution})
|
88
|
1
|
return eachindex(view(reward, [1:s for s in size(reward)]...))
|
89
|
|
end
|
90
|
|
|
91
|
|
function all_arm_indices(reward::Dict{T, Distribution}) where T
|
92
|
1
|
return collect(keys(reward))
|
93
|
|
end
|
94
|
|
|
95
|
|
function all_arm_indices(instance::CombinatorialInstance{T}) where T
|
96
|
|
if isa(instance.reward, MultivariateDistribution) # Correlated arms.
|
97
|
|
# Nothing as generic as the uncorrelated case is available, due to
|
98
|
|
# the fact that only a vector is known, for any kind of arms (unlike
|
99
|
|
# the uncorrelated case, where there is a distinction between vectors
|
100
|
|
# and matrices of reward distributions).
|
101
|
1
|
return instance.all_arm_indices
|
102
|
|
else # Uncorrelated arms.
|
103
|
1
|
return all_arm_indices(instance.reward)
|
104
|
|
end
|
105
|
|
end
|
106
|
|
|
107
|
|
function pull(instance::CombinatorialInstance{T}, arms::Vector{T}) where T
|
108
|
|
# Draw the rewards for this round. If T is a tuple, the reward distributions
|
109
|
|
# are stored in a matrix, hence the splatting.
|
110
|
1
|
arm_indices = all_arm_indices(instance)
|
111
|
|
if isa(instance.reward, MultivariateDistribution) # Correlated arms.
|
112
|
1
|
true_rewards_vector = mean(instance.reward)
|
113
|
1
|
true_rewards = Dict{T, Float64}(arm => true_rewards_vector[i] for (i, arm) in enumerate(arm_indices))
|
114
|
|
|
115
|
1
|
drawn_rewards_vector = rand(instance.reward)
|
116
|
1
|
drawn_rewards = Dict{T, Float64}(arm => true_rewards_vector[i] for (i, arm) in enumerate(arm_indices))
|
117
|
|
else # Uncorrelated arms.
|
118
|
1
|
true_rewards = Dict{T, Float64}(i => mean(instance.reward[i...]) for i in arm_indices)
|
119
|
1
|
drawn_rewards = Dict{T, Float64}(i => rand(instance.reward[i...]) for i in arm_indices)
|
120
|
|
end
|
121
|
|
|
122
|
|
# Select the information that will be provided back to the bandit policy.
|
123
|
|
# Here is implemented the semi-bandit setting.
|
124
|
1
|
true_reward = sum(true_rewards[arm] for arm in arms)
|
125
|
1
|
reward = Float64[drawn_rewards[arm] for arm in arms]
|
126
|
|
|
127
|
|
# Compute the incurred regret from the provided solution.
|
128
|
1
|
incurred_regret = instance.optimal_average_reward - sum(true_reward)
|
129
|
|
|
130
|
1
|
return reward, incurred_regret
|
131
|
|
end
|
132
|
|
|
133
|
1
|
solve_linear(instance::CombinatorialInstance{T}, rewards::Dict{T, Float64}) where T = solve_linear(instance.solver, rewards)
|
134
|
1
|
solve_budgeted_linear(instance::CombinatorialInstance{T}, rewards::Dict{T, Float64}, weights::Dict{T, Int}, budget::Int) where T =
|
135
|
|
solve_budgeted_linear(instance.solver, rewards, weights, budget)
|
136
|
0
|
solve_all_budgeted_linear(instance::CombinatorialInstance{T}, rewards::Dict{T, Float64}, weights::Dict{T, Int}, max_budget::Int) where T =
|
137
|
|
solve_all_budgeted_linear(instance.solver, rewards, weights, max_budget)
|
138
|
0
|
has_lp_formulation(instance::CombinatorialInstance{T}) where T = has_lp_formulation(instance.solver)
|
139
|
0
|
get_lp_formulation(instance::CombinatorialInstance{T}, rewards::Dict{T, Float64}) where T = has_lp_formulation(instance) ?
|
140
|
|
get_lp_formulation(instance.solver, rewards) :
|
141
|
|
error("The chosen solver uses no LP formulation.")
|
142
|
|
|
143
|
|
# Implement the most common case. For tuples, the number vary more widely.
|
144
|
|
function initial_state(instance::CombinatorialInstance{Int})
|
145
|
1
|
n = instance.n_arms
|
146
|
1
|
zero_counts = Dict(i => 0 for i in 1:n)
|
147
|
1
|
zero_rewards = Dict(i => 0.0 for i in 1:n)
|
148
|
1
|
return State{Int}(0, 0.0, 0.0, zero_counts, zero_rewards, copy(zero_rewards))
|
149
|
|
end
|
150
|
|
|
151
|
|
function initial_trace(instance::CombinatorialInstance{Int})
|
152
|
1
|
return Trace{Int}(State{Int}[], Vector{Int}[], Vector{Float64}[], PolicyDetails[], Int[])
|
153
|
|
end
|
154
|
|
|
155
|
|
function initial_trace(instance::CombinatorialInstance{Tuple{Int, Int}})
|
156
|
1
|
return Trace{Tuple{Int, Int}}(State{Tuple{Int, Int}}[], Vector{Tuple{Int, Int}}[], Vector{Float64}[], PolicyDetails[], Int[])
|
157
|
|
end
|
158
|
|
|
159
|
|
# Interface for policies.
|
160
|
0
|
function choose_action(instance::CombinatorialInstance{T}, policy::Policy, state::State{T}) where T end
|
161
|
|
|
162
|
|
# Update the state before the new round.
|
163
|
|
function update!(state::State{T}, instance::CombinatorialInstance{T}, arms::Vector{T}, reward::Vector{Float64}, incurred_regret::Float64) where T
|
164
|
1
|
state.round += 1
|
165
|
|
|
166
|
|
# One reward per arm, i.e. semi-bandit feedback (not bandit feedback, where there would be only one reward for all arms).
|
167
|
1
|
for i in 1:length(arms)
|
168
|
1
|
state.arm_counts[arms[i]] += 1
|
169
|
1
|
state.arm_reward[arms[i]] += reward[i]
|
170
|
1
|
state.arm_average_reward[arms[i]] = state.arm_reward[arms[i]] / state.arm_counts[arms[i]]
|
171
|
1
|
state.reward += reward[i]
|
172
|
|
end
|
173
|
|
|
174
|
1
|
state.regret += incurred_regret
|
175
|
|
end
|
176
|
|
|
177
|
|
# Use the bandit for the given number of steps.
|
178
|
|
function simulate(instance::CombinatorialInstance{T}, policy::Policy, steps::Int; with_trace::Bool=false) where T
|
179
|
1
|
state = initial_state(instance)
|
180
|
1
|
if with_trace
|
181
|
1
|
trace = initial_trace(instance)
|
182
|
|
end
|
183
|
|
|
184
|
1
|
for i in 1:steps
|
185
|
1
|
t0 = time_ns()
|
186
|
1
|
if with_trace
|
187
|
1
|
arms, run_details = choose_action(instance, policy, state, with_trace=true)
|
188
|
|
else
|
189
|
1
|
arms = choose_action(instance, policy, state, with_trace=false)
|
190
|
|
end
|
191
|
1
|
t1 = time_ns()
|
192
|
|
|
193
|
1
|
if length(arms) == 0
|
194
|
0
|
error("No arms have been chosen at round $(i)!")
|
195
|
|
end
|
196
|
|
|
197
|
1
|
reward, incurred_regret = pull(instance, arms)
|
198
|
1
|
update!(state, instance, arms, reward, incurred_regret)
|
199
|
|
|
200
|
1
|
if with_trace
|
201
|
1
|
push!(trace, state, arms, reward, run_details, round(Int, (t1 - t0) / 1_000_000_000))
|
202
|
|
end
|
203
|
|
|
204
|
1
|
if i % 100 == 0
|
205
|
1
|
println(i)
|
206
|
|
end
|
207
|
|
end
|
208
|
|
|
209
|
1
|
if ! with_trace
|
210
|
1
|
return state
|
211
|
|
else
|
212
|
1
|
return state, trace
|
213
|
|
end
|
214
|
|
end
|
215
|
|
|
216
|
|
## Combinatorial algorithms. TODO: put this in another package.
|
217
|
|
include("algos/helpers.jl")
|
218
|
|
include("algos/ep.jl")
|
219
|
|
include("algos/ep_budgeted.jl")
|
220
|
|
include("algos/matching.jl")
|
221
|
|
include("algos/matching_budgeted.jl")
|
222
|
|
include("algos/msets.jl")
|
223
|
|
include("algos/msets_budgeted.jl")
|
224
|
|
include("algos/st.jl")
|
225
|
|
include("algos/st_budgeted.jl")
|
226
|
|
|
227
|
|
## Bandit policies.
|
228
|
|
include("policies/thompson.jl")
|
229
|
|
include("policies/llr.jl")
|
230
|
|
include("policies/cucb.jl")
|
231
|
|
include("policies/escb2.jl")
|
232
|
|
include("policies/olsucb.jl")
|
233
|
|
|
234
|
|
include("policies/escb2_exact.jl")
|
235
|
|
include("policies/escb2_greedy.jl")
|
236
|
|
include("policies/escb2_budgeted.jl")
|
237
|
|
|
238
|
|
include("policies/olsucb_greedy.jl")
|
239
|
|
|
240
|
|
## Potential problems to solve.
|
241
|
|
include("instances/perfectbipartitematching.jl")
|
242
|
|
include("instances/perfectbipartitematching_algos.jl")
|
243
|
|
include("instances/perfectbipartitematching_lp.jl")
|
244
|
|
include("instances/perfectbipartitematching_munkres.jl")
|
245
|
|
include("instances/perfectbipartitematching_hungarian.jl")
|
246
|
|
# Not using LightGraphsMatching, due to BlossomV dependency (hard to get to work…)
|
247
|
|
|
248
|
|
include("instances/elementarypath.jl")
|
249
|
|
include("instances/elementarypath_algos.jl")
|
250
|
|
include("instances/elementarypath_lightgraphsdijkstra.jl")
|
251
|
|
include("instances/elementarypath_lp.jl")
|
252
|
|
|
253
|
|
include("instances/spanningtree.jl")
|
254
|
|
include("instances/spanningtree_algos.jl")
|
255
|
|
include("instances/spanningtree_lightgraphsprim.jl")
|
256
|
|
include("instances/spanningtree_lp.jl")
|
257
|
|
|
258
|
|
include("instances/mset.jl")
|
259
|
|
include("instances/mset_algos.jl")
|
260
|
|
include("instances/mset_lp.jl")
|
261
|
|
end
|