1 1
module CombinatorialBandits
2
  using LinearAlgebra
3

4
  using DataStructures
5
  using IterTools
6
  using Distributions
7

8
  import Munkres # Avoid clashing with Hungarian.
9
  import Hungarian
10
  using LightGraphs
11
  using JuMP
12

13
  import Base: push!, copy, hash, isequal
14
  import JuMP: value
15

16
  export Policy, CombinatorialInstance, State,
17
         initial_state, initial_trace, simulate, choose_action, pull, update!, solve_linear, solve_budgeted_linear, solve_all_budgeted_linear, get_lp_formulation, has_lp_formulation, is_feasible, is_partially_acceptable,
18
         ThompsonSampling, LLR, CUCB, ESCB2, OLSUCB,
19
         ThompsonSamplingDetails, LLRDetails, CUCBDetails, ESCB2Details, OLSUCBDetails,
20
         ESCB2OptimisationAlgorithm, ESCB2Exact, ESCB2Greedy, ESCB2Budgeted, OLSUCBOptimisationAlgorithm, OLSUCBGreedy,
21
         PerfectBipartiteMatching, UncorrelatedPerfectBipartiteMatching, CorrelatedPerfectBipartiteMatching, PerfectBipartiteMatchingSolver, PerfectBipartiteMatchingLPSolver, PerfectBipartiteMatchingMunkresSolver, PerfectBipartiteMatchingHungarianSolver, PerfectBipartiteMatchingAlgosSolver,
22
         ElementaryPath, ElementaryPathSolver, ElementaryPathLightGraphsDijkstraSolver, ElementaryPathLPSolver, ElementaryPathAlgosSolver,
23
         SpanningTree, SpanningTreeSolver, SpanningTreeLightGraphsPrimSolver, SpanningTreeAlgosSolver, SpanningTreeLPSolver,
24
         MSet, MSetSolver, MSetAlgosSolver, MSetLPSolver,
25
         # Algos.
26
         MSetInstance, MSetSolution, dimension, m, value, values, msets_greedy, msets_dp, msets_lp,
27
         BudgetedMSetInstance, BudgetedMSetSolution, weight, weights, budget, max_weight, items, items_all_budgets, budgeted_msets_dp, budgeted_msets_lp, budgeted_msets_lp_select, budgeted_msets_lp_all,
28
         ElementaryPathInstance, ElementaryPathSolution, graph, costs, src, dst, cost, lp_dp,
29
         BudgetedElementaryPathInstance, BudgetedElementaryPathSolution, rewards, reward, budgeted_lp_dp,
30
         SpanningTreeInstance, SpanningTreeSolution, st_prim,
31
         BudgetedSpanningTreeInstance, BudgetedSpanningTreeSolution, BudgetedSpanningTreeLagrangianSolution, SimpleBudgetedSpanningTreeSolution, _budgeted_spanning_tree_compute_value, _budgeted_spanning_tree_compute_weight, st_prim_budgeted_lagrangian, st_prim_budgeted_lagrangian_search, _solution_symmetric_difference, _solution_symmetric_difference_size, st_prim_budgeted_lagrangian_refinement, st_prim_budgeted_lagrangian_approx_half,
32
         BipartiteMatchingInstance, BipartiteMatchingSolution, matching_hungarian, BudgetedBipartiteMatchingInstance, BudgetedBipartiteMatchingSolution, BudgetedBipartiteMatchingLagrangianSolution, SimpleBudgetedBipartiteMatchingSolution, matching_hungarian_budgeted_lagrangian, matching_hungarian_budgeted_lagrangian_search, matching_hungarian_budgeted_lagrangian_refinement, matching_hungarian_budgeted_lagrangian_approx_half
33

34
  # General algorithm.
35
  abstract type Policy end
36
  abstract type PolicyDetails end
37
  abstract type CombinatorialInstance{T} end
38

39
  # Define the state of a bandit (evolves at each round).
40
  mutable struct State{T}
41 1
    round::Int
42
    regret::Float64
43
    reward::Float64
44
    arm_counts::Dict{T, Int}
45
    arm_reward::Dict{T, Float64}
46
    arm_average_reward::Dict{T, Float64}
47
  end
48

49 1
  copy(s::State{T}) where T = State{T}(s.round, s.regret, s.reward, s.arm_counts, s.arm_reward, s.arm_average_reward)
50

51
  # Define the trace of the execution throughout the rounds.
52
  struct Trace{T}
53 1
    states::Vector{State{T}}
54
    arms::Vector{Vector{T}}
55
    reward::Vector{Vector{Float64}}
56
    policy_details::Vector{PolicyDetails}
57
    time_choose_action::Vector{Int}
58
  end
59

60
  """
61
      push!(trace::Trace{T}, state::State{T}, arms::Vector{T}, reward::Vector{Float64}, policy_details::PolicyDetails, time_choose_action::Int) where T
62

63
  Appends the arguments to the execution trace of the bandit algorithm. More specifically, `trace`'s data structures are
64
  updated to also include `state`, `arms`, `reward`, `policy_details`, and `time_choose_action` (expressed in milliseconds).
65
  All of these arguments are copied, *except* `policy_details`.
66
  (Indeed, the usual scenario is to keep updating the state, the arms and the rewards, but to build the details at each round from the ground up.)
67
  """
68
  function push!(trace::Trace{T}, state::State{T}, arms::Vector{T}, reward::Vector{Float64}, policy_details::PolicyDetails, time_choose_action::Int) where T
69 1
    push!(trace.states, copy(state))
70 1
    push!(trace.arms, copy(arms))
71 1
    push!(trace.reward, copy(reward))
72 1
    push!(trace.policy_details, policy_details)
73 1
    push!(trace.time_choose_action, time_choose_action)
74
  end
75

76
  # Interface for combinatorial instances.
77 0
  function initial_state(instance::CombinatorialInstance{T}) where T end # TODO: Can't this be provided by default based on template types?
78 0
  function initial_trace(instance::CombinatorialInstance{T}) where T end # TODO: Can't this be provided by default based on template types?
79 0
  function is_feasible(instance::CombinatorialInstance{T}, arms::Vector{T}) where T end
80 0
  function is_partially_acceptable(instance::CombinatorialInstance{T}, arms::Vector{T}) where T end
81

82
  function all_arm_indices(reward::Matrix{Distribution})
83 1
    reward_indices_cartesian = eachindex(view(reward, [1:s for s in size(reward)]...))
84 1
    return [Tuple(i) for i in reward_indices_cartesian]
85
  end
86

87
  function all_arm_indices(reward::Vector{Distribution})
88 1
    return eachindex(view(reward, [1:s for s in size(reward)]...))
89
  end
90

91
  function all_arm_indices(reward::Dict{T, Distribution}) where T
92 1
    return collect(keys(reward))
93
  end
94

95
  function all_arm_indices(instance::CombinatorialInstance{T}) where T
96
    if isa(instance.reward, MultivariateDistribution) # Correlated arms.
97
      # Nothing as generic as the uncorrelated case is available, due to
98
      # the fact that only a vector is known, for any kind of arms (unlike
99
      # the uncorrelated case, where there is a distinction between vectors
100
      # and matrices of reward distributions).
101 1
      return instance.all_arm_indices
102
    else # Uncorrelated arms.
103 1
      return all_arm_indices(instance.reward)
104
    end
105
  end
106

107
  function pull(instance::CombinatorialInstance{T}, arms::Vector{T}) where T
108
    # Draw the rewards for this round. If T is a tuple, the reward distributions
109
    # are stored in a matrix, hence the splatting.
110 1
    arm_indices = all_arm_indices(instance)
111
    if isa(instance.reward, MultivariateDistribution) # Correlated arms.
112 1
      true_rewards_vector = mean(instance.reward)
113 1
      true_rewards = Dict{T, Float64}(arm => true_rewards_vector[i] for (i, arm) in enumerate(arm_indices))
114

115 1
      drawn_rewards_vector = rand(instance.reward)
116 1
      drawn_rewards = Dict{T, Float64}(arm => true_rewards_vector[i] for (i, arm) in enumerate(arm_indices))
117
    else # Uncorrelated arms.
118 1
      true_rewards = Dict{T, Float64}(i => mean(instance.reward[i...]) for i in arm_indices)
119 1
      drawn_rewards = Dict{T, Float64}(i => rand(instance.reward[i...]) for i in arm_indices)
120
    end
121

122
    # Select the information that will be provided back to the bandit policy.
123
    # Here is implemented the semi-bandit setting.
124 1
    true_reward = sum(true_rewards[arm] for arm in arms)
125 1
    reward = Float64[drawn_rewards[arm] for arm in arms]
126

127
    # Compute the incurred regret from the provided solution.
128 1
    incurred_regret = instance.optimal_average_reward - sum(true_reward)
129

130 1
    return reward, incurred_regret
131
  end
132

133 1
  solve_linear(instance::CombinatorialInstance{T}, rewards::Dict{T, Float64}) where T = solve_linear(instance.solver, rewards)
134 1
  solve_budgeted_linear(instance::CombinatorialInstance{T}, rewards::Dict{T, Float64}, weights::Dict{T, Int}, budget::Int) where T =
135
    solve_budgeted_linear(instance.solver, rewards, weights, budget)
136 0
  solve_all_budgeted_linear(instance::CombinatorialInstance{T}, rewards::Dict{T, Float64}, weights::Dict{T, Int}, max_budget::Int) where T =
137
    solve_all_budgeted_linear(instance.solver, rewards, weights, max_budget)
138 0
  has_lp_formulation(instance::CombinatorialInstance{T}) where T = has_lp_formulation(instance.solver)
139 0
  get_lp_formulation(instance::CombinatorialInstance{T}, rewards::Dict{T, Float64}) where T = has_lp_formulation(instance) ?
140
    get_lp_formulation(instance.solver, rewards) :
141
    error("The chosen solver uses no LP formulation.")
142

143
  # Implement the most common case. For tuples, the number vary more widely.
144
  function initial_state(instance::CombinatorialInstance{Int})
145 1
    n = instance.n_arms
146 1
    zero_counts = Dict(i => 0 for i in 1:n)
147 1
    zero_rewards = Dict(i => 0.0 for i in 1:n)
148 1
    return State{Int}(0, 0.0, 0.0, zero_counts, zero_rewards, copy(zero_rewards))
149
  end
150

151
  function initial_trace(instance::CombinatorialInstance{Int})
152 1
    return Trace{Int}(State{Int}[], Vector{Int}[], Vector{Float64}[], PolicyDetails[], Int[])
153
  end
154

155
  function initial_trace(instance::CombinatorialInstance{Tuple{Int, Int}})
156 1
    return Trace{Tuple{Int, Int}}(State{Tuple{Int, Int}}[], Vector{Tuple{Int, Int}}[], Vector{Float64}[], PolicyDetails[], Int[])
157
  end
158

159
  # Interface for policies.
160 0
  function choose_action(instance::CombinatorialInstance{T}, policy::Policy, state::State{T}) where T end
161

162
  # Update the state before the new round.
163
  function update!(state::State{T}, instance::CombinatorialInstance{T}, arms::Vector{T}, reward::Vector{Float64}, incurred_regret::Float64) where T
164 1
    state.round += 1
165

166
    # One reward per arm, i.e. semi-bandit feedback (not bandit feedback, where there would be only one reward for all arms).
167 1
    for i in 1:length(arms)
168 1
      state.arm_counts[arms[i]] += 1
169 1
      state.arm_reward[arms[i]] += reward[i]
170 1
      state.arm_average_reward[arms[i]] = state.arm_reward[arms[i]] / state.arm_counts[arms[i]]
171 1
      state.reward += reward[i]
172
    end
173

174 1
    state.regret += incurred_regret
175
  end
176

177
  # Use the bandit for the given number of steps.
178
  function simulate(instance::CombinatorialInstance{T}, policy::Policy, steps::Int; with_trace::Bool=false) where T
179 1
    state = initial_state(instance)
180 1
    if with_trace
181 1
      trace = initial_trace(instance)
182
    end
183

184 1
    for i in 1:steps
185 1
      t0 = time_ns()
186 1
      if with_trace
187 1
        arms, run_details = choose_action(instance, policy, state, with_trace=true)
188
      else
189 1
        arms = choose_action(instance, policy, state, with_trace=false)
190
      end
191 1
      t1 = time_ns()
192

193 1
      if length(arms) == 0
194 0
        error("No arms have been chosen at round $(i)!")
195
      end
196

197 1
      reward, incurred_regret = pull(instance, arms)
198 1
      update!(state, instance, arms, reward, incurred_regret)
199

200 1
      if with_trace
201 1
        push!(trace, state, arms, reward, run_details, round(Int, (t1 - t0) / 1_000_000_000))
202
      end
203

204 1
      if i % 100 == 0
205 1
        println(i)
206
      end
207
    end
208

209 1
    if ! with_trace
210 1
      return state
211
    else
212 1
      return state, trace
213
    end
214
  end
215

216
  ## Combinatorial algorithms. TODO: put this in another package.
217
  include("algos/helpers.jl")
218
  include("algos/ep.jl")
219
  include("algos/ep_budgeted.jl")
220
  include("algos/matching.jl")
221
  include("algos/matching_budgeted.jl")
222
  include("algos/msets.jl")
223
  include("algos/msets_budgeted.jl")
224
  include("algos/st.jl")
225
  include("algos/st_budgeted.jl")
226

227
  ## Bandit policies.
228
  include("policies/thompson.jl")
229
  include("policies/llr.jl")
230
  include("policies/cucb.jl")
231
  include("policies/escb2.jl")
232
  include("policies/olsucb.jl")
233

234
  include("policies/escb2_exact.jl")
235
  include("policies/escb2_greedy.jl")
236
  include("policies/escb2_budgeted.jl")
237

238
  include("policies/olsucb_greedy.jl")
239

240
  ## Potential problems to solve.
241
  include("instances/perfectbipartitematching.jl")
242
  include("instances/perfectbipartitematching_algos.jl")
243
  include("instances/perfectbipartitematching_lp.jl")
244
  include("instances/perfectbipartitematching_munkres.jl")
245
  include("instances/perfectbipartitematching_hungarian.jl")
246
  # Not using LightGraphsMatching, due to BlossomV dependency (hard to get to work…)
247

248
  include("instances/elementarypath.jl")
249
  include("instances/elementarypath_algos.jl")
250
  include("instances/elementarypath_lightgraphsdijkstra.jl")
251
  include("instances/elementarypath_lp.jl")
252

253
  include("instances/spanningtree.jl")
254
  include("instances/spanningtree_algos.jl")
255
  include("instances/spanningtree_lightgraphsprim.jl")
256
  include("instances/spanningtree_lp.jl")
257

258
  include("instances/mset.jl")
259
  include("instances/mset_algos.jl")
260
  include("instances/mset_lp.jl")
261
end

Read our documentation on viewing source code .

Loading