Shortcuts

Source code for pypose.module.kalman

import torch
from torch import nn
from ..basics import bmv
from torch.linalg import pinv


[docs]class EKF(nn.Module): r''' Performs Batched Extended Kalman Filter (EKF). Args: model (:obj:`System`): The system model to be estimated, a subclass of :obj:`pypose.module.System`. Q (:obj:`Tensor`, optional): The covariance matrices of system transition noise. Ignored if provided during each iteration. Default: ``None`` R (:obj:`Tensor`, optional): The covariance matrices of system observation noise. Ignored if provided during each iteration. Default: ``None`` A non-linear system can be described as .. math:: \begin{aligned} \mathbf{x}_{k+1} &= \mathbf{f}(\mathbf{x}_k, \mathbf{u}_k, t_k) + \mathbf{w}_k, \quad \mathbf{w}_k \sim \mathcal{N}(\mathbf{0}, \mathbf{Q}) \\ \mathbf{y}_{k} &= \mathbf{g}(\mathbf{x}_k, \mathbf{u}_k, t_k) + \mathbf{v}_k, \quad \mathbf{v}_k \sim \mathcal{N}(\mathbf{0}, \mathbf{R}) \end{aligned} It will be linearized automatically: .. math:: \begin{align*} \mathbf{z}_{k+1} = \mathbf{A}_{k}\mathbf{x}_{k} + \mathbf{B}_{k}\mathbf{u}_{k} + \mathbf{c}_{k}^1 + \mathbf{w}_k\\ \mathbf{y}_{k} = \mathbf{C}_{k}\mathbf{x}_{k} + \mathbf{D}_{k}\mathbf{u}_{k} + \mathbf{c}_{k}^2 + \mathbf{v}_k\\ \end{align*} EKF can be described as the following five equations, where the subscript :math:`\cdot_{k}` is omited for simplicity. 1. Priori State Estimation. .. math:: \mathbf{x}^{-} = \mathbf{A}\mathbf{x} + \mathbf{B}\mathbf{u}_k + \mathbf{c}_1 2. Priori Covariance Propagation. .. math:: \mathbf{P}^{-} = \mathbf{A}\mathbf{P}\mathbf{A}^{T} + \mathbf{Q} 3. Update Kalman Gain .. math:: \mathbf{K} = \mathbf{P}\mathbf{C}^{T} (\mathbf{C}\mathbf{P} \mathbf{C}^{T} + \mathbf{R})^{-1} 4. Posteriori State Estimation .. math:: \mathbf{x}^{+} = \mathbf{x}^{-} + \mathbf{K} (\mathbf{y} - \mathbf{C}\mathbf{x}^{-} - \mathbf{D}\mathbf{u} - \mathbf{c}_2) 5. Posteriori Covariance Estimation .. math:: \mathbf{P} = (\mathbf{I} - \mathbf{K}\mathbf{C}) \mathbf{P}^{-} where superscript :math:`\cdot^{-}` and :math:`\cdot^{+}` denote the priori and posteriori estimation, respectively. Example: 1. Define a Nonlinear Time Invariant (NTI) system model >>> import torch, pypose as pp >>> class NTI(pp.module.System): ... def __init__(self): ... super().__init__() ... ... def state_transition(self, state, input, t=None): ... return state.cos() + input ... ... def observation(self, state, input, t): ... return state.sin() + input 2. Create a model and filter >>> model = NTI() >>> ekf = pp.module.EKF(model) 3. Prepare data >>> T, N = 5, 2 # steps, state dim >>> states = torch.zeros(T, N) >>> inputs = torch.randn(T, N) >>> observ = torch.zeros(T, N) >>> # std of transition, observation, and estimation >>> q, r, p = 0.1, 0.1, 10 >>> Q = torch.eye(N) * q**2 >>> R = torch.eye(N) * r**2 >>> P = torch.eye(N).repeat(T, 1, 1) * p**2 >>> estim = torch.randn(T, N) * p 4. Perform EKF prediction. Note that estimation error becomes smaller with more steps. >>> for i in range(T - 1): ... w = q * torch.randn(N) # transition noise ... v = r * torch.randn(N) # observation noise ... states[i+1], observ[i] = model(states[i] + w, inputs[i]) ... estim[i+1], P[i+1] = ekf(estim[i], observ[i] + v, inputs[i], P[i], Q, R) ... print('Est error:', (states - estim).norm(dim=-1)) Est error: tensor([5.7655, 5.3436, 3.5947, 0.3359, 0.0639]) Warning: Don't introduce noise in ``System`` methods ``state_transition`` and ``observation`` for filter testing, as those methods are used for automatically linearizing the system by the parent class ``pypose.module.System``, unless your system model explicitly introduces noise. Note: Implementation is based on Section 5.1 of this book * Dan Simon, `Optimal State Estimation: Kalman, H∞, and Nonlinear Approaches <https://onlinelibrary.wiley.com/doi/epdf/10.1002/0470045345.fmatter>`_, Cleveland State University, 2006 ''' def __init__(self, model, Q=None, R=None): super().__init__() self.set_uncertainty(Q=Q, R=R) self.model = model
[docs] def forward(self, x, y, u, P, Q=None, R=None): r''' Performs one step estimation. Args: x (:obj:`Tensor`): estimated system state of previous step y (:obj:`Tensor`): system observation at current step (measurement) u (:obj:`Tensor`): system input at current step P (:obj:`Tensor`): state estimation covariance of previous step Q (:obj:`Tensor`, optional): covariance of system transition model R (:obj:`Tensor`, optional): covariance of system observation model Return: list of :obj:`Tensor`: posteriori state and covariance estimation ''' # Upper cases are matrices, lower cases are vectors self.model.set_refpoint(state=x, input=u) I = torch.eye(P.shape[-1], device=P.device, dtype=P.dtype) A, B = self.model.A, self.model.B C, D = self.model.C, self.model.D c1, c2 = self.model.c1, self.model.c2 Q = Q if Q is not None else self.Q R = R if R is not None else self.R x = bmv(A, x) + bmv(B, u) + c1 # 1. System transition P = A @ P @ A.mT + Q # 2. Covariance predict K = P @ C.mT @ pinv(C @ P @ C.mT + R) # 3. Kalman gain e = y - bmv(C, x) - bmv(D, u) - c2 # predicted observation error x = x + bmv(K, e) # 4. Posteriori state P = (I - K @ C) @ P # 5. Posteriori covariance return x, P
@property def Q(self): r''' The covariance of system transition noise. ''' if not hasattr(self, '_Q'): raise NotImplementedError('Call set_uncertainty() to define\ transition covariance Q.') return self._Q @property def R(self): r''' The covariance of system observation noise. ''' if not hasattr(self, '_R'): raise NotImplementedError('Call set_uncertainty() to define\ transition covariance R.') return self._R
[docs] def set_uncertainty(self, Q=None, R=None): r''' Set the covariance matrices of transition noise and observation noise. Args: Q (:obj:`Tensor`): batched square covariance matrices of transition noise. R (:obj:`Tensor`): batched square covariance matrices of observation noise. ''' if Q is not None: self.register_buffer("_Q", Q) if R is not None: self.register_buffer("_R", R)

Docs

Access documentation for PyPose

View Docs

Tutorials

Get started with tutorials and examples

View Tutorials

Get Started

Find resources and how to start using pypose

View Resources