Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

black formatting #39

Merged
merged 6 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions exts/stride.simulator/stride/simulator/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ class Backend:
"""

def __init__(self):
"""Initialize the Backend class
"""
"""Initialize the Backend class"""
self._vehicle = None

@property
Expand Down Expand Up @@ -70,16 +69,13 @@ def update(self, dt: float):
pass

def start(self):
"""Method that when implemented should handle the begining of the simulation of vehicle
"""
"""Method that when implemented should handle the begining of the simulation of vehicle"""
pass

def stop(self):
"""Method that when implemented should handle the stopping of the simulation of vehicle
"""
"""Method that when implemented should handle the stopping of the simulation of vehicle"""
pass

def reset(self):
"""Method that when implemented, should handle the reset of the vehicle simulation to its original state
"""
"""Method that when implemented, should handle the reset of the vehicle simulation to its original state"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def __init__(self, config=None):
if config is None:
config = {}
else:
assert isinstance(config, dict), "The config parameter must be a dictionary."
assert isinstance(
config, dict
), "The config parameter must be a dictionary."

self.vehicle_id = config.get("vehicle_id", 0)
self.update_rate: float = config.get("update_rate", 250.0) # [Hz]
Expand Down Expand Up @@ -73,7 +75,9 @@ def update_sensor(self, sensor_type: str, data):
elif sensor_type == "Lidar":
self.update_lidar_data(data)
else:
carb.log_warn(f"Sensor type {sensor_type} is not supported by the ROS2 backend.")
carb.log_warn(
f"Sensor type {sensor_type} is not supported by the ROS2 backend."
)
pass
# TODO: Add support for other sensors

Expand Down
51 changes: 37 additions & 14 deletions exts/stride.simulator/stride/simulator/backends/ros2_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import os
import carb
import numpy as np
Expand All @@ -16,15 +15,25 @@
import rclpy # pylint: disable=wrong-import-position
from std_msgs.msg import Float64 # pylint: disable=unused-import, wrong-import-position
from sensor_msgs.msg import ( # pylint: disable=unused-import, wrong-import-position
Imu, PointCloud2, PointField, MagneticField, NavSatFix, NavSatStatus
Imu,
PointCloud2,
PointField,
MagneticField,
NavSatFix,
NavSatStatus,
)
from geometry_msgs.msg import ( # pylint: disable=wrong-import-position
PoseStamped,
TwistStamped,
AccelStamped,
)
from geometry_msgs.msg import PoseStamped, TwistStamped, AccelStamped # pylint: disable=wrong-import-position


# set environment variable to use ROS2
os.environ["RMW_IMPLEMENTATION"] = "rmw_cyclonedds_cpp"
os.environ["ROS_DOMAIN_ID"] = "15"


class ROS2Backend(Backend):
"""
A class representing the ROS2 backend for the simulation.
Expand Down Expand Up @@ -62,17 +71,25 @@ def __init__(self, node_name: str):
rclpy.init()
self.node = rclpy.create_node(node_name)



# Create publishers for the state of the vehicle in ENU
self.pose_pub = self.node.create_publisher(PoseStamped, node_name + "/state/pose", 10)
self.twist_pub = self.node.create_publisher(TwistStamped, node_name + "/state/twist", 10)
self.twist_inertial_pub = self.node.create_publisher(TwistStamped, node_name + "/state/twist_inertial", 10)
self.accel_pub = self.node.create_publisher(AccelStamped, node_name + "/state/accel", 10)
self.pose_pub = self.node.create_publisher(
PoseStamped, node_name + "/state/pose", 10
)
self.twist_pub = self.node.create_publisher(
TwistStamped, node_name + "/state/twist", 10
)
self.twist_inertial_pub = self.node.create_publisher(
TwistStamped, node_name + "/state/twist_inertial", 10
)
self.accel_pub = self.node.create_publisher(
AccelStamped, node_name + "/state/accel", 10
)

# Create publishers for some sensor data
self.imu_pub = self.node.create_publisher(Imu, node_name + "/sensors/imu", 10)
self.point_cloud_pub = self.node.create_publisher(PointCloud2, node_name + "/sensors/points", 10)
self.point_cloud_pub = self.node.create_publisher(
PointCloud2, node_name + "/sensors/points", 10
)

def update(self, dt: float):
"""
Expand Down Expand Up @@ -127,7 +144,9 @@ def update_lidar_data(self, data):
msg = PointCloud2()

# Flatten LiDAR data
points_flat = np.array(data["points"]).reshape(-1, 3) # Adjust based on your data's structure
points_flat = np.array(data["points"]).reshape(
-1, 3
) # Adjust based on your data's structure

# Create a PointCloud2 message
msg = PointCloud2()
Expand All @@ -139,7 +158,7 @@ def update_lidar_data(self, data):
msg.fields = [
PointField(name="x", offset=0, datatype=PointField.FLOAT32, count=1),
PointField(name="y", offset=4, datatype=PointField.FLOAT32, count=1),
PointField(name="z", offset=8, datatype=PointField.FLOAT32, count=1)
PointField(name="z", offset=8, datatype=PointField.FLOAT32, count=1),
]
msg.is_bigendian = False
msg.point_step = 12 # Float32, x, y, z
Expand Down Expand Up @@ -170,7 +189,9 @@ def update_sensor(self, sensor_type: str, data):
elif sensor_type == "Lidar":
self.update_lidar_data(data)
else:
carb.log_warn(f"Sensor type {sensor_type} is not supported by the ROS2 backend.")
carb.log_warn(
f"Sensor type {sensor_type} is not supported by the ROS2 backend."
)
pass

def update_state(self, state):
Expand All @@ -187,7 +208,9 @@ def update_state(self, state):
accel = AccelStamped()

# Update the header
pose.header.stamp = (self.node.get_clock().now().to_msg()) # TODO: fill time when the state was measured.
pose.header.stamp = (
self.node.get_clock().now().to_msg()
) # TODO: fill time when the state was measured.
twist.header.stamp = pose.header.stamp
twist_inertial.header.stamp = pose.header.stamp
accel.header.stamp = pose.header.stamp
Expand Down
23 changes: 15 additions & 8 deletions exts/stride.simulator/stride/simulator/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def on_startup(self, ext_id):

self._window = ui.Window("Stride Simulator", width=300, height=300)

self._window.deferred_dock_in("Property", ui.DockPolicy.CURRENT_WINDOW_IS_ACTIVE)
self._window.deferred_dock_in(
"Property", ui.DockPolicy.CURRENT_WINDOW_IS_ACTIVE
)

# Start the extension backend
self._stride_sim = StrideInterface()
Expand All @@ -49,27 +51,32 @@ def on_world():

def on_environment():

self._stride_sim.load_asset(SIMULATION_ENVIRONMENTS["Default Environment"], "/World/layout")
self._stride_sim.load_asset(
SIMULATION_ENVIRONMENTS["Default Environment"], "/World/layout"
)

label.text = "Load environment"

def on_simulation():

async def respawn():

self._anymal_config = AnymalCConfig()

self._anymal = AnymalC(id=0,
init_pos=[0.0, 0.0, 0.7],
init_orientation=[0.0, 0.0, 0.0, 1.0],
config=self._anymal_config)
self._anymal = AnymalC(
id=0,
init_pos=[0.0, 0.0, 0.7],
init_orientation=[0.0, 0.0, 0.0, 1.0],
config=self._anymal_config,
)

self._current_tasks = self._stride_sim.world.get_current_tasks()
await self._stride_sim.world.reset_async()
await self._stride_sim.world.pause_async()

if len(self._current_tasks) > 0:
self._stride_sim.world.add_physics_callback("tasks_step", self._world.step_async)
self._stride_sim.world.add_physics_callback(
"tasks_step", self._world.step_async
)

asyncio.ensure_future(respawn())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def altitude(self):
return self._altitude

def initialize_world(self):
"""Method that initializes the world object
"""
"""Method that initializes the world object"""

async def _on_load_world_async():
if self._world is None:
Expand Down Expand Up @@ -243,7 +242,12 @@ def load_asset(self, usd_asset: str, stage_prefix: str):
success = prim.GetReferences().AddReference(usd_asset)

if not success:
raise Exception("The usd asset" + usd_asset + "is not load at stage path " + stage_prefix)
raise Exception(
"The usd asset"
+ usd_asset
+ "is not load at stage path "
+ stage_prefix
)

def set_viewport_camera(self, camera_position, camera_target):
"""Sets the viewport camera to given position and makes it point to another target position.
Expand All @@ -256,7 +260,9 @@ def set_viewport_camera(self, camera_position, camera_target):
# Set the camera view to a fixed value
set_camera_view(eye=camera_position, target=camera_target)

def set_world_settings(self, physics_dt=None, stage_units_in_meters=None, rendering_dt=None):
def set_world_settings(
self, physics_dt=None, stage_units_in_meters=None, rendering_dt=None
):
"""
Set the current world settings to the pre-defined settings. TODO - finish the implementation of this method.
For now these new setting will never override the default ones.
Expand Down
10 changes: 8 additions & 2 deletions exts/stride.simulator/stride/simulator/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@

# Add the Isaac Sim assets to the list
for asset, path in NVIDIA_SIMULATION_ENVIRONMENTS.items():
SIMULATION_ENVIRONMENTS[asset] = NVIDIA_ASSETS_PATH + ISAAC_SIM_ENVIRONMENTS + "/" + path
SIMULATION_ENVIRONMENTS[asset] = (
NVIDIA_ASSETS_PATH + ISAAC_SIM_ENVIRONMENTS + "/" + path
)

# Define the default settings for the simulation environment
DEFAULT_WORLD_SETTINGS = {"physics_dt": 1.0 / 200.0, "stage_units_in_meters": 1.0, "rendering_dt": 1.0 / 60.0}
DEFAULT_WORLD_SETTINGS = {
"physics_dt": 1.0 / 200.0,
"stage_units_in_meters": 1.0,
"rendering_dt": 1.0 / 60.0,
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@ async def test_hello_public_function(self):
result = stride.simulator.some_public_function(4)
self.assertEqual(result, 256)


async def test_window_button(self):

# Find a label in our window
label = ui_test.find("Stride Simulator//Frame/**/Label[*]")

# Find buttons in our window
add_button = ui_test.find("Stride Simulator//Frame/**/Button[*].text=='Add'")
reset_button = ui_test.find("Stride Simulator//Frame/**/Button[*].text=='Reset'")
reset_button = ui_test.find(
"Stride Simulator//Frame/**/Button[*].text=='Reset'"
)

# Click reset button
await reset_button.click()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from omni.isaac.core.utils.nucleus import get_assets_root_path

from stride.simulator.vehicles.controllers.controller import Controller
from stride.simulator.vehicles.controllers.networks.actuator_network import LstmSeaNetwork
from stride.simulator.vehicles.controllers.networks.actuator_network import (
LstmSeaNetwork,
)

import io
import numpy as np
Expand All @@ -25,8 +27,9 @@ def __init__(self):
assets_root_path = get_assets_root_path()

# Policy
file_content = omni.client.read_file(assets_root_path +
"/Isaac/Samples/Quadruped/Anymal_Policies/policy_1.pt")[2]
file_content = omni.client.read_file(
assets_root_path + "/Isaac/Samples/Quadruped/Anymal_Policies/policy_1.pt"
)[2]
file = io.BytesIO(memoryview(file_content).tobytes())

self._policy = torch.jit.load(file)
Expand All @@ -36,13 +39,17 @@ def __init__(self):
self.base_vel_ang_scale = 0.25
self.joint_pos_scale = 1.0
self.joint_vel_scale = 0.05
self.default_joint_pos = np.array([0.0, 0.4, -0.8, 0.0, -0.4, 0.8, -0.0, 0.4, -0.8, -0.0, -0.4, 0.8])
self.default_joint_pos = np.array(
[0.0, 0.4, -0.8, 0.0, -0.4, 0.8, -0.0, 0.4, -0.8, -0.0, -0.4, 0.8]
)
self.previous_action = np.zeros(12)
self._policy_counter = 0

# Actuator network
file_content = omni.client.read_file(assets_root_path +
"/Isaac/Samples/Quadruped/Anymal_Policies/sea_net_jit2.pt")[2]
file_content = omni.client.read_file(
assets_root_path
+ "/Isaac/Samples/Quadruped/Anymal_Policies/sea_net_jit2.pt"
)[2]
file = io.BytesIO(memoryview(file_content).tobytes())
self._actuator_network = LstmSeaNetwork()
self._actuator_network.setup(file, self.default_joint_pos)
Expand Down Expand Up @@ -90,8 +97,9 @@ def advance(self, dt, obs, command):
current_joint_vel = self.state.joint_velocities
current_joint_pos = np.array(current_joint_pos.reshape([3, 4]).T.flat)
current_joint_vel = np.array(current_joint_vel.reshape([3, 4]).T.flat)
joint_torques, _ = self._actuator_network.compute_torques(current_joint_pos, current_joint_vel,
self._action_scale * self.action)
joint_torques, _ = self._actuator_network.compute_torques(
current_joint_pos, current_joint_vel, self._action_scale * self.action
)

self._policy_counter += 1

Expand Down
Loading
Loading