Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 603468404
Change-Id: Iae85e56a5afdc0e16bfb4ab57324d0c46ce7b51e
  • Loading branch information
Brax Team authored and btaba committed Feb 1, 2024
1 parent a893224 commit ec075da
Show file tree
Hide file tree
Showing 347 changed files with 892 additions and 130,412 deletions.
2 changes: 1 addition & 1 deletion brax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The Brax Authors.
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion brax/actuator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The Brax Authors.
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion brax/actuator_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The Brax Authors.
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
171 changes: 10 additions & 161 deletions brax/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The Brax Authors.
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,6 +27,7 @@
from jax.tree_util import tree_map
import mujoco
from mujoco import mjx
import numpy as np

# f: free, 1: 1-dof, 2: 2-dof, 3: 3-dof
Q_WIDTHS = {'f': 7, '1': 1, '2': 2, '3': 3}
Expand Down Expand Up @@ -355,152 +356,16 @@ class DoF(Base):
solver_params: jax.Array


@struct.dataclass
class Geometry(Base):
"""A surface or spatial volume with a shape and material properties.
Attributes:
link_idx: Link index to which this Geometry is attached
transform: transform for the geometry frame relative to the link frame, or
relative to the world frame in the case of unparented geometry
friction: resistance encountered when sliding against another geometry
elasticity: bounce/restitution encountered when hitting another geometry
solver_params: (7,) solver parameters (reference, impedance)
"""

link_idx: Optional[jax.Array]
transform: Transform
friction: jax.Array
elasticity: jax.Array
solver_params: jax.Array


@struct.dataclass
class Sphere(Geometry):
"""A sphere.
Attributes:
radius: radius of the sphere
rgba: (4,) the rgba to display in the renderer
"""

radius: jax.Array
rgba: Optional[jax.Array] = None


@struct.dataclass
class Capsule(Geometry):
"""A capsule.
Attributes:
radius: radius of the capsule end
length: distance between the two capsule end centroids
rgba: (4,) the rgba to display in the renderer
"""

radius: jax.Array
length: jax.Array
rgba: Optional[jax.Array] = None


@struct.dataclass
class Box(Geometry):
"""A box.
Attributes:
halfsize: (3,) half sizes for each box side
rgba: (4,) the rgba to display in the renderer
"""

halfsize: jax.Array
rgba: Optional[jax.Array] = None


@struct.dataclass
class Cylinder(Geometry):
"""A cylinder.
Attributes:
radius: (1,) radius of the top and bottom of the cylinder
length: (1,) length of the cylinder
rgba: (4,) the rgba to display in the renderer
"""

radius: jax.Array
length: jax.Array
rgba: Optional[jax.Array] = None


@struct.dataclass
class Plane(Geometry):
"""An infinite plane whose normal points at +z in its coordinate space.
Attributes:
rgba: (4,) the rgba to display in the renderer, currently unused
"""

rgba: Optional[jax.Array] = None


@struct.dataclass
class Mesh(Geometry):
"""A mesh loaded from an OBJ or STL file.
The mesh is expected to be in the counter-clockwise winding order.
Attributes:
vert: (num_verts, 3) spatial coordinates associated with each vertex
face: (num_faces, num_face_vertices) vertices associated with each face
rgba: (4,) the rgba to display in the renderer, currently unused
"""

vert: jax.Array
face: jax.Array
rgba: Optional[jax.Array] = None


@struct.dataclass
class Convex(Geometry):
"""A convex mesh geometry.
Attributes:
vert: (num_verts, 3) spatial coordinates associated with each vertex
face: (num_faces, num_face_vertices) vertices associated with each face
unique_edge: (num_unique, 2) vert index associated with each unique edge
rgba: (4,) the rgba to display in the renderer, currently unused
"""

vert: jax.Array
face: jax.Array
unique_edge: jax.Array
rgba: Optional[jax.Array] = None


@struct.dataclass
class Contact(Base):
class Contact(mjx.Contact, Base):
"""Contact between two geometries.
Attributes:
pos: contact position, or average of the two closest points, in world frame
normal: contact normal on the surface of geometry b
penetration: penetration distance between two geometries. positive means the
two geometries are interpenetrating, negative means they are not
friction: resistance encountered when sliding against another geometry
link_idx: Tuple of link indices participating in contact.
elasticity: bounce/restitution encountered when hitting another geometry
solver_params: (7,) collision constraint solver parameters
link_idx: Tuple of link indices participating in contact. The second part
of the tuple can be None if the second geometry is static.
"""

pos: jax.Array
normal: jax.Array
penetration: jax.Array
friction: jax.Array
# only used by `brax.physics.spring` and `brax.physics.positional`:
link_idx: jax.Array
elasticity: jax.Array
solver_params: jax.Array

link_idx: Tuple[jax.Array, Optional[jax.Array]]


@struct.dataclass
Expand Down Expand Up @@ -547,8 +412,7 @@ class State:
contact: Optional[Contact]


@struct.dataclass
class System(Base):
class System(mjx.Model):
r"""Describes a physical environment: its links, joints and geometries.
Attributes:
Expand All @@ -558,9 +422,9 @@ class System(Base):
density: (1,) density of the medium applied to all links
link: (num_link,) the links in the system
dof: (qd_size,) every degree of freedom for the system
geoms: list of batched geoms grouped by type
actuator: actuators that can be applied to links
init_q: (q_size,) initial q position for the system
elasticity: bounce/restitution encountered when hitting another geometry
vel_damping: (1,) linear vel damping applied to each body.
ang_damping: (1,) angular vel damping applied to each body.
baumgarte_erp: how aggressively interpenetrating bodies should push away\
Expand All @@ -572,9 +436,6 @@ class System(Base):
collide_scale: fraction of position based collide update to apply
enable_fluid: (1,) enables or disables fluid forces based on the
default viscosity and density parameters provided in the XML
geom_masks: 64-bit mask determines whether two geoms will be contact tested.
lower 32 bits are type, upper 32 bits are affinity. two geoms
a, b will be contact tested if a.type & b.affinity != 0
link_names: (num_link,) link names
link_types: (num_link,) string specifying the joint type of each link
valid types are:
Expand All @@ -587,6 +448,7 @@ class System(Base):
matrix_inv_iterations: maximum number of iterations of the matrix inverse
solver_iterations: maximum number of iterations of the constraint solver
solver_maxls: maximum number of line searches of the constraint solver
mj_model: mujoco.MjModel that was used to build this brax System
"""

dt: jax.Array
Expand All @@ -595,10 +457,10 @@ class System(Base):
density: Union[float, jax.Array]
link: Link
dof: DoF
geoms: List[Geometry]
actuator: Actuator
init_q: jax.Array
# only used in `brax.physics.spring` and `brax.physics.positional`:
elasticity: jax.Array
vel_damping: Union[float, jax.Array]
ang_damping: Union[float, jax.Array]
baumgarte_erp: Union[float, jax.Array]
Expand All @@ -610,15 +472,14 @@ class System(Base):
collide_scale: Union[float, jax.Array]
# non-pytree nodes
enable_fluid: bool = struct.field(pytree_node=False)
geom_masks: List[int] = struct.field(pytree_node=False)
link_names: List[str] = struct.field(pytree_node=False)
link_types: str = struct.field(pytree_node=False)
link_parents: Tuple[int, ...] = struct.field(pytree_node=False)
# only used in `brax.physics.generalized`:
matrix_inv_iterations: int = struct.field(pytree_node=False)
solver_iterations: int = struct.field(pytree_node=False)
solver_maxls: int = struct.field(pytree_node=False)
_model: mujoco.MjModel = struct.field(pytree_node=False, default=None)
mj_model: mujoco.MjModel = struct.field(pytree_node=False, default=None)

def num_links(self) -> int:
"""Returns the number of links in the system."""
Expand Down Expand Up @@ -680,18 +541,6 @@ def act_size(self) -> int:
"""Returns the act dimension for the system."""
return self.actuator.q_id.shape[0]

def set_model(self, model: mujoco.MjModel):
"""Sets the source MuJoCo model of this System."""
object.__setattr__(self, '_model', model)

def get_model(self) -> mujoco.MjModel:
"""Returns the source MuJoCo model of this System."""
return self._model

def get_mjx_model(self) -> mjx.Model:
"""Returns an MJX model of this System."""
return mjx.put_model(getattr(self, '_model'))


# below are some operation dispatch derivations

Expand Down
23 changes: 6 additions & 17 deletions brax/base_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The Brax Authors.
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,24 +36,13 @@ def test_write_mass_value(self):
sys_w = sys.tree_replace({'link.inertia.mass': 1.0})
self.assertEqual(sys_w.link.inertia.mass, 1.0)

def test_write_friction_value(self):
def test_write_array(self):
sys = test_utils.load_fixture('ant.xml')
sys_w = sys.tree_replace({'geoms.friction': None})
self.assertTrue(np.all([g.friction is None for g in sys_w.geoms]))
np.random.seed(0)

def test_write_friction_array(self):
sys = test_utils.load_fixture('ant.xml')
rng = jax.random.PRNGKey(0)
shape = [g.friction.shape[0] for g in sys.geoms]
expected_friction = []
for s in shape:
rng, key = jax.random.split(rng)
expected_friction.append(jax.random.uniform(key, (s,)))

sys_w = sys.tree_replace({'geoms.friction': expected_friction})
friction = np.concatenate([g.friction for g in sys_w.geoms])
expected_friction = np.concatenate(expected_friction)
np.testing.assert_array_equal(friction, expected_friction)
expected = np.random.uniform(sys.elasticity.shape)
sys_w = sys.tree_replace({'elasticity': expected})
np.testing.assert_array_equal(sys_w.elasticity, expected)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion brax/com.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The Brax Authors.
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion brax/com_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The Brax Authors.
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
67 changes: 67 additions & 0 deletions brax/contact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=g-multiple-import
"""Calculations for generating contacts."""

from typing import Optional
from brax import math
from brax.base import Contact
from brax.base import System
from brax.base import Transform
import jax
from jax import numpy as jp
from mujoco import mjx


def get(sys: System, x: Transform) -> Optional[Contact]:
"""Calculates contacts.
Args:
sys: system defining the kinematic tree and other properties
x: link transforms in world frame
Returns:
Contact pytree
"""
ncon = mjx.ncon(sys)
if not ncon:
return None

@jax.vmap
def local_to_global(pos1, quat1, pos2, quat2):
pos = pos1 + math.rotate(pos2, quat1)
mat = math.quat_to_3x3(math.quat_mul(quat1, quat2))
return pos, mat

x = x.concatenate(Transform.zero((1,)))
xpos = x.pos[sys.geom_bodyid - 1]
xquat = x.rot[sys.geom_bodyid - 1]
geom_xpos, geom_xmat = local_to_global(
xpos, xquat, sys.geom_pos, sys.geom_quat
)

# pytype: disable=wrong-arg-types
d = mjx.make_data(sys).replace(geom_xpos=geom_xpos, geom_xmat=geom_xmat)
d = mjx.collision(sys, d)
# pytype: enable=wrong-arg-types

c = d.contact
elasticity = (sys.elasticity[c.geom1] + sys.elasticity[c.geom2]) * 0.5

body1 = jp.array(sys.geom_bodyid)[c.geom1] - 1
body2 = jp.array(sys.geom_bodyid)[c.geom2] - 1
link_idx = (body1, body2)

return Contact(elasticity=elasticity, link_idx=link_idx, **c.__dict__)
Loading

0 comments on commit ec075da

Please sign in to comment.