35 |
60 |
|
|
36 |
61 |
|
See Also |
37 |
62 |
|
-------- |
38 |
|
- |
problinsolve : Solve linear systems in a Bayesian framework. |
39 |
|
- |
bayescg : Solve linear systems with prior information on the solution. |
|
63 |
+ |
~probnum.linalg.problinsolve : Solve linear systems in a Bayesian framework. |
|
64 |
+ |
~probnum.linalg.bayescg : Solve linear systems with prior information on the solution. |
40 |
65 |
|
|
41 |
66 |
|
Examples |
42 |
67 |
|
-------- |
|
68 |
+ |
Define a linear system. |
|
69 |
+ |
|
|
70 |
+ |
>>> import numpy as np |
|
71 |
+ |
>>> from probnum.problems import LinearSystem |
|
72 |
+ |
>>> from probnum.problems.zoo.linalg import random_spd_matrix |
|
73 |
+ |
|
|
74 |
+ |
>>> rng = np.random.default_rng(42) |
|
75 |
+ |
>>> n = 100 |
|
76 |
+ |
>>> A = random_spd_matrix(rng=rng, dim=n) |
|
77 |
+ |
>>> b = rng.standard_normal(size=(n,)) |
|
78 |
+ |
>>> linsys = LinearSystem(A=A, b=b) |
|
79 |
+ |
|
|
80 |
+ |
Create a custom probabilistic linear solver from pre-defined components. |
|
81 |
+ |
|
|
82 |
+ |
>>> from probnum.linalg.solvers import ( |
|
83 |
+ |
... ProbabilisticLinearSolver, |
|
84 |
+ |
... belief_updates, |
|
85 |
+ |
... beliefs, |
|
86 |
+ |
... information_ops, |
|
87 |
+ |
... policies, |
|
88 |
+ |
... stopping_criteria, |
|
89 |
+ |
... ) |
|
90 |
+ |
|
|
91 |
+ |
>>> pls = ProbabilisticLinearSolver( |
|
92 |
+ |
... policy=policies.ConjugateGradientPolicy(), |
|
93 |
+ |
... information_op=information_ops.ProjectedRHSInformationOp(), |
|
94 |
+ |
... belief_update=belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate(), |
|
95 |
+ |
... stopping_criterion=( |
|
96 |
+ |
... stopping_criteria.MaxIterationsStoppingCriterion(100) |
|
97 |
+ |
... | stopping_criteria.ResidualNormStoppingCriterion(atol=1e-5, rtol=1e-5) |
|
98 |
+ |
... ), |
|
99 |
+ |
... ) |
|
100 |
+ |
|
|
101 |
+ |
Define a prior over the solution. |
|
102 |
+ |
|
|
103 |
+ |
>>> from probnum import linops, randvars |
|
104 |
+ |
>>> prior = beliefs.LinearSystemBelief( |
|
105 |
+ |
... x=randvars.Normal( |
|
106 |
+ |
... mean=np.zeros((n,)), |
|
107 |
+ |
... cov=np.eye(n), |
|
108 |
+ |
... ), |
|
109 |
+ |
... ) |
|
110 |
+ |
|
|
111 |
+ |
Solve the linear system using the custom solver. |
|
112 |
+ |
|
|
113 |
+ |
>>> belief, solver_state = pls.solve(prior=prior, problem=linsys) |
|
114 |
+ |
>>> np.linalg.norm(linsys.A @ belief.x.mean - linsys.b) / np.linalg.norm(linsys.b) |
|
115 |
+ |
7.1886e-06 |
|
116 |
+ |
""" |
|
117 |
+ |
|
|
118 |
+ |
def __init__( |
|
119 |
+ |
self, |
|
120 |
+ |
policy: policies.LinearSolverPolicy, |
|
121 |
+ |
information_op: information_ops.LinearSolverInformationOp, |
|
122 |
+ |
belief_update: belief_updates.LinearSolverBeliefUpdate, |
|
123 |
+ |
stopping_criterion: stopping_criteria.LinearSolverStoppingCriterion, |
|
124 |
+ |
): |
|
125 |
+ |
self.policy = policy |
|
126 |
+ |
self.information_op = information_op |
|
127 |
+ |
self.belief_update = belief_update |
|
128 |
+ |
super().__init__(stopping_criterion=stopping_criterion) |
|
129 |
+ |
|
|
130 |
+ |
def solve_iterator( |
|
131 |
+ |
self, |
|
132 |
+ |
prior: beliefs.LinearSystemBelief, |
|
133 |
+ |
problem: problems.LinearSystem, |
|
134 |
+ |
rng: Optional[np.random.Generator] = None, |
|
135 |
+ |
) -> Generator[LinearSolverState, None, None]: |
|
136 |
+ |
"""Generator implementing the solver iteration. |
|
137 |
+ |
|
|
138 |
+ |
This function allows stepping through the solver iteration one step at a time |
|
139 |
+ |
and exposes the internal solver state. |
|
140 |
+ |
|
|
141 |
+ |
Parameters |
|
142 |
+ |
---------- |
|
143 |
+ |
prior |
|
144 |
+ |
Prior belief about the quantities of interest :math:`(x, A, A^{-1}, b)` of the linear system. |
|
145 |
+ |
problem |
|
146 |
+ |
Linear system to be solved. |
|
147 |
+ |
rng |
|
148 |
+ |
Random number generator. |
|
149 |
+ |
|
|
150 |
+ |
Yields |
|
151 |
+ |
------ |
|
152 |
+ |
solver_state |
|
153 |
+ |
State of the probabilistic linear solver. |
|
154 |
+ |
""" |
|
155 |
+ |
solver_state = LinearSolverState(problem=problem, prior=prior, rng=rng) |
|
156 |
+ |
|
|
157 |
+ |
while True: |
|
158 |
+ |
|
|
159 |
+ |
yield solver_state |
|
160 |
+ |
|
|
161 |
+ |
# Check stopping criterion |
|
162 |
+ |
if self.stopping_criterion(solver_state=solver_state): |
|
163 |
+ |
break |
|
164 |
+ |
|
|
165 |
+ |
# Compute action via policy |
|
166 |
+ |
solver_state.action = self.policy(solver_state=solver_state) |
|
167 |
+ |
|
|
168 |
+ |
# Make observation via information operator |
|
169 |
+ |
solver_state.observation = self.information_op(solver_state=solver_state) |
|
170 |
+ |
|
|
171 |
+ |
# Update belief about the quantities of interest |
|
172 |
+ |
solver_state.belief = self.belief_update(solver_state=solver_state) |
|
173 |
+ |
|
|
174 |
+ |
# Advance state to next step and invalidate caches |
|
175 |
+ |
solver_state.next_step() |
|
176 |
+ |
|
|
177 |
+ |
def solve( |
|
178 |
+ |
self, |
|
179 |
+ |
prior: beliefs.LinearSystemBelief, |
|
180 |
+ |
problem: problems.LinearSystem, |
|
181 |
+ |
rng: Optional[np.random.Generator] = None, |
|
182 |
+ |
) -> Tuple[beliefs.LinearSystemBelief, LinearSolverState]: |
|
183 |
+ |
r"""Solve the linear system. |
|
184 |
+ |
|
|
185 |
+ |
Parameters |
|
186 |
+ |
---------- |
|
187 |
+ |
prior |
|
188 |
+ |
Prior belief about the quantities of interest :math:`(x, A, A^{-1}, b)` of the linear system. |
|
189 |
+ |
problem |
|
190 |
+ |
Linear system to be solved. |
|
191 |
+ |
rng |
|
192 |
+ |
Random number generator. |
|
193 |
+ |
|
|
194 |
+ |
Returns |
|
195 |
+ |
------- |
|
196 |
+ |
belief |
|
197 |
+ |
Posterior belief :math:`(\mathsf{x}, \mathsf{A}, \mathsf{H}, \mathsf{b})` |
|
198 |
+ |
over the solution :math:`x`, the system matrix :math:`A`, its (pseudo-)inverse :math:`H=A^\dagger` and the right hand side :math:`b`. |
|
199 |
+ |
solver_state |
|
200 |
+ |
Final state of the solver. |
|
201 |
+ |
""" |
|
202 |
+ |
solver_state = None |
|
203 |
+ |
|
|
204 |
+ |
for solver_state in self.solve_iterator(prior=prior, problem=problem, rng=rng): |
|
205 |
+ |
pass |
|
206 |
+ |
|
|
207 |
+ |
return solver_state.belief, solver_state |
|
208 |
+ |
|
|
209 |
+ |
|
|
210 |
+ |
class BayesCG(ProbabilisticLinearSolver): |
|
211 |
+ |
r"""Bayesian conjugate gradient method. |
|
212 |
+ |
|
|
213 |
+ |
Probabilistic linear solver taking prior information about the solution and |
|
214 |
+ |
choosing :math:`A`-conjugate actions to gain information about the solution |
|
215 |
+ |
by projecting the current residual. |
|
216 |
+ |
|
|
217 |
+ |
This code implements the method described in Cockayne et al. [1]_. |
|
218 |
+ |
|
|
219 |
+ |
Parameters |
|
220 |
+ |
---------- |
|
221 |
+ |
stopping_criterion |
|
222 |
+ |
Stopping criterion determining when a desired terminal condition is met. |
|
223 |
+ |
|
|
224 |
+ |
References |
|
225 |
+ |
---------- |
|
226 |
+ |
.. [1] Cockayne, J. et al., A Bayesian Conjugate Gradient Method, *Bayesian |
|
227 |
+ |
Analysis*, 2019 |
|
228 |
+ |
""" |
|
229 |
+ |
|
|
230 |
+ |
def __init__( |
|
231 |
+ |
self, |
|
232 |
+ |
stopping_criterion: stopping_criteria.LinearSolverStoppingCriterion = stopping_criteria.MaxIterationsStoppingCriterion() |
|
233 |
+ |
| stopping_criteria.ResidualNormStoppingCriterion(atol=1e-5, rtol=1e-5), |
|
234 |
+ |
): |
|
235 |
+ |
super().__init__( |
|
236 |
+ |
policy=policies.ConjugateGradientPolicy(), |
|
237 |
+ |
information_op=information_ops.ProjectedRHSInformationOp(), |
|
238 |
+ |
belief_update=belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate(), |
|
239 |
+ |
stopping_criterion=stopping_criterion, |
|
240 |
+ |
) |
|
241 |
+ |
|
|
242 |
+ |
|
|
243 |
+ |
class ProbabilisticKaczmarz(ProbabilisticLinearSolver): |
|
244 |
+ |
r"""Probabilistic Kaczmarz method. |
|
245 |
+ |
|
|
246 |
+ |
Probabilistic analogue of the (randomized) Kaczmarz method [1]_ [2]_, taking prior |
|
247 |
+ |
information about the solution and randomly choosing rows of the matrix :math:`A_i` |
|
248 |
+ |
and entries :math:`b_i` of the right-hand-side to obtain information about the solution. |
|
249 |
+ |
|
|
250 |
+ |
Parameters |
|
251 |
+ |
---------- |
|
252 |
+ |
stopping_criterion |
|
253 |
+ |
Stopping criterion determining when a desired terminal condition is met. |
|
254 |
+ |
|
|
255 |
+ |
References |
|
256 |
+ |
---------- |
|
257 |
+ |
.. [1] Kaczmarz, Stefan, Angenäherte Auflösung von Systemen linearer Gleichungen, |
|
258 |
+ |
*Bulletin International de l'Académie Polonaise des Sciences et des Lettres. Classe des Sciences Mathématiques et Naturelles. Série A, Sciences Mathématiques*, 1937 |
|
259 |
+ |
.. [2] Strohmer, Thomas; Vershynin, Roman, A randomized Kaczmarz algorithm for |
|
260 |
+ |
linear systems with exponential convergence, *Journal of Fourier Analysis and Applications*, 2009 |
|
261 |
+ |
""" |
|
262 |
+ |
|
|
263 |
+ |
def __init__( |
|
264 |
+ |
self, |
|
265 |
+ |
stopping_criterion: stopping_criteria.LinearSolverStoppingCriterion = stopping_criteria.MaxIterationsStoppingCriterion() |
|
266 |
+ |
| stopping_criteria.ResidualNormStoppingCriterion(atol=1e-5, rtol=1e-5), |
|
267 |
+ |
): |
|
268 |
+ |
super().__init__( |
|
269 |
+ |
policy=policies.RandomUnitVectorPolicy(), |
|
270 |
+ |
information_op=information_ops.ProjectedRHSInformationOp(), |
|
271 |
+ |
belief_update=belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate(), |
|
272 |
+ |
stopping_criterion=stopping_criterion, |
|
273 |
+ |
) |
|
274 |
+ |
|
|
275 |
+ |
|
|
276 |
+ |
class MatrixBasedPLS(ProbabilisticLinearSolver): |
|
277 |
+ |
r"""Matrix-based probabilistic linear solver. |
|
278 |
+ |
|
|
279 |
+ |
Probabilistic linear solver updating beliefs over the system matrix and its |
|
280 |
+ |
inverse. The solver makes use of prior information and iteratively infers the matrix and its inverse by matrix-vector multiplication. |
|
281 |
+ |
|
|
282 |
+ |
This code implements the method described in Wenger et al. [1]_. |
|
283 |
+ |
|
|
284 |
+ |
Parameters |
|
285 |
+ |
---------- |
|
286 |
+ |
policy |
|
287 |
+ |
Policy returning actions taken by the solver. |
|
288 |
+ |
stopping_criterion |
|
289 |
+ |
Stopping criterion determining when a desired terminal condition is met. |
|
290 |
+ |
|
|
291 |
+ |
References |
|
292 |
+ |
---------- |
|
293 |
+ |
.. [1] Wenger, J. and Hennig, P., Probabilistic Linear Solvers for Machine Learning, |
|
294 |
+ |
*Advances in Neural Information Processing Systems (NeurIPS)*, 2020 |
|
295 |
+ |
""" |
|
296 |
+ |
|
|
297 |
+ |
def __init__( |
|
298 |
+ |
self, |
|
299 |
+ |
policy: policies.LinearSolverPolicy = policies.ConjugateGradientPolicy(), |
|
300 |
+ |
stopping_criterion: stopping_criteria.LinearSolverStoppingCriterion = stopping_criteria.MaxIterationsStoppingCriterion() |
|
301 |
+ |
| stopping_criteria.ResidualNormStoppingCriterion(atol=1e-5, rtol=1e-5), |
|
302 |
+ |
): |
|
303 |
+ |
super().__init__( |
|
304 |
+ |
policy=policy, |
|
305 |
+ |
information_op=information_ops.MatVecInformationOp(), |
|
306 |
+ |
belief_update=belief_updates.matrix_based.MatrixBasedLinearBeliefUpdate(), |
|
307 |
+ |
stopping_criterion=stopping_criterion, |
|
308 |
+ |
) |
|
309 |
+ |
|
|
310 |
+ |
|
|
311 |
+ |
class SymMatrixBasedPLS(ProbabilisticLinearSolver): |
|
312 |
+ |
r"""Symmetric matrix-based probabilistic linear solver. |
|
313 |
+ |
|
|
314 |
+ |
Probabilistic linear solver updating beliefs over the symmetric system matrix and its inverse. The solver makes use of prior information and iteratively infers the matrix and its inverse by matrix-vector multiplication. |
|
315 |
+ |
|
|
316 |
+ |
This code implements the method described in Wenger et al. [1]_. |
|
317 |
+ |
|
|
318 |
+ |
Parameters |
|
319 |
+ |
---------- |
|
320 |
+ |
policy |
|
321 |
+ |
Policy returning actions taken by the solver. |
|
322 |
+ |
stopping_criterion |
|
323 |
+ |
Stopping criterion determining when a desired terminal condition is met. |
|
324 |
+ |
|
|
325 |
+ |
References |
|
326 |
+ |
---------- |
|
327 |
+ |
.. [1] Wenger, J. and Hennig, P., Probabilistic Linear Solvers for Machine Learning, |
|
328 |
+ |
*Advances in Neural Information Processing Systems (NeurIPS)*, 2020 |
43 |
329 |
|
""" |
|
330 |
+ |
|
|
331 |
+ |
def __init__( |
|
332 |
+ |
self, |
|
333 |
+ |
policy: policies.LinearSolverPolicy = policies.ConjugateGradientPolicy(), |
|
334 |
+ |
stopping_criterion: stopping_criteria.LinearSolverStoppingCriterion = stopping_criteria.MaxIterationsStoppingCriterion() |
|
335 |
+ |
| stopping_criteria.ResidualNormStoppingCriterion(atol=1e-5, rtol=1e-5), |
|
336 |
+ |
): |
|
337 |
+ |
super().__init__( |
|
338 |
+ |
policy=policy, |
|
339 |
+ |
information_op=information_ops.MatVecInformationOp(), |
|
340 |
+ |
belief_update=belief_updates.matrix_based.SymmetricMatrixBasedLinearBeliefUpdate(), |
|
341 |
+ |
stopping_criterion=stopping_criterion, |
|
342 |
+ |
) |