Shortcuts

Source code for pypose.lietensor.lietensor

import collections
import math, numbers
import torch, warnings
from torch import nn, linalg
from .operation import broadcast_inputs
from .basics import cumops_, cummul_, cumprod_
from .basics import vec2skew, cumops, cummul, cumprod
from torch.utils._pytree import tree_map, tree_flatten
from .operation import SO3_Log, SE3_Log, RxSO3_Log, Sim3_Log
from .operation import so3_Exp, se3_Exp, rxso3_Exp, sim3_Exp
from .operation import SO3_Act, SE3_Act, RxSO3_Act, Sim3_Act
from .operation import SO3_Mul, SE3_Mul, RxSO3_Mul, Sim3_Mul
from .operation import SO3_Inv, SE3_Inv, RxSO3_Inv, Sim3_Inv
from .operation import SO3_Act4, SE3_Act4, RxSO3_Act4, Sim3_Act4
from .operation import SO3_AdjXa, SE3_AdjXa, RxSO3_AdjXa, Sim3_AdjXa
from .operation import SO3_AdjTXa, SE3_AdjTXa, RxSO3_AdjTXa, Sim3_AdjTXa
from .operation import so3_Jl_inv, se3_Jl_inv, rxso3_Jl_inv, sim3_Jl_inv
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple, _ntuple


HANDLED_FUNCTIONS = ['__getitem__', '__setitem__', 'cpu', 'cuda', 'float', 'double',
                     'to', 'detach', 'view', 'view_as', 'squeeze', 'unsqueeze', 'cat',
                     'stack', 'split', 'hsplit', 'dsplit', 'vsplit', 'tensor_split',
                     'chunk', 'concat', 'column_stack', 'dstack', 'vstack', 'hstack',
                     'index_select', 'masked_select', 'movedim', 'moveaxis', 'narrow',
                     'permute', 'reshape', 'row_stack', 'scatter', 'scatter_add', 'clone',
                     'swapaxes', 'swapdims', 'take', 'take_along_dim', 'tile', 'copy',
                     'transpose', 'unbind', 'gather', 'repeat', 'expand', 'expand_as',
                     'index_select', 'masked_select', 'index_copy', 'index_copy_',
                     'select', 'select_scatter', 'index_put','index_put_', 'copy_']

class LieType:
    '''LieTensor Type Base Class'''
    def __init__(self, dimension, embedding, manifold):
        self._dimension = torch.Size([dimension]) # Data dimension
        self._embedding = torch.Size([embedding]) # Embedding dimension
        self._manifold  = torch.Size([manifold])  # Manifold dimension

    @property
    def dimension(self):
        return self._dimension

    @property
    def embedding(self):
        return self._embedding

    @property
    def manifold(self):
        return self._manifold

    @property
    def on_manifold(self):
        return self.dimension == self.manifold

    def add_(self, input, other):
        if self.on_manifold:
            other1 = torch.Tensor.as_subclass(input, torch.Tensor)
            other2 = torch.Tensor.as_subclass(other, torch.Tensor)
            return input.copy_(other1 + other2[..., :self.manifold[0]])
        raise NotImplementedError("Instance has no add_ attribute.")

    def Log(self, X):
        if self.on_manifold:
            raise AttributeError("Lie Algebra has no Log attribute")
        raise NotImplementedError("Instance has no Log attribute.")

    def Exp(self, x):
        if not self.on_manifold:
            raise AttributeError("Lie Group has no Exp attribute")
        raise NotImplementedError("Instance has no Exp attribute.")

    def Inv(self, x):
        if self.on_manifold:
            return - x
        raise NotImplementedError("Instance has no Inv attribute.")

    def Act(self, X, p):
        """ action on a points tensor(*, 3[4]) (homogeneous)"""
        if not self.on_manifold:
            raise AttributeError("Lie Group has no Act attribute")
        raise NotImplementedError("Instance has no Act attribute.")

    def Mul(self, X, Y):
        if not self.on_manifold:
            raise AttributeError("Lie Group has no Mul attribute")
        raise NotImplementedError("Instance has no Mul attribute.")

    def Retr(self, X, a):
        if self.on_manifold:
            raise AttributeError("Has no Retr attribute")
        return a.Exp() * X

    def Adj(self, X, a):
        ''' X * Exp(a) = Exp(Adj) * X '''
        if not self.on_manifold:
            raise AttributeError("Lie Group has no Adj attribute")
        raise NotImplementedError("Instance has no Adj attribute.")

    def AdjT(self, X, a):
        ''' Exp(a) * X = X * Exp(AdjT) '''
        if not self.on_manifold:
            raise AttributeError("Lie Group has no AdjT attribute")
        raise NotImplementedError("Instance has no AdjT attribute.")

    def Jinvp(self, X, p):
        if not self.on_manifold:
            raise AttributeError("Lie Group has no Jinvp attribute")
        raise NotImplementedError("Instance has no Jinvp attribute.")

    def matrix(self, input):
        """ To 4x4 matrix """
        X = input.Exp() if self.on_manifold else input
        I = torch.eye(4, dtype=X.dtype, device=X.device)
        I = I.view([1] * (X.dim() - 1) + [4, 4])
        return X.unsqueeze(-2).Act(I).transpose(-1,-2)

    def rotation(self, input):
        raise NotImplementedError("Rotation is not implemented for the instance.")

    def translation(self, input):
        warnings.warn("Instance has no translation. Zero vector(s) is returned.")
        return torch.zeros(input.lshape + (3,), dtype=input.dtype, device=input.device,
            requires_grad=input.requires_grad)

    def scale(self, input):
        warnings.warn("Instance has no scale. Scalar one(s) is returned.")
        return torch.ones(input.lshape + (1,), dtype=input.dtype, device=input.device,
            requires_grad=input.requires_grad)

    @classmethod
    def to_tuple(cls, input):
        out = tuple()
        for i in input:
            if not isinstance(i, collections.abc.Iterable):
                out += (i,)
            else:
                out += tuple(i)
        return out

    @classmethod
    def identity(cls, *args, **kwargs):
        raise NotImplementedError("Instance has no identity.")

    @classmethod
    def identity_like(cls, *args, **kwargs):
        return cls.identity(*args, **kwargs)

    def randn_like(self, *args, sigma=1.0, **kwargs):
        return self.randn(*args, sigma=sigma, **kwargs)

    def randn(self, *args, **kwargs):
        raise NotImplementedError("randn not implemented yet")

    @classmethod
    def cumops(self, X, dim, ops):
        return cumops(X, dim, ops)

    @classmethod
    def cummul(self, X, dim):
        return cummul(X, dim)

    @classmethod
    def cumprod(self, X, dim, left = True):
        return cumprod(X, dim, left)

    @classmethod
    def cumops_(self, X, dim, ops):
        return cumops_(X, dim, ops)

    @classmethod
    def cummul_(self, X, dim):
        return cummul_(X, dim)

    @classmethod
    def cumprod_(self, X, dim):
        return cumprod_(X, dim)


class SO3Type(LieType):
    def __init__(self):
        super().__init__(4, 4, 3)

    def Log(self, X):
        X = X.tensor() if hasattr(X, 'ltype') else X
        x = SO3_Log.apply(X)
        return LieTensor(x, ltype=so3_type)
    
    def Act(self, X, p):
        assert not self.on_manifold and isinstance(p, torch.Tensor)
        assert p.shape[-1]==3 or p.shape[-1]==4, "Invalid Tensor Dimension"
        X = X.tensor() if hasattr(X, 'ltype') else X
        input, out_shape = broadcast_inputs(X, p)
        if p.shape[-1]==3:
            out = SO3_Act.apply(*input)
        else:
            out = SO3_Act4.apply(*input)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        return out.view(out_shape + (dim,))

    def Mul(self, X, Y):
        # Transform on transform
        X = X.tensor() if hasattr(X, 'ltype') else X
        if not self.on_manifold and isinstance(Y, LieTensor) and not Y.ltype.on_manifold:
            Y = Y.tensor() if hasattr(Y, 'ltype') else Y
            input, out_shape = broadcast_inputs(X, Y)
            out = SO3_Mul.apply(*input)
            dim = -1 if out.nelement() != 0 else X.shape[-1]
            out = out.view(out_shape + (dim,))
            return LieTensor(out, ltype=SO3_type)
        # Transform on points
        if not self.on_manifold and isinstance(Y, torch.Tensor):
            return self.Act(X, Y)
        # (scalar or tensor) * manifold
        if self.on_manifold:
            return LieTensor(torch.mul(X, Y), ltype=SO3_type)
        raise NotImplementedError('Invalid __mul__ operation')
    
    def Inv(self, X):
        X = X.tensor() if hasattr(X, 'ltype') else X
        out = SO3_Inv.apply(X)
        return LieTensor(out, ltype=SO3_type)
    
    def Adj(self, X, a):
        X = X.tensor() if hasattr(X, 'ltype') else X
        a = a.tensor() if hasattr(a, 'ltype') else a
        input, out_shape = broadcast_inputs(X, a)
        out = SO3_AdjXa.apply(*input)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        out = out.view(out_shape + (dim,))
        return LieTensor(out, ltype=so3_type)

    def AdjT(self, X, a):
        X = X.tensor() if hasattr(X, 'ltype') else X
        a = a.tensor() if hasattr(a, 'ltype') else a
        input, out_shape = broadcast_inputs(X, a)
        out = SO3_AdjTXa.apply(*input)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        out = out.view(out_shape + (dim,))
        return LieTensor(out, ltype=so3_type)

    def Jinvp(self, X, a):
        X = X.tensor() if hasattr(X, 'ltype') else X
        a = a.tensor() if hasattr(a, 'ltype') else a
        (X, a), out_shape = broadcast_inputs(X, a)
        out = (so3_Jl_inv(SO3_Log.apply(X)) @ a.unsqueeze(-1)).squeeze(-1)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        out = out.view(out_shape + (dim,))
        return LieTensor(out, ltype=so3_type)

    @classmethod
    def identity(cls, *size, **kwargs):
        data = torch.tensor([0., 0., 0., 1.], **kwargs)
        return LieTensor(data.repeat(size+(1,)), ltype=SO3_type)

    def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
        data = so3_type.Exp(so3_type.randn(*size, sigma=sigma, **kwargs)).detach()
        return LieTensor(data, ltype=SO3_type).requires_grad_(requires_grad)

    @classmethod
    def add_(cls, input, other):
        return input.copy_(LieTensor(other[..., :3], ltype=so3_type).Exp() * input)

    def matrix(self, input):
        """ To 3x3 matrix """
        I = torch.eye(3, dtype=input.dtype, device=input.device)
        I = I.view([1] * (input.dim() - 1) + [3, 3])
        return input.unsqueeze(-2).Act(I).transpose(-1,-2)

    def rotation(self, input):
        return input

    def identity_(self, X):
        X.fill_(0)
        X.index_fill_(dim=-1, index=torch.tensor([-1], device=X.device), value=1)
        return X

    def Jr(self, X):
        """
        Right jacobian of SO(3)
        """
        return X.Log().Jr()


class so3Type(LieType):
    def __init__(self):
        super().__init__(3, 4, 3)

    def Exp(self, x):
        x = x.tensor() if hasattr(x, 'ltype') else x
        X = so3_Exp.apply(x)
        return LieTensor(X, ltype=SO3_type)

    def Mul(self, X, Y):
        X = X.tensor() if hasattr(X, 'ltype') else X
        # (scalar or tensor) * manifold
        if self.on_manifold:
            return LieTensor(torch.mul(X, Y), ltype=so3_type)
        raise NotImplementedError('Invalid __mul__ operation')

    @classmethod
    def identity(cls, *size, **kwargs):
        return SO3_type.Log(SO3_type.identity(*size, **kwargs))

    def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
        assert isinstance(sigma, numbers.Number), 'Only accepts sigma as a single number'
        size = self.to_tuple(size)
        data = torch.randn(*(size + torch.Size([3])), **kwargs)
        dist = data.norm(dim=-1, keepdim=True)
        theta = sigma * torch.randn(*(size + torch.Size([1])), **kwargs)
        return LieTensor(data / dist * theta, ltype=so3_type).requires_grad_(requires_grad)

    def matrix(self, input):
        """ To 3x3 matrix """
        X = input.Exp()
        I = torch.eye(3, dtype=X.dtype, device=X.device)
        I = I.view([1] * (X.dim() - 1) + [3, 3])
        return X.unsqueeze(-2).Act(I).transpose(-1,-2)

    def rotation(self, input):
        return input.Exp().rotation()

    def Jr(self, x):
        """
        Right jacobian of so(3)
        """
        K = vec2skew(x)
        theta = torch.linalg.norm(x, dim=-1, keepdim=True).unsqueeze(-1)
        I = torch.eye(3, device=x.device, dtype=x.dtype).expand(x.lshape+(3, 3))
        Jr = I - (1-theta.cos())/theta**2 * K + (theta - theta.sin())/theta**3 * K@K
        return torch.where(theta>torch.finfo(theta.dtype).eps, Jr, I)


class SE3Type(LieType):
    def __init__(self):
        super().__init__(7, 7, 6)

    def Log(self, X):
        X = X.tensor() if hasattr(X, 'ltype') else X
        x = SE3_Log.apply(X)
        return LieTensor(x, ltype=se3_type)

    def Act(self, X, p):
        assert not self.on_manifold and isinstance(p, torch.Tensor)
        assert p.shape[-1]==3 or p.shape[-1]==4, "Invalid Tensor Dimension"
        X = X.tensor() if hasattr(X, 'ltype') else X
        input, out_shape = broadcast_inputs(X, p)
        if p.shape[-1]==3:
            out = SE3_Act.apply(*input)
        else:
            out = SE3_Act4.apply(*input)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        return out.view(out_shape + (dim,))

    def Mul(self, X, Y):
        # Transform on transform
        X = X.tensor() if hasattr(X, 'ltype') else X
        if not self.on_manifold and isinstance(Y, LieTensor) and not Y.ltype.on_manifold:
            Y = Y.tensor() if hasattr(Y, 'ltype') else Y
            input, out_shape = broadcast_inputs(X, Y)
            out = SE3_Mul.apply(*input)
            dim = -1 if out.nelement() != 0 else X.shape[-1]
            out = out.view(out_shape + (dim,))
            return LieTensor(out, ltype=SE3_type)
        # Transform on points
        if not self.on_manifold and isinstance(Y, torch.Tensor):
            return self.Act(X, Y)
        # (scalar or tensor) * manifold
        if self.on_manifold:
            return LieTensor(torch.mul(X, Y), ltype=SE3_type)
        raise NotImplementedError('Invalid __mul__ operation')

    def Inv(self, X):
        X = X.tensor() if hasattr(X, 'ltype') else X
        out = SE3_Inv.apply(X)
        return LieTensor(out, ltype=SE3_type)

    def rotation(self, input):
        return LieTensor(input.tensor()[..., 3:7], ltype=SO3_type)

    def translation(self, input):
        return input.tensor()[..., 0:3]

    def Adj(self, X, a):
        X = X.tensor() if hasattr(X, 'ltype') else X
        a = a.tensor() if hasattr(a, 'ltype') else a
        input, out_shape = broadcast_inputs(X, a)
        out = SE3_AdjXa.apply(*input)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        out = out.view(out_shape + (dim,))
        return LieTensor(out, ltype=se3_type)

    def AdjT(self, X, a):
        X = X.tensor() if hasattr(X, 'ltype') else X
        a = a.tensor() if hasattr(a, 'ltype') else a
        input, out_shape = broadcast_inputs(X, a)
        out = SE3_AdjTXa.apply(*input)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        out = out.view(out_shape + (dim,))
        return LieTensor(out, ltype=se3_type)

    def Jinvp(self, X, a):
        X = X.tensor() if hasattr(X, 'ltype') else X
        a = a.tensor() if hasattr(a, 'ltype') else a
        (X, a), out_shape = broadcast_inputs(X, a)
        out = (se3_Jl_inv(SE3_Log.apply(X)) @ a.unsqueeze(-1)).squeeze(-1)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        out = out.view(out_shape + (dim,))
        return LieTensor(out, ltype=se3_type)

    @classmethod
    def identity(cls, *size, **kwargs):
        data = torch.tensor([0., 0., 0., 0., 0., 0., 1.], **kwargs)
        return LieTensor(data.repeat(size+(1,)), ltype=SE3_type)

    def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
        data = se3_type.Exp(se3_type.randn(*size, sigma=sigma, **kwargs)).detach()
        return LieTensor(data, ltype=SE3_type).requires_grad_(requires_grad)

    @classmethod
    def add_(cls, input, other):
        return input.copy_(LieTensor(other[..., :6], ltype=se3_type).Exp() * input)


class se3Type(LieType):
    def __init__(self):
        super().__init__(6, 7, 6)

    def Exp(self, x):
        x = x.tensor() if hasattr(x, 'ltype') else x
        X = se3_Exp.apply(x)
        return LieTensor(X, ltype=SE3_type)

    def Mul(self, X, Y):
        X = X.tensor() if hasattr(X, 'ltype') else X
        # (scalar or tensor) * manifold
        if self.on_manifold:
            return LieTensor(torch.mul(X, Y), ltype=se3_type)
        raise NotImplementedError('Invalid __mul__ operation')

    def rotation(self, input):
        return input.Exp().rotation()

    def translation(self, input):
        return input.Exp().translation()

    @classmethod
    def identity(cls, *size, **kwargs):
        return SE3_type.Log(SE3_type.identity(*size, **kwargs))

    def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
        #  convert different types of inputs to SE3 sigma
        if not isinstance(sigma, collections.abc.Iterable):
            sigma = _quadruple(sigma)
        elif len(sigma)==2:
            rotation_sigma = _single(sigma[-1])
            translation_sigma = _triple(sigma[0])
            sigma = translation_sigma + rotation_sigma
        else:
            assert len(sigma)==4, 'Only accepts a tuple of sigma in size 1, 2, or 4.'
        size = self.to_tuple(size)
        rotation = so3_type.randn(*size, sigma=sigma[-1], **kwargs).detach().tensor()
        sigma = torch.tensor([sigma[0], sigma[1], sigma[2]], **kwargs)
        translation = sigma * torch.randn(*(size + torch.Size([3])), **kwargs)
        data = torch.cat([translation, rotation], dim=-1)
        return LieTensor(data, ltype=se3_type).requires_grad_(requires_grad)


class Sim3Type(LieType):
    def __init__(self):
        super().__init__(8, 8, 7)

    def Log(self, X):
        X = X.tensor() if hasattr(X, 'ltype') else X
        x = Sim3_Log.apply(X)
        return LieTensor(x, ltype=sim3_type)

    def Act(self, X, p):
        assert not self.on_manifold and isinstance(p, torch.Tensor)
        assert p.shape[-1]==3 or p.shape[-1]==4, "Invalid Tensor Dimension"
        X = X.tensor() if hasattr(X, 'ltype') else X
        input, out_shape = broadcast_inputs(X, p)
        if p.shape[-1]==3:
            out = Sim3_Act.apply(*input)
        else:
            out = Sim3_Act4.apply(*input)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        return out.view(out_shape + (dim,))

    def Mul(self, X, Y):
        # Transform on transform
        X = X.tensor() if hasattr(X, 'ltype') else X
        if not self.on_manifold and isinstance(Y, LieTensor) and not Y.ltype.on_manifold:
            Y = Y.tensor() if hasattr(Y, 'ltype') else Y
            input, out_shape = broadcast_inputs(X, Y)
            out = Sim3_Mul.apply(*input)
            dim = -1 if out.nelement() != 0 else X.shape[-1]
            out = out.view(out_shape + (dim,))
            return LieTensor(out, ltype=Sim3_type)
        # Transform on points
        if not self.on_manifold and isinstance(Y, torch.Tensor):
            return self.Act(X, Y)
        # (scalar or tensor) * manifold
        if self.on_manifold:
            return LieTensor(torch.mul(X, Y), ltype=Sim3_type)
        raise NotImplementedError('Invalid __mul__ operation')

    def Inv(self, X):
        X = X.tensor() if hasattr(X, 'ltype') else X
        out = Sim3_Inv.apply(X)
        return LieTensor(out, ltype=Sim3_type)

    def Adj(self, X, a):
        X = X.tensor() if hasattr(X, 'ltype') else X
        a = a.tensor() if hasattr(a, 'ltype') else a
        input, out_shape = broadcast_inputs(X, a)
        out = Sim3_AdjXa.apply(*input)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        out = out.view(out_shape + (dim,))
        return LieTensor(out, ltype=sim3_type)

    def AdjT(self, X, a):
        X = X.tensor() if hasattr(X, 'ltype') else X
        a = a.tensor() if hasattr(a, 'ltype') else a
        input, out_shape = broadcast_inputs(X, a)
        out = Sim3_AdjTXa.apply(*input)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        out = out.view(out_shape + (dim,))
        return LieTensor(out, ltype=sim3_type)

    def Jinvp(self, X, a):
        X = X.tensor() if hasattr(X, 'ltype') else X
        a = a.tensor() if hasattr(a, 'ltype') else a
        (X, a), out_shape = broadcast_inputs(X, a)
        out = (sim3_Jl_inv(Sim3_Log.apply(X)) @ a.unsqueeze(-1)).squeeze(-1)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        out = out.view(out_shape + (dim,))
        return LieTensor(out, ltype=sim3_type)

    def rotation(self, input):
        return LieTensor(input.tensor()[..., 3:7], ltype=SO3_type)

    def translation(self, input):
        return input.tensor()[..., 0:3]

    def scale(self, input):
        return input.tensor()[..., 7:8]

    @classmethod
    def identity(cls, *size, **kwargs):
        data = torch.tensor([0., 0., 0., 0., 0., 0., 1., 1.], **kwargs)
        return LieTensor(data.repeat(size+(1,)), ltype=Sim3_type)

    def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
        data = sim3_type.Exp(sim3_type.randn(*size, sigma=sigma, **kwargs)).detach()
        return LieTensor(data, ltype=Sim3_type).requires_grad_(requires_grad)

    @classmethod
    def add_(cls, input, other):
        return input.copy_(LieTensor(other[..., :7], ltype=sim3_type).Exp() * input)


class sim3Type(LieType):
    def __init__(self):
        super().__init__(7, 8, 7)

    def Exp(self, x):
        x = x.tensor() if hasattr(x, 'ltype') else x
        X = sim3_Exp.apply(x)
        return LieTensor(X, ltype=Sim3_type)

    def Mul(self, X, Y):
        X = X.tensor() if hasattr(X, 'ltype') else X
        # (scalar or tensor) * manifold
        if self.on_manifold:
            return LieTensor(torch.mul(X, Y), ltype=sim3_type)
        raise NotImplementedError('Invalid __mul__ operation')

    def rotation(self, input):
        return input.Exp().rotation()

    def translation(self, input):
        return input.Exp().translation()

    def scale(self, input):
        return input.Exp().scale()

    @classmethod
    def identity(cls, *size, **kwargs):
        return Sim3_type.Log(Sim3_type.identity(*size, **kwargs))

    def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
        if not isinstance(sigma, collections.abc.Iterable):
            sigma = _ntuple(5, "_penta")(sigma)
        elif len(sigma)==3:
            rotation_sigma = _single(sigma[-2])
            scale_sigma = _single(sigma[-1])
            translation_sigma = _triple(sigma[0])
            sigma = translation_sigma+rotation_sigma+scale_sigma
        else:
            assert len(sigma)==5, 'Only accepts a tuple of sigma in size 1, 3, or 5.'
        size = self.to_tuple(size)
        rotation = so3_type.randn(*size, sigma=sigma[-2], **kwargs).detach().tensor()
        scale = sigma[-1] * torch.randn(*(size + torch.Size([1])), **kwargs)
        sigma = torch.tensor([sigma[0], sigma[1], sigma[2]], **kwargs)
        translation = sigma * torch.randn(*(size + torch.Size([3])), **kwargs)
        data = torch.cat([translation, rotation, scale], dim=-1)
        return LieTensor(data, ltype=sim3_type).requires_grad_(requires_grad)


class RxSO3Type(LieType):
    def __init__(self):
        super().__init__(5, 5, 4)

    def Log(self, X):
        X = X.tensor() if hasattr(X, 'ltype') else X
        x = RxSO3_Log.apply(X)
        return LieTensor(x, ltype=rxso3_type)

    def Act(self, X, p):
        assert not self.on_manifold and isinstance(p, torch.Tensor)
        assert p.shape[-1]==3 or p.shape[-1]==4, "Invalid Tensor Dimension"
        X = X.tensor() if hasattr(X, 'ltype') else X
        input, out_shape = broadcast_inputs(X, p)
        if p.shape[-1]==3:
            out = RxSO3_Act.apply(*input)
        else:
            out = RxSO3_Act4.apply(*input)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        return out.view(out_shape + (dim,))

    def Mul(self, X, Y):
        # Transform on transform
        X = X.tensor() if hasattr(X, 'ltype') else X
        if not self.on_manifold and isinstance(Y, LieTensor) and not Y.ltype.on_manifold:
            Y = Y.tensor() if hasattr(Y, 'ltype') else Y
            input, out_shape = broadcast_inputs(X, Y)
            out = RxSO3_Mul.apply(*input)
            dim = -1 if out.nelement() != 0 else X.shape[-1]
            out = out.view(out_shape + (dim,))
            return LieTensor(out, ltype=RxSO3_type)
        # Transform on points
        if not self.on_manifold and isinstance(Y, torch.Tensor):
            return self.Act(X, Y)
        # (scalar or tensor) * manifold
        if self.on_manifold:
            return LieTensor(torch.mul(X, Y), ltype=RxSO3_type)
        raise NotImplementedError('Invalid __mul__ operation')

    def Inv(self, X):
        X = X.tensor() if hasattr(X, 'ltype') else X
        out = RxSO3_Inv.apply(X)
        return LieTensor(out, ltype=RxSO3_type)

    def Adj(self, X, a):
        X = X.tensor() if hasattr(X, 'ltype') else X
        a = a.tensor() if hasattr(a, 'ltype') else a
        input, out_shape = broadcast_inputs(X, a)
        out = RxSO3_AdjXa.apply(*input)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        out = out.view(out_shape + (dim,))
        return LieTensor(out, ltype=rxso3_type)

    def AdjT(self, X, a):
        X = X.tensor() if hasattr(X, 'ltype') else X
        a = a.tensor() if hasattr(a, 'ltype') else a
        input, out_shape = broadcast_inputs(X, a)
        out = RxSO3_AdjTXa.apply(*input)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        out = out.view(out_shape + (dim,))
        return LieTensor(out, ltype=rxso3_type)

    def Jinvp(self, X, a):
        X = X.tensor() if hasattr(X, 'ltype') else X
        a = a.tensor() if hasattr(a, 'ltype') else a
        (X, a), out_shape = broadcast_inputs(X, a)
        out = (rxso3_Jl_inv(RxSO3_Log.apply(X)) @ a.unsqueeze(-1)).squeeze(-1)
        dim = -1 if out.nelement() != 0 else X.shape[-1]
        out = out.view(out_shape + (dim,))
        return LieTensor(out, ltype=rxso3_type)

    def rotation(self, input):
        return LieTensor(input.tensor()[..., 0:4], ltype=SO3_type)

    def scale(self, input):
        return input.tensor()[..., 4:5]

    @classmethod
    def identity(cls, *size, **kwargs):
        data = torch.tensor([0., 0., 0., 1., 1.], **kwargs)
        return LieTensor(data.repeat(size+(1,)), ltype=RxSO3_type)

    def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
        data = rxso3_type.Exp(rxso3_type.randn(*size, sigma=sigma, **kwargs)).detach()
        return LieTensor(data, ltype=RxSO3_type).requires_grad_(requires_grad)

    @classmethod
    def add_(cls, input, other):
        return input.copy_(LieTensor(other[..., :4], ltype=rxso3_type).Exp() * input)


class rxso3Type(LieType):
    def __init__(self):
        super().__init__(4, 5, 4)

    def Exp(self, x):
        x = x.tensor() if hasattr(x, 'ltype') else x
        X = rxso3_Exp.apply(x)
        return LieTensor(X, ltype=RxSO3_type)

    def Mul(self, X, Y):
        X = X.tensor() if hasattr(X, 'ltype') else X
        # (scalar or tensor) * manifold
        if self.on_manifold:
            return LieTensor(torch.mul(X, Y), ltype=rxso3_type)
        raise NotImplementedError('Invalid __mul__ operation')

    def rotation(self, input):
        return input.Exp().rotation()

    def scale(self, input):
        return input.Exp().scale()

    @classmethod
    def identity(cls, *size, **kwargs):
        return RxSO3_type.Log(RxSO3_type.identity(*size, **kwargs))

    def randn(self, *size, sigma=1.0, requires_grad=False, **kwargs):
        if not isinstance(sigma, collections.abc.Iterable):
            sigma = _pair(sigma)
        else:
            assert len(sigma)==2, 'Only accepts a tuple of sigma in size 1 or 2.'
        size = self.to_tuple(size)
        rotation = so3_type.randn(*size, sigma=sigma[0], **kwargs).tensor()
        scale = sigma[1] * torch.randn(*(size + torch.Size([1])), **kwargs)
        data = torch.cat([rotation, scale], dim=-1)
        return LieTensor(data, ltype=rxso3_type).requires_grad_(requires_grad)


SO3_type, so3_type = SO3Type(), so3Type()
SE3_type, se3_type = SE3Type(), se3Type()
Sim3_type, sim3_type = Sim3Type(), sim3Type()
RxSO3_type, rxso3_type = RxSO3Type(), rxso3Type()


[docs]class LieTensor(torch.Tensor): r""" A sub-class of :obj:`torch.Tensor` to represent Lie Algebra and Lie Group. Args: data (:obj:`Tensor`, or :obj:`list`, or ':obj:`int`...'): A :obj:`Tensor` object, or constructing a :obj:`Tensor` object from :obj:`list`, which defines tensor data, or from ':obj:`int`...', which defines tensor shape. The shape of :obj:`Tensor` object should be compatible with Lie Type :obj:`ltype`, otherwise error will be raised. ltype (:obj:`ltype`): Lie Type, either **Lie Group** or **Lie Algebra** is listed below: Returns: LieTensor corresponding to Lie Type :obj:`ltype`. .. list-table:: List of :obj:`ltype` for **Lie Group** :widths: 25 25 30 30 :header-rows: 1 * - Representation - :obj:`ltype` - :obj:`shape` - Alias Class * - Rotation - :obj:`SO3_type` - :obj:`(*, 4)` - :meth:`SO3` * - Translation + Rotation - :obj:`SE3_type` - :obj:`(*, 7)` - :meth:`SE3` * - Translation + Rotation + Scale - :obj:`Sim3_type` - :obj:`(*, 8)` - :meth:`Sim3` * - Rotation + Scale - :obj:`RxSO3_type` - :obj:`(*, 5)` - :meth:`RxSO3` .. list-table:: List of :obj:`ltype` for **Lie Algebra** :widths: 25 25 30 30 :header-rows: 1 * - Representation - :obj:`ltype` - :obj:`shape` - Alias Class * - Rotation - :obj:`so3_type` - :obj:`(*, 3)` - :meth:`so3` * - Translation + Rotation - :obj:`se3_type` - :obj:`(*, 6)` - :meth:`se3` * - Translation + Rotation + Scale - :obj:`sim3_type` - :obj:`(*, 7)` - :meth:`sim3` * - Rotation + Scale - :obj:`rxso3_type` - :obj:`(*, 4)` - :meth:`rxso3` Note: Two attributes :obj:`shape` and :obj:`lshape` are available for LieTensor. The only differece is the :obj:`lshape` hides the last dimension of :obj:`shape`, since :obj:`lshape` takes the data in the last dimension as a single :obj:`ltype` item. See LieTensor method :meth:`lview` for more details. Examples: >>> import torch >>> import pypose as pp >>> data = torch.randn(3, 3, requires_grad=True, device='cuda:0') >>> pp.LieTensor(data, ltype=pp.so3_type) so3Type LieTensor: tensor([[ 0.9520, 0.4517, 0.5834], [-0.8106, 0.8197, 0.7077], [-0.5743, 0.8182, -1.2104]], device='cuda:0', grad_fn=<AliasBackward0>) Alias class for specific LieTensor is recommended: >>> pp.so3(data) so3Type LieTensor: tensor([[ 0.9520, 0.4517, 0.5834], [-0.8106, 0.8197, 0.7077], [-0.5743, 0.8182, -1.2104]], device='cuda:0', grad_fn=<AliasBackward0>) See more alias classes at `Table 1 for Lie Group <#id1>`_ and `Table 2 for Lie Algebra <#id2>`_. Other constructors: - From list. >>> pp.so3([0, 0, 0]) so3Type LieTensor: tensor([0., 0., 0.]) - From ints. >>> pp.so3(2, 3) so3Type LieTensor: tensor([[0., 0., 0.], [0., 0., 0.]]) Note: Alias class for LieTensor is recommended. For example, the following usage is equivalent: - :obj:`pp.LieTensor(tensor, ltype=pp.so3_type)` - :obj:`pp.so3(tensor)` (This is preferred). Note: All attributes from Tensor are available for LieTensor, e.g., :obj:`dtype`, :obj:`device`, and :obj:`requires_grad`. See more details at `tensor attributes <https://pytorch.org/docs/stable/tensor_attributes.html>`_. Example: >>> data = torch.randn(1, 3, dtype=torch.float64, device="cuda", requires_grad=True) >>> pp.so3(data) # All Tensor attributes are available for LieTensor so3Type LieTensor: tensor([[-1.5948, 0.3113, -0.9807]], device='cuda:0', dtype=torch.float64, grad_fn=<AliasBackward0>) Note: In most of the cases, Lie Group is expected to be used, therefore we only provide `converting functions <https://pypose.org/docs/main/convert/>`_ between Lie Groups and other data structures, e.g., transformation matrix, Euler angle, etc. The users can convert data between Lie Group and Lie algebra with :obj:`Exp` and :obj:`Log`. """ def __init__(self, *data, ltype:LieType): assert self.shape[-1:] == ltype.dimension, 'The last dimension of a LieTensor has to be ' \ 'corresponding to their LieType. More details go to {}. If this error happens in an ' \ 'optimization process, where LieType is not a necessary structure, we suggest to ' \ 'call .tensor() to convert a LieTensor to Tensor before passing it to an optimizer. ' \ 'If this still happens, create an issue on GitHub please.'.format( 'https://pypose.org/docs/main/generated/pypose.LieTensor') self.ltype = ltype @staticmethod def __new__(cls, *data, ltype): tensor = data[0] if isinstance(data[0], torch.Tensor) else torch.Tensor(*data) return torch.Tensor.as_subclass(tensor, LieTensor) def __repr__(self): if hasattr(self, 'ltype'): return self.ltype.__class__.__name__ + \ ' %s:\n'%(self.__class__.__name__) + super().__repr__() else: return super().__repr__()
[docs] def new_empty(self, shape): return torch.Tensor.as_subclass(torch.empty(shape), LieTensor)
@classmethod def __torch_function__(cls, func, types, args=(), kwargs={}): ltypes = (torch.Tensor if t is LieTensor or Parameter else t for t in types) data = torch.Tensor.__torch_function__(func, ltypes, args, kwargs) if data is not None and func.__name__ in HANDLED_FUNCTIONS: args, spec = tree_flatten(args) ltype = [arg.ltype for arg in args if isinstance(arg, LieTensor)][0] def warp(t): if isinstance(t, torch.Tensor) and not isinstance(t, cls): lt = torch.Tensor.as_subclass(t, LieTensor) lt.ltype = ltype if lt.shape[-1:] != lt.ltype.dimension: link = 'https://pypose.org/docs/main/generated/pypose.LieTensor' warnings.warn('Tensor Shape Invalid by calling {}, ' \ 'go to {}'.format(func, link)) return lt return t return tree_map(warp, data) return data @property def lshape(self) -> torch.Size: r''' LieTensor Shape (shape of torch.Tensor by ignoring the last dimension) Returns: torch.Size Note: - The only difference from :obj:`shape` is the last dimension is hidden, since :obj:`lshape` takes the last dimension as a single :obj:`ltype` item. - The last dimension can also be accessed via :obj:`LieTensor.ltype.dimension`. Examples: >>> x = pp.randn_SE3(2) >>> x.lshape torch.Size([2]) >>> x.shape torch.Size([2, 7]) >>> x.ltype.dimension torch.Size([7]) ''' return self.shape[:-1]
[docs] def lview(self, *shape): r''' Returns a new LieTensor with the same data as the self tensor but of a different :obj:`lshape`. Args: shape (torch.Size or int...): the desired size Returns: A new lieGroup tensor sharing with the same data as the self tensor but of a different shape. Note: The only difference from :meth:`view` is the last dimension is hidden. See `Tensor.view <https://tinyurl.com/mrds8nmd>`_ for its usage. Examples: >>> x = pp.randn_so3(2,2) >>> x.shape torch.Size([2, 2, 3]) >>> x.lview(-1).lshape torch.Size([4]) ''' return self.view(*shape+self.ltype.dimension)
[docs] def Exp(self): r''' See :meth:`pypose.Exp` ''' return self.ltype.Exp(self)
[docs] def Log(self): r''' See :meth:`pypose.Log` ''' return self.ltype.Log(self)
[docs] def Inv(self): r''' See :meth:`pypose.Inv` ''' return self.ltype.Inv(self)
[docs] def Act(self, p): r''' See :meth:`pypose.Act` ''' return self.ltype.Act(self, p)
[docs] def add(self, other, alpha=1): r''' See :meth:`pypose.add` ''' return self.clone().add_(other = alpha * other)
[docs] def add_(self, other, alpha=1): r''' See :meth:`pypose.add_` ''' return self.ltype.add_(self, other = alpha * other)
def __add__(self, other): return self.add(other=other) def __mul__(self, other): r''' See :meth:`pypose.mul` ''' return self.ltype.Mul(self, other) def __matmul__(self, other): r''' See :meth:`pypose.matmul` ''' if isinstance(other, LieTensor): return self.ltype.Mul(self, other) else: # Same with: self.ltype.matrix(self) @ other return self.Act(other)
[docs] def Retr(self, a): r''' See :meth:`pypose.Retr` ''' return self.ltype.Retr(self, a)
[docs] def Adj(self, a): r''' See :meth:`pypose.Adj` ''' return self.ltype.Adj(self, a)
[docs] def AdjT(self, a): r''' See :meth:`pypose.AdjT` ''' return self.ltype.AdjT(self, a)
[docs] def Jinvp(self, p): r''' See :meth:`pypose.Jinvp` ''' return self.ltype.Jinvp(self, p)
[docs] def Jr(self): r''' See :meth:`pypose.Jr` ''' return self.ltype.Jr(self)
[docs] def tensor(self) -> torch.Tensor: r''' See :meth:`pypose.tensor` ''' return torch.Tensor.as_subclass(self, torch.Tensor)
[docs] def matrix(self) -> torch.Tensor: r''' See :meth:`pypose.matrix` ''' return self.ltype.matrix(self)
[docs] def translation(self) -> torch.Tensor: r''' See :meth:`pypose.translation` ''' return self.ltype.translation(self)
[docs] def rotation(self): r''' See :meth:`pypose.rotation` ''' return self.ltype.rotation(self)
[docs] def scale(self) -> torch.Tensor: r''' See :meth:`pypose.scale` ''' return self.ltype.scale(self)
[docs] def euler(self, eps=2e-4) -> torch.Tensor: r''' See :meth:`pypose.euler` ''' data = self.rotation().tensor() x, y = data[..., 0], data[..., 1] z, w = data[..., 2], data[..., 3] xx, yy, zz, ww = x*x, y*y, z*z, w*w t0 = 2 * (w * x + y * z) t1 = (ww + zz) - (xx + yy) t2 = 2 * (w * y - z * x) / (xx + yy + zz + ww) t3 = 2 * (w * z + x * y) t4 = (ww + xx) - (yy + zz) roll = torch.atan2(t0, t1) pitch = torch.asin(t2.clamp(-1, 1)) # sigularity when pitch angle ~ +/-pi/2 flag = -1. + eps < t2 < 1. - eps yaw1 = torch.atan2(t3, t4) yaw2 = -2 * torch.sign(t2) * torch.atan2(x, w) yaw = torch.where(flag, yaw1, yaw2) return torch.stack([roll, pitch, yaw], dim=-1)
[docs] def identity_(self): r''' Inplace set the LieTensor to identity. Return: LieTensor: the :obj:`self` LieTensor Note: The translation part, if there is, is set to zeros, while the rotation part is set to identity quaternion. Example: >>> x = pp.randn_SO3(2) >>> x SO3Type LieTensor: tensor([[-0.0724, 0.1970, 0.0022, 0.9777], [ 0.3492, 0.4998, -0.5310, 0.5885]]) >>> x.identity_() SO3Type LieTensor: tensor([[0., 0., 0., 1.], [0., 0., 0., 1.]]) ''' return self.ltype.identity_(self)
[docs] def cumops(self, dim, ops): r""" See :func:`pypose.cumops` """ return self.ltype.cumops(self, dim, ops)
[docs] def cummul(self, dim): r""" See :func:`pypose.cummul` """ return self.ltype.cummul(self, dim)
[docs] def cumprod(self, dim, left = True): r""" See :func:`pypose.cumprod` """ return self.ltype.cumprod(self, dim, left)
[docs] def cumops_(self, dim, ops): r""" Inplace version of :func:`pypose.cumops` """ return self.ltype.cumops_(self, dim, ops)
[docs] def cummul_(self, dim): r""" Inplace version of :func:`pypose.cummul` """ return self.ltype.cummul_(self, dim)
[docs] def cumprod_(self, dim): r""" Inplace version of :func:`pypose.cumprod` """ return self.ltype.cumprod_(self, dim)
[docs]class Parameter(LieTensor, nn.Parameter): r''' A kind of LieTensor that is to be considered a module parameter. Parameters are of :meth:`LieTensor` and :meth:`torch.nn.Parameter`, that have a very special property when used with Modules: when they are assigned as Module attributes they are automatically added to the list of its parameters, and will appear, e.g., in :meth:`parameters()` iterator. Args: data (LieTensor): parameter LieTensor. requires_grad (bool, optional): if the parameter requires gradient. Default: ``True`` Examples: >>> import torch, pypose as pp >>> x = pp.Parameter(pp.randn_SO3(2)) >>> x.Log().sum().backward() >>> x.grad tensor([[0.8590, 1.4069, 0.6261, 0.0000], [1.2869, 1.0748, 0.5385, 0.0000]]) ''' def __init__(self, data, **kwargs): self.ltype = data.ltype def __new__(cls, data=None, requires_grad=True): if data is None: data = torch.tensor([]) return LieTensor._make_subclass(cls, data, requires_grad) def __deepcopy__(self, memo): if id(self) in memo: return memo[id(self)] else: result = type(self)(self.clone(memory_format=torch.preserve_format)) memo[id(self)] = result return result

Docs

Access documentation for PyPose

View Docs

Tutorials

Get started with tutorials and examples

View Tutorials

Get Started

Find resources and how to start using pypose

View Resources