JuliaPOMDP / ARDESPOT.jl
Showing 3 of 5 files from the diff.

@@ -40,7 +40,6 @@
Loading
40 40
41 41
    ReportWhenUsed
42 42
43 -
44 43
# include("random.jl")
45 44
include("random_2.jl")
46 45

@@ -1,10 +1,10 @@
Loading
1 -
function branching_sim(pomdp::POMDP, policy::Policy, b::ScenarioBelief, steps::Integer)
1 +
function branching_sim(pomdp::POMDP, policy::Policy, b::ScenarioBelief, steps::Integer, fval)
2 2
    S = statetype(pomdp)
3 3
    O = obstype(pomdp)
4 4
    odict = Dict{O, Vector{Pair{Int, S}}}()
5 5
6 6
    if steps <= 0
7 -
        return 0.0
7 +
        return length(b.scenarios)*fval(pomdp, b)
8 8
    end
9 9
10 10
    a = action(policy, b)
@@ -29,9 +29,9 @@
Loading
29 29
    for (o, scenarios) in odict 
30 30
        bp = ScenarioBelief(scenarios, b.random_source, b.depth+1, o)
31 31
        if length(scenarios) == 1
32 -
            next_r += rollout(pomdp, policy, bp, steps-1)
32 +
            next_r += rollout(pomdp, policy, bp, steps-1, fval)
33 33
        else
34 -
            next_r += branching_sim(pomdp, policy, bp, steps-1)
34 +
            next_r += branching_sim(pomdp, policy, bp, steps-1, fval)
35 35
        end
36 36
    end
37 37
@@ -39,7 +39,7 @@
Loading
39 39
end
40 40
41 41
# once there is only one scenario left, just run a rollout
42 -
function rollout(pomdp::POMDP, policy::Policy, b0::ScenarioBelief, steps::Integer)
42 +
function rollout(pomdp::POMDP, policy::Policy, b0::ScenarioBelief, steps::Integer, fval)
43 43
    @assert length(b0.scenarios) == 1
44 44
    disc = 1.0
45 45
    r_total = 0.0
@@ -63,5 +63,9 @@
Loading
63 63
        steps -= 1
64 64
    end
65 65
66 +
    if steps == 0 && !isterminal(pomdp, s)
67 +
        r_total += disc*fval(pomdp, b)
68 +
    end
69 +
66 70
    return r_total
67 71
end

@@ -75,26 +75,39 @@
Loading
75 75
76 76
# Default Policy Lower Bound
77 77
78 -
struct DefaultPolicyLB{P<:Union{Solver, Policy}, D<:Union{Nothing,Int}}
78 +
"""
79 +
    DefaultPolicyLB(policy; max_depth=nothing, final_value=(m,x)->0.0)
80 +
    DefaultPolicyLB(solver; max_depth=nothing, final_value=(m,x)->0.0)
81 +
82 +
A lower bound calculated by running a default policy on the scenarios in a belief.
83 +
84 +
# Keyword Arguments
85 +
- `max_depth::Union{Nothing,Int}=nothing`: max depth to run the simulation. The depth of the belief will be automatically subtracted so simulations for the bound will be run for `max_depth-b.depth` steps. If `nothing`, the solver's max depth will be used.
86 +
- `final_value=(m,x)->0.0`: a function (or callable object) that specifies an additional value to be added at the end of the simulation when `max_depth` is reached. This function will be called with two arguments, a `POMDP`, and a `ScenarioBelief`. It will not be called when the states in the belief are terminal.
87 +
"""
88 +
struct DefaultPolicyLB{P<:Union{Solver, Policy}, D<:Union{Nothing,Int}, T}
79 89
    policy::P
80 90
    max_depth::D
91 +
    final_value::T
81 92
end
82 93
83 -
function DefaultPolicyLB(policy_or_solver::T; max_depth=nothing) where T <: Union{Solver, Policy}
84 -
    return DefaultPolicyLB(policy_or_solver, max_depth)
94 +
function DefaultPolicyLB(policy_or_solver::T;
95 +
                         max_depth=nothing,
96 +
                         final_value=(m,x)->0.0) where T <: Union{Solver, Policy}
97 +
    return DefaultPolicyLB(policy_or_solver, max_depth, final_value)
85 98
end
86 99
87 100
function lbound(lb::DefaultPolicyLB, pomdp::POMDP, b::ScenarioBelief)
88 -
    rsum = branching_sim(pomdp, lb.policy, b, lb.max_depth-b.depth)
101 +
    rsum = branching_sim(pomdp, lb.policy, b, lb.max_depth-b.depth, lb.final_value)
89 102
    return rsum/length(b.scenarios)
90 103
end
91 104
92 105
function init_bound(lb::DefaultPolicyLB{S}, pomdp::POMDP, sol::DESPOTSolver) where S <: Solver
93 106
    policy = solve(lb.policy, pomdp)
94 -
    return init_bound(DefaultPolicyLB(policy, lb.max_depth), pomdp, sol)
107 +
    return init_bound(DefaultPolicyLB(policy, lb.max_depth, lb.final_value), pomdp, sol)
95 108
end
96 109
97 110
function init_bound(lb::DefaultPolicyLB{P}, pomdp::POMDP, sol::DESPOTSolver) where P <: Policy
98 111
    max_depth = something(lb.max_depth, sol.D)
99 -
    return DefaultPolicyLB(lb.policy, max_depth)
112 +
    return DefaultPolicyLB(lb.policy, max_depth, lb.final_value)
100 113
end
Files Coverage
src 47.43%
Project Totals (12 files) 47.43%
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading