In [None]:
%matplotlib inline


# Cartpole Tutorial


In [None]:
import torch, pypose as pp
import math, matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Preparation
Create class for cart-pole dynamics




In [None]:
class CartPole(pp.module.NLS):
    def __init__(self, dt, length, cartmass, polemass, gravity):
        super().__init__()
        self.tau = dt
        self.length = length
        self.cartmass = cartmass
        self.polemass = polemass
        self.gravity = gravity
        self.polemassLength = self.polemass * self.length
        self.totalMass = self.cartmass + self.polemass

    def state_transition(self, state, input, t = None):
        x, xDot, theta, thetaDot = state
        force = input.squeeze()
        costheta = theta.cos()
        sintheta = theta.sin()

        temp = (force + self.polemassLength * thetaDot**2 * sintheta) / self.totalMass

        thetaAcc = (self.gravity * sintheta - costheta * temp) / \
            (self.length * (4.0 / 3.0 - self.polemass * costheta**2 / self.totalMass))
    
        xAcc = temp - self.polemassLength * thetaAcc * costheta / self.totalMass

        _dstate = torch.stack((xDot, xAcc, thetaDot, thetaAcc))

        return state + _dstate * self.tau

    def observation(self, state, input, t = None):
        return state


def subPlot(ax, x, y, xlabel=None, ylabel=None):
    x = x.detach().cpu().numpy()
    y = y.detach().cpu().numpy()
    ax.plot(x, y)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

## Create parameters for cart pole trajectory



In [None]:
dt = 0.01   # Delta t
len = 1.5   # Length of pole
m_cart = 20 # Mass of cart
m_pole = 10 # Mass of pole
g = 9.81    # Accerleration due to gravity
N = 1000    # Number of time steps

Time and input



In [None]:
time  = torch.arange(0, N, device=device) * dt
input = torch.sin(time)

Initial state



In [None]:
state = torch.zeros(N, 4, dtype=float, device=device)
state[0] = torch.tensor([0, 0, math.pi, 0], dtype=float, device=device)

Create dynamics solver object



In [None]:
model = CartPole(dt, len, m_cart, m_pole, g).to(device)

Calculate trajectory



In [None]:
for i in range(N - 1):
    state[i + 1], _ = model(state[i], input[i])

Jacobian computation - Find jacobians at the last step



In [None]:
model.set_refpoint(state=state[-1,:], input=input[-1], t=time[-1])
vars = ['A', 'B', 'C', 'D', 'c1', 'c2']
[print(v, getattr(model, v)) for v in vars]

Create time plots to show dynamics



In [None]:
f, ax = plt.subplots(nrows=4, sharex=True)
x, xdot, theta, thetadot = state.T
subPlot(ax[0], time, x, ylabel='X')
subPlot(ax[1], time, xdot, ylabel='X dot')
subPlot(ax[2], time, theta, ylabel='Theta')
subPlot(ax[3], time, thetadot, ylabel='Theta dot', xlabel='Time')
plt.show()