diff --git a/newton/_src/solvers/featherstone/kernels.py b/newton/_src/solvers/featherstone/kernels_body.py similarity index 99% rename from newton/_src/solvers/featherstone/kernels.py rename to newton/_src/solvers/featherstone/kernels_body.py index 2f5afaa76..b914bbe3e 100644 --- a/newton/_src/solvers/featherstone/kernels.py +++ b/newton/_src/solvers/featherstone/kernels_body.py @@ -358,7 +358,7 @@ def jcalc_tau( ang_axis_count: int, body_f_s: wp.spatial_vector, # outputs - tau: wp.array(dtype=float), + joint_tau: wp.array(dtype=float), ): if type == JointType.BALL: # target_ke = joint_target_ke[dof_start] @@ -370,15 +370,17 @@ def jcalc_tau( # w = joint_qd[dof_start + i] # r = joint_q[coord_start + i] - tau[dof_start + i] = -wp.dot(S_s, body_f_s) + joint_f[dof_start + i] - # tau -= w * target_kd - r * target_ke + t = -wp.dot(S_s, body_f_s) + joint_f[dof_start + i] + # t -= w * target_kd - r * target_ke + joint_tau[dof_start + i] = t return if type == JointType.FREE or type == JointType.DISTANCE: for i in range(6): S_s = joint_S_s[dof_start + i] - tau[dof_start + i] = -wp.dot(S_s, body_f_s) + joint_f[dof_start + i] + t = -wp.dot(S_s, body_f_s) + joint_f[dof_start + i] + joint_tau[dof_start + i] = t return @@ -406,7 +408,7 @@ def jcalc_tau( # total torque / force on the joint t = -wp.dot(S_s, body_f_s) + drive_f + joint_f[j] - tau[j] = t + joint_tau[j] = t return @@ -826,7 +828,7 @@ def eval_rigid_tau( body_f_ext: wp.array(dtype=wp.spatial_vector), # outputs body_ft_s: wp.array(dtype=wp.spatial_vector), - tau: wp.array(dtype=float), + joint_tau: wp.array(dtype=float), ): # one thread per-articulation index = wp.tid() @@ -874,7 +876,7 @@ def eval_rigid_tau( lin_axis_count, ang_axis_count, f_s, - tau, + joint_tau, ) # update parent forces, todo: check that this is valid for the backwards pass diff --git a/newton/_src/solvers/featherstone/solver_featherstone.py b/newton/_src/solvers/featherstone/solver_featherstone.py index d7470810e..0764b93b9 100644 --- a/newton/_src/solvers/featherstone/solver_featherstone.py +++ b/newton/_src/solvers/featherstone/solver_featherstone.py @@ -15,12 +15,13 @@ import warp as wp -from ...core.types import override +from ...core.types import Devicelike, override from ...sim import Contacts, Control, Model, State, eval_fk from ..semi_implicit.kernels_contact import ( eval_body_contact, eval_particle_body_contact_forces, eval_particle_contact_forces, + eval_triangle_contact_forces, ) from ..semi_implicit.kernels_muscle import ( eval_muscle_forces, @@ -32,7 +33,7 @@ eval_triangle_forces, ) from ..solver import SolverBase -from .kernels import ( +from .kernels_body import ( compute_com_transforms, compute_spatial_inertia, create_inertia_matrix_cholesky_kernel, @@ -49,6 +50,237 @@ ) +class StateFeatherstone: + def __init__(self): + self.M: wp.array | None = None + self.J: wp.array | None = None + self.P: wp.array | None = None + self.H: wp.array | None = None + self.L: wp.array | None = None + + self.joint_qdd: wp.array | None = None + self.joint_tau: wp.array | None = None + + self.joint_solve_tmp: wp.array | None = None + self.joint_S_s: wp.array | None = None + + self.body_q_com: wp.array | None = None + self.body_I_s: wp.array | None = None + self.body_v_s: wp.array | None = None + self.body_a_s: wp.array | None = None + self.body_f_s: wp.array | None = None + self.body_ft_s: wp.array | None = None + + +class ModelFeatherstone: + def __init__(self, requires_grad: bool = False): + self.requires_grad = requires_grad + self.device = wp.get_device() + + self.J_size: int = 0 + self.M_size: int = 0 + self.H_size: int = 0 + + self.tile_joint_count: int = 0 + self.tile_dof_count: int = 0 + + self.articulation_J_start: wp.array | None = None + self.articulation_M_start: wp.array | None = None + self.articulation_H_start: wp.array | None = None + + self.articulation_M_rows: wp.array | None = None + self.articulation_H_rows: wp.array | None = None + self.articulation_J_rows: wp.array | None = None + self.articulation_J_cols: wp.array | None = None + + self.articulation_dof_start: wp.array | None = None + self.articulation_coord_start: wp.array | None = None + + self.M: wp.array | None = None + self.J: wp.array | None = None + self.P: wp.array | None = None + self.H: wp.array | None = None + self.L: wp.array | None = None + + self.joint_qdd: wp.array | None = None + self.joint_tau: wp.array | None = None + + self.joint_solve_tmp: wp.array | None = None + self.joint_S_s: wp.array | None = None + + self.body_q_com: wp.array | None = None + self.body_I_s: wp.array | None = None + self.body_v_s: wp.array | None = None + self.body_a_s: wp.array | None = None + self.body_f_s: wp.array | None = None + self.body_ft_s: wp.array | None = None + + def state_custom(self, model: Model, requires_grad: bool | None = None) -> StateFeatherstone: + _s = StateFeatherstone() + if requires_grad is None: + requires_grad = self.requires_grad + + if model.joint_count: + _s.M = wp.zeros_like(self.M, requires_grad=requires_grad) + _s.J = wp.zeros_like(self.J, requires_grad=requires_grad) + _s.P = wp.zeros_like(self.P, requires_grad=requires_grad) + _s.H = wp.zeros_like(self.H, requires_grad=requires_grad) + _s.L = wp.zeros_like(self.L) + + if model.body_count: + _s.joint_qdd = wp.zeros_like(self.joint_qdd, requires_grad=requires_grad) + _s.joint_tau = wp.zeros_like(self.joint_tau, requires_grad=requires_grad) + if requires_grad: + _s.joint_solve_tmp = wp.zeros_like(self.joint_solve_tmp, requires_grad=requires_grad) + else: + _s.joint_solve_tmp = None + _s.joint_S_s = wp.zeros_like(self.joint_S_s, requires_grad=requires_grad) + + _s.body_q_com = wp.zeros_like(self.body_q_com, requires_grad=requires_grad) + _s.body_I_s = wp.zeros_like(self.body_I_s, requires_grad=requires_grad) + _s.body_v_s = wp.zeros_like(self.body_v_s, requires_grad=requires_grad) + _s.body_a_s = wp.zeros_like(self.body_a_s, requires_grad=requires_grad) + _s.body_f_s = wp.zeros_like(self.body_f_s, requires_grad=requires_grad) + _s.body_ft_s = wp.zeros_like(self.body_ft_s, requires_grad=requires_grad) + + return _s + + +class ModelBuilderFeatherstone: + def __init__(self, use_tile_gemm: bool = False): + self.use_tile_gemm = use_tile_gemm + + def _compute_articulation_indices(self, model: Model, _model: ModelFeatherstone, use_tile_gemm: bool = False): + device = model.device + with wp.ScopedDevice(device): + # calculate total size and offsets of Jacobian and mass matrices for entire system + if model.joint_count: + _model.J_size = 0 + _model.M_size = 0 + _model.H_size = 0 + + articulation_J_start = [] + articulation_M_start = [] + articulation_H_start = [] + + articulation_M_rows = [] + articulation_H_rows = [] + articulation_J_rows = [] + articulation_J_cols = [] + + articulation_dof_start = [] + articulation_coord_start = [] + + articulation_start = model.articulation_start.numpy() + joint_q_start = model.joint_q_start.numpy() + joint_qd_start = model.joint_qd_start.numpy() + + for i in range(model.articulation_count): + first_joint = articulation_start[i] + last_joint = articulation_start[i + 1] + + first_coord = joint_q_start[first_joint] + + first_dof = joint_qd_start[first_joint] + last_dof = joint_qd_start[last_joint] + + joint_count = last_joint - first_joint + dof_count = last_dof - first_dof + + articulation_J_start.append(_model.J_size) + articulation_M_start.append(_model.M_size) + articulation_H_start.append(_model.H_size) + articulation_dof_start.append(first_dof) + articulation_coord_start.append(first_coord) + + # bit of data duplication here, but will leave it as such for clarity + articulation_M_rows.append(joint_count * 6) + articulation_H_rows.append(dof_count) + articulation_J_rows.append(joint_count * 6) + articulation_J_cols.append(dof_count) + + if use_tile_gemm: + # store the joint and dof count assuming all + # articulations have the same structure + _model.tile_joint_count = int(joint_count) + _model.tile_dof_count = int(dof_count) + + _model.J_size += 6 * joint_count * dof_count + _model.M_size += 6 * joint_count * 6 * joint_count + _model.H_size += dof_count * dof_count + + # matrix offsets for batched gemm + _model.articulation_J_start = wp.array(articulation_J_start, dtype=wp.int32) + _model.articulation_M_start = wp.array(articulation_M_start, dtype=wp.int32) + _model.articulation_H_start = wp.array(articulation_H_start, dtype=wp.int32) + + _model.articulation_M_rows = wp.array(articulation_M_rows, dtype=wp.int32) + _model.articulation_H_rows = wp.array(articulation_H_rows, dtype=wp.int32) + _model.articulation_J_rows = wp.array(articulation_J_rows, dtype=wp.int32) + _model.articulation_J_cols = wp.array(articulation_J_cols, dtype=wp.int32) + + _model.articulation_dof_start = wp.array(articulation_dof_start, dtype=wp.int32) + _model.articulation_coord_start = wp.array(articulation_coord_start, dtype=wp.int32) + + def finalize_custom(self, model: Model, device: Devicelike | None = None, requires_grad: bool = False): + with wp.ScopedDevice(device): + _model = ModelFeatherstone(requires_grad=model.requires_grad) + + self._compute_articulation_indices(model, _model, use_tile_gemm=self.use_tile_gemm) + + if model.body_count: + _model.body_I_m = wp.empty((model.body_count,), dtype=wp.spatial_matrix, requires_grad=requires_grad) + wp.launch( + compute_spatial_inertia, + model.body_count, + inputs=[model.body_inertia, model.body_mass], + outputs=[_model.body_I_m], + device=model.device, + ) + _model.body_X_com = wp.empty((model.body_count,), dtype=wp.transform, requires_grad=requires_grad) + wp.launch( + compute_com_transforms, + model.body_count, + inputs=[model.body_com], + outputs=[_model.body_X_com], + device=model.device, + ) + + # allocate mass, Jacobian matrices + if model.joint_count: + # system matrices + _model.M = wp.zeros((_model.M_size,), dtype=wp.float32, requires_grad=requires_grad) + _model.J = wp.zeros((_model.J_size,), dtype=wp.float32, requires_grad=requires_grad) + _model.P = wp.empty_like(_model.J, requires_grad=requires_grad) + _model.H = wp.empty((_model.H_size,), dtype=wp.float32, requires_grad=requires_grad) + # zero since only upper triangle is set which can trigger NaN detection + _model.L = wp.zeros_like(_model.H) + + # allocate other auxiliary variables that vary with state + if model.body_count: + # joints + _model.joint_qdd = wp.zeros_like(model.joint_qd, requires_grad=requires_grad) + _model.joint_tau = wp.empty_like(model.joint_qd, requires_grad=requires_grad) + if requires_grad: + # used in the custom grad implementation of eval_dense_solve_batched + _model.joint_solve_tmp = wp.zeros_like(model.joint_qd, requires_grad=requires_grad) + else: + _model.joint_solve_tmp = None + _model.joint_S_s = wp.empty( + (model.joint_dof_count,), dtype=wp.spatial_vector, requires_grad=requires_grad + ) + + # derived rigid body data (maximal coordinates) + _model.body_q_com = wp.empty_like(model.body_q, requires_grad=requires_grad) + _model.body_I_s = wp.empty((model.body_count,), dtype=wp.spatial_matrix, requires_grad=requires_grad) + _model.body_v_s = wp.empty((model.body_count,), dtype=wp.spatial_vector, requires_grad=requires_grad) + _model.body_a_s = wp.empty((model.body_count,), dtype=wp.spatial_vector, requires_grad=requires_grad) + _model.body_f_s = wp.zeros((model.body_count,), dtype=wp.spatial_vector, requires_grad=requires_grad) + _model.body_ft_s = wp.zeros((model.body_count,), dtype=wp.spatial_vector, requires_grad=requires_grad) + + return _model + + class SolverFeatherstone(SolverBase): """A semi-implicit integrator using symplectic Euler that operates on reduced (also called generalized) coordinates to simulate articulated rigid body dynamics @@ -100,6 +332,7 @@ def __init__( friction_smoothing: float = 1.0, use_tile_gemm: bool = False, fuse_cholesky: bool = True, + enable_tri_contact: bool = True, ): """ Args: @@ -109,6 +342,7 @@ def __init__( friction_smoothing (float, optional): The delta value for the Huber norm (see :func:`warp.math.norm_huber`) used for the friction velocity normalization. Defaults to 1.0. use_tile_gemm (bool, optional): Whether to use operators from Warp's Tile API to solve for joint accelerations. Defaults to False. fuse_cholesky (bool, optional): Whether to fuse the Cholesky decomposition into the inertia matrix evaluation kernel when using the Tile API. Only used if `use_tile_gemm` is true. Defaults to True. + enable_tri_contact (bool, optional): Enable triangle contact. Defaults to True. """ super().__init__(model) @@ -117,193 +351,39 @@ def __init__( self.friction_smoothing = friction_smoothing self.use_tile_gemm = use_tile_gemm self.fuse_cholesky = fuse_cholesky + self.enable_tri_contact = enable_tri_contact self._step = 0 - self.compute_articulation_indices(model) - self.allocate_model_aux_vars(model) + # custom model attributes for Featherstone + _builder = ModelBuilderFeatherstone(use_tile_gemm=self.use_tile_gemm) + _model = _builder.finalize_custom(model, device=model.device, requires_grad=model.requires_grad) + model.featherstone = _model if self.use_tile_gemm: # create a custom kernel to evaluate the system matrix for this type if self.fuse_cholesky: self.eval_inertia_matrix_cholesky_kernel = create_inertia_matrix_cholesky_kernel( - int(self.joint_count), int(self.dof_count) + _model.tile_joint_count, _model.tile_dof_count ) else: self.eval_inertia_matrix_kernel = create_inertia_matrix_kernel( - int(self.joint_count), int(self.dof_count) + _model.tile_joint_count, _model.tile_dof_count ) # ensure matrix is reloaded since otherwise an unload can happen during graph capture # todo: should not be necessary? wp.load_module(device=wp.get_device()) - def compute_articulation_indices(self, model): - # calculate total size and offsets of Jacobian and mass matrices for entire system - if model.joint_count: - self.J_size = 0 - self.M_size = 0 - self.H_size = 0 - - articulation_J_start = [] - articulation_M_start = [] - articulation_H_start = [] - - articulation_M_rows = [] - articulation_H_rows = [] - articulation_J_rows = [] - articulation_J_cols = [] - - articulation_dof_start = [] - articulation_coord_start = [] - - articulation_start = model.articulation_start.numpy() - joint_q_start = model.joint_q_start.numpy() - joint_qd_start = model.joint_qd_start.numpy() - - for i in range(model.articulation_count): - first_joint = articulation_start[i] - last_joint = articulation_start[i + 1] - - first_coord = joint_q_start[first_joint] - - first_dof = joint_qd_start[first_joint] - last_dof = joint_qd_start[last_joint] - - joint_count = last_joint - first_joint - dof_count = last_dof - first_dof - - articulation_J_start.append(self.J_size) - articulation_M_start.append(self.M_size) - articulation_H_start.append(self.H_size) - articulation_dof_start.append(first_dof) - articulation_coord_start.append(first_coord) - - # bit of data duplication here, but will leave it as such for clarity - articulation_M_rows.append(joint_count * 6) - articulation_H_rows.append(dof_count) - articulation_J_rows.append(joint_count * 6) - articulation_J_cols.append(dof_count) - - if self.use_tile_gemm: - # store the joint and dof count assuming all - # articulations have the same structure - self.joint_count = joint_count - self.dof_count = dof_count - - self.J_size += 6 * joint_count * dof_count - self.M_size += 6 * joint_count * 6 * joint_count - self.H_size += dof_count * dof_count - - # matrix offsets for batched gemm - self.articulation_J_start = wp.array(articulation_J_start, dtype=wp.int32, device=model.device) - self.articulation_M_start = wp.array(articulation_M_start, dtype=wp.int32, device=model.device) - self.articulation_H_start = wp.array(articulation_H_start, dtype=wp.int32, device=model.device) - - self.articulation_M_rows = wp.array(articulation_M_rows, dtype=wp.int32, device=model.device) - self.articulation_H_rows = wp.array(articulation_H_rows, dtype=wp.int32, device=model.device) - self.articulation_J_rows = wp.array(articulation_J_rows, dtype=wp.int32, device=model.device) - self.articulation_J_cols = wp.array(articulation_J_cols, dtype=wp.int32, device=model.device) - - self.articulation_dof_start = wp.array(articulation_dof_start, dtype=wp.int32, device=model.device) - self.articulation_coord_start = wp.array(articulation_coord_start, dtype=wp.int32, device=model.device) - - def allocate_model_aux_vars(self, model): - # allocate mass, Jacobian matrices, and other auxiliary variables pertaining to the model - if model.joint_count: - # system matrices - self.M = wp.zeros((self.M_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad) - self.J = wp.zeros((self.J_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad) - self.P = wp.empty_like(self.J, requires_grad=model.requires_grad) - self.H = wp.empty((self.H_size,), dtype=wp.float32, device=model.device, requires_grad=model.requires_grad) - - # zero since only upper triangle is set which can trigger NaN detection - self.L = wp.zeros_like(self.H) - - if model.body_count: - self.body_I_m = wp.empty( - (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=model.requires_grad - ) - wp.launch( - compute_spatial_inertia, - model.body_count, - inputs=[model.body_inertia, model.body_mass], - outputs=[self.body_I_m], - device=model.device, - ) - self.body_X_com = wp.empty( - (model.body_count,), dtype=wp.transform, device=model.device, requires_grad=model.requires_grad - ) - wp.launch( - compute_com_transforms, - model.body_count, - inputs=[model.body_com], - outputs=[self.body_X_com], - device=model.device, - ) - - def allocate_state_aux_vars(self, model, target, requires_grad): - # allocate auxiliary variables that vary with state - if model.body_count: - # joints - target.joint_qdd = wp.zeros_like(model.joint_qd, requires_grad=requires_grad) - target.joint_tau = wp.empty_like(model.joint_qd, requires_grad=requires_grad) - if requires_grad: - # used in the custom grad implementation of eval_dense_solve_batched - target.joint_solve_tmp = wp.zeros_like(model.joint_qd, requires_grad=True) - else: - target.joint_solve_tmp = None - target.joint_S_s = wp.empty( - (model.joint_dof_count,), - dtype=wp.spatial_vector, - device=model.device, - requires_grad=requires_grad, - ) - - # derived rigid body data (maximal coordinates) - target.body_q_com = wp.empty_like(model.body_q, requires_grad=requires_grad) - target.body_I_s = wp.empty( - (model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=requires_grad - ) - target.body_v_s = wp.empty( - (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad - ) - target.body_a_s = wp.empty( - (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad - ) - target.body_f_s = wp.zeros( - (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad - ) - target.body_ft_s = wp.zeros( - (model.body_count,), dtype=wp.spatial_vector, device=model.device, requires_grad=requires_grad - ) - - target._featherstone_augmented = True - @override def step( self, state_in: State, state_out: State, - control: Control, - contacts: Contacts, + control: Control | None, + contacts: Contacts | None, dt: float, ): - requires_grad = state_in.requires_grad - - # optionally create dynamical auxiliary variables - if requires_grad: - state_aug = state_out - else: - state_aug = self - - model = self.model - - if not getattr(state_aug, "_featherstone_augmented", False): - self.allocate_state_aux_vars(model, state_aug, requires_grad) - if control is None: - control = model.control(clone_variables=False) - with wp.ScopedTimer("simulate", False): particle_f = None body_f = None @@ -314,6 +394,25 @@ def step( if state_in.body_count: body_f = state_in.body_f + model = self.model + _model = model.featherstone + + # optionally create dynamical auxiliary variables + requires_grad = state_in.requires_grad + if requires_grad: + if not hasattr(state_in, "featherstone"): + state_in.featherstone = _model.state_custom(model, requires_grad) + if not hasattr(state_out, "featherstone"): + state_out.featherstone = _model.state_custom(model, requires_grad) + _state_in = state_in.featherstone + _state_out = state_out.featherstone + else: + _state_in = _model + _state_out = _model + + if control is None: + control = model.control(clone_variables=False) + # damped springs eval_spring_forces(model, state_in, particle_f) @@ -326,16 +425,20 @@ def step( # tetrahedral FEM eval_tetrahedra_forces(model, state_in, control, particle_f) + # muscles + if False: + eval_muscle_forces(model, state_in, control, body_f) + # particle-particle interactions eval_particle_contact_forces(model, state_in, particle_f) + # triangle/triangle contacts + if self.enable_tri_contact: + eval_triangle_contact_forces(model, state_in, particle_f) + # particle shape contact eval_particle_body_contact_forces(model, state_in, contacts, particle_f, body_f, body_f_in_world_frame=True) - # muscles - if False: - eval_muscle_forces(model, state_in, control, body_f) - # ---------------------------- # articulations @@ -354,19 +457,16 @@ def step( state_in.joint_q, model.joint_X_p, model.joint_X_c, - self.body_X_com, + _model.body_X_com, model.joint_axis, model.joint_dof_dim, ], - outputs=[state_in.body_q, state_aug.body_q_com], + outputs=[state_in.body_q, _state_in.body_q_com], device=model.device, ) - # print("body_X_sc:") - # print(state_in.body_q.numpy()) - # evaluate joint inertias, motion vectors, and forces - state_aug.body_f_s.zero_() + _state_in.body_f_s.zero_() wp.launch( eval_rigid_id, @@ -380,18 +480,18 @@ def step( state_in.joint_qd, model.joint_axis, model.joint_dof_dim, - self.body_I_m, + _model.body_I_m, state_in.body_q, - state_aug.body_q_com, + _state_in.body_q_com, model.joint_X_p, model.gravity, ], outputs=[ - state_aug.joint_S_s, - state_aug.body_I_s, - state_aug.body_v_s, - state_aug.body_f_s, - state_aug.body_a_s, + _state_in.joint_S_s, + _state_in.body_I_s, + _state_in.body_v_s, + _state_in.body_f_s, + _state_in.body_a_s, ], device=model.device, ) @@ -402,7 +502,7 @@ def step( dim=contacts.rigid_contact_max, inputs=[ state_in.body_q, - state_aug.body_v_s, + _state_in.body_v_s, model.body_com, model.shape_material_ke, model.shape_material_kd, @@ -427,7 +527,8 @@ def step( if model.articulation_count: # evaluate joint torques - state_aug.body_ft_s.zero_() + _state_in.body_ft_s.zero_() + _state_in.joint_tau.zero_() wp.launch( eval_rigid_tau, dim=model.articulation_count, @@ -450,24 +551,17 @@ def step( model.joint_limit_upper, model.joint_limit_ke, model.joint_limit_kd, - state_aug.joint_S_s, - state_aug.body_f_s, + _state_in.joint_S_s, + _state_in.body_f_s, body_f, ], outputs=[ - state_aug.body_ft_s, - state_aug.joint_tau, + _state_in.body_ft_s, + _state_in.joint_tau, ], device=model.device, ) - # print("joint_tau:") - # print(state_aug.joint_tau.numpy()) - # print("body_q:") - # print(state_in.body_q.numpy()) - # print("body_qd:") - # print(state_in.body_qd.numpy()) - if self._step % self.update_mass_matrix_interval == 0: # build J wp.launch( @@ -475,12 +569,12 @@ def step( dim=model.articulation_count, inputs=[ model.articulation_start, - self.articulation_J_start, + _model.articulation_J_start, model.joint_ancestor, model.joint_qd_start, - state_aug.joint_S_s, + _state_in.joint_S_s, ], - outputs=[self.J], + outputs=[_state_out.J], device=model.device, ) @@ -490,20 +584,22 @@ def step( dim=model.articulation_count, inputs=[ model.articulation_start, - self.articulation_M_start, - state_aug.body_I_s, + _model.articulation_M_start, + _state_in.body_I_s, ], - outputs=[self.M], + outputs=[_state_out.M], device=model.device, ) if self.use_tile_gemm: # reshape arrays - M_tiled = self.M.reshape((-1, 6 * self.joint_count, 6 * self.joint_count)) - J_tiled = self.J.reshape((-1, 6 * self.joint_count, self.dof_count)) - R_tiled = model.joint_armature.reshape((-1, self.dof_count)) - H_tiled = self.H.reshape((-1, self.dof_count, self.dof_count)) - L_tiled = self.L.reshape((-1, self.dof_count, self.dof_count)) + M_tiled = _state_out.M.reshape( + (-1, 6 * _model.tile_joint_count, 6 * _model.tile_joint_count) + ) + J_tiled = _state_out.J.reshape((-1, 6 * _model.tile_joint_count, _model.tile_dof_count)) + H_tiled = _state_out.H.reshape((-1, _model.tile_dof_count, _model.tile_dof_count)) + L_tiled = _state_out.L.reshape((-1, _model.tile_dof_count, _model.tile_dof_count)) + R_tiled = model.joint_armature.reshape((-1, _model.tile_dof_count)) assert H_tiled.shape == (model.articulation_count, 18, 18) assert L_tiled.shape == (model.articulation_count, 18, 18) assert R_tiled.shape == (model.articulation_count, 18) @@ -532,12 +628,12 @@ def step( eval_dense_cholesky_batched, dim=model.articulation_count, inputs=[ - self.articulation_H_start, - self.articulation_H_rows, - self.H, + _model.articulation_H_start, + _model.articulation_H_rows, + _state_out.H, model.joint_armature, ], - outputs=[self.L], + outputs=[_state_out.L], device=model.device, ) @@ -558,19 +654,19 @@ def step( eval_dense_gemm_batched, dim=model.articulation_count, inputs=[ - self.articulation_M_rows, - self.articulation_J_cols, - self.articulation_J_rows, + _model.articulation_M_rows, + _model.articulation_J_cols, + _model.articulation_J_rows, False, False, - self.articulation_M_start, - self.articulation_J_start, + _model.articulation_M_start, + _model.articulation_J_start, # P start is the same as J start since it has the same dims as J - self.articulation_J_start, - self.M, - self.J, + _model.articulation_J_start, + _state_out.M, + _state_out.J, ], - outputs=[self.P], + outputs=[_state_out.P], device=model.device, ) @@ -579,20 +675,20 @@ def step( eval_dense_gemm_batched, dim=model.articulation_count, inputs=[ - self.articulation_J_cols, - self.articulation_J_cols, + _model.articulation_J_cols, + _model.articulation_J_cols, # P rows is the same as J rows - self.articulation_J_rows, + _model.articulation_J_rows, True, False, - self.articulation_J_start, + _model.articulation_J_start, # P start is the same as J start since it has the same dims as J - self.articulation_J_start, - self.articulation_H_start, - self.J, - self.P, + _model.articulation_J_start, + _model.articulation_H_start, + _model.J, + _model.P, ], - outputs=[self.H], + outputs=[_model.H], device=model.device, ) @@ -601,46 +697,38 @@ def step( eval_dense_cholesky_batched, dim=model.articulation_count, inputs=[ - self.articulation_H_start, - self.articulation_H_rows, - self.H, + _model.articulation_H_start, + _model.articulation_H_rows, + _state_out.H, model.joint_armature, ], - outputs=[self.L], + outputs=[_state_out.L], device=model.device, ) - - # print("joint_target:") - # print(control.joint_target.numpy()) - # print("joint_tau:") - # print(state_aug.joint_tau.numpy()) - # print("H:") - # print(self.H.numpy()) - # print("L:") - # print(self.L.numpy()) + else: + if requires_grad: + wp.copy(_state_out.H, _state_in.H) + wp.copy(_state_out.L, _state_in.L) # solve for qdd - state_aug.joint_qdd.zero_() + _state_in.joint_qdd.zero_() wp.launch( eval_dense_solve_batched, dim=model.articulation_count, inputs=[ - self.articulation_H_start, - self.articulation_H_rows, - self.articulation_dof_start, - self.H, - self.L, - state_aug.joint_tau, + _model.articulation_H_start, + _model.articulation_H_rows, + _model.articulation_dof_start, + _model.H, + _model.L, + _state_in.joint_tau, ], outputs=[ - state_aug.joint_qdd, - state_aug.joint_solve_tmp, + _state_in.joint_qdd, + _state_in.joint_solve_tmp, ], device=model.device, ) - # print("joint_qdd:") - # print(state_aug.joint_qdd.numpy()) - # print("\n\n") # ------------------------------------- # integrate bodies @@ -656,7 +744,7 @@ def step( model.joint_dof_dim, state_in.joint_q, state_in.joint_qd, - state_aug.joint_qdd, + _state_in.joint_qdd, dt, ], outputs=[state_out.joint_q, state_out.joint_qd],