Mujoco足式机器人强化学习训练04(四足机器人)
本文介绍了mujoco_playground中的joystick.py脚本,该脚本为Go1机器人构建操纵杆环境,主要功能包括随机生成三维速度指令、加入传感器噪声与扰动、设计多种奖励函数等。文章详细解析了代码结构,包括参数设置(如控制参数、奖励系数、速度指令范围)、机器人模型导入方式以及初始化过程。该环境支持软关节限位、命令采样等功能,可用于训练鲁棒的步态控制策略。参数配置部分特别说明了KP、KD值
文章目录
注:本文为个人学习笔记,仅为个人理解,如有错误欢迎讨论
前言
本文将对mujoco_playground中的joystick.py进行介绍,这个脚本的基类在上一节已介绍完成
Mujoco足式机器人强化学习训练03(四足机器人)
一、Joystick.py是什么?
为Go1机器人构建的操纵杆环境,随机生成三维速度指令并周期更新,加入传感器噪声与躯干扰动,提供低维与特权观测,设计跟踪、能耗、接触与姿态相关奖励,支持软限位与命令采样,便于训练鲁棒步态控制策略高效化。(ai写的🐶)
二、代码介绍
1.引入库
# Copyright 2025 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Joystick task for Go1."""
from typing import Any, Dict, Optional, Union
import jax
import jax.numpy as jp
from ml_collections import config_dict
from mujoco import mjx
from mujoco.mjx._src import math
import numpy as np
from mujoco_playground._src import mjx_env
from mujoco_playground._src.locomotion.go1 import base as go1_base
from mujoco_playground._src.locomotion.go1 import go1_constants as consts
2.参数设置
此处和leggedgym不一样,mujoco_playground直接把config和训练脚本写在一起了(可自行更改成isaacgym相同的格式)
def default_config() -> config_dict.ConfigDict:
return config_dict.create(
ctrl_dt=0.02,
sim_dt=0.004,
episode_length=1000,
Kp=35.0,
Kd=0.5,
action_repeat=1,
action_scale=0.5,
history_len=1,
soft_joint_pos_limit_factor=0.95,
noise_config=config_dict.create(
level=1.0, # Set to 0.0 to disable noise.
scales=config_dict.create(
joint_pos=0.03,
joint_vel=1.5,
gyro=0.2,
gravity=0.05,
linvel=0.1,
),
),
注:此处对kp和kd的设置,其实就是对xml文件中kp和damping的覆写,具体如何实现的在base.py里面有写,kp是xml文件中的actuator里面设置的,damping即每个joint的阻尼系数
此处action_scale=0.5,如果迁移到自己的机器人上面训练,出现关节高频抖动,可以将其调小一点
soft_joint_pos_limit_factor=0.95为软关节限位系数,即在原有范围上乘0.95,如果超出后续的奖励函数就会给惩罚
reward_config=config_dict.create(
scales=config_dict.create(
# Tracking.
tracking_lin_vel=1.0,
tracking_ang_vel=0.5,
# Base reward.
lin_vel_z=-0.5,
ang_vel_xy=-0.05,
orientation=-5.0,
# Other.
dof_pos_limits=-1.0,
pose=0.5,
# Other.
termination=-1.0,
stand_still=-1.0,
# Regularization.
torques=-0.0002,
action_rate=-0.01,
energy=-0.001,
# Feet.
feet_clearance=-2.0,
feet_height=-0.2,
feet_slip=-0.1,
feet_air_time=0.1,
),
tracking_sigma=0.25,
max_foot_height=0.1,
),
pert_config=config_dict.create(
enable=False,
velocity_kick=[0.0, 3.0],
kick_durations=[0.05, 0.2],
kick_wait_times=[1.0, 3.0],
),
以上为reward的scale,即奖励函数计算出来的raw前面乘的系数,具体设置大小要结合机器人综合考量
command_config=config_dict.create(
# Uniform distribution for command amplitude.
a=[1.5, 0.8, 1.2],
# Probability of not zeroing out new command.
b=[0.9, 0.25, 0.5],
),
impl="jax",
nconmax=4 * 8192,
njmax=40,
)
此处为手柄速度设置(command)
a=[1.5, 0.8, 1.2]代表
x方向的最大速度是1.5m/s
y方向的最大速度是0.8m/s
角速度最大1.2rad/s
3.创建对象
1.机器人模型导入
以下是joystick的核心代码,是base.py的续写
class Joystick(go1_base.Go1Env):
"""Track a joystick command."""
def __init__(
self,
task: str = "flat_terrain",
config: config_dict.ConfigDict = default_config(),
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
):
if task.startswith("rough"):
config.nconmax = 8 * 8192
config.njmax = 12 + 48
super().__init__(
xml_path=consts.task_to_xml(task).as_posix(),
config=config,
config_overrides=config_overrides,
)
self._post_init()
task: str = "flat_terrain"即我上文介绍的xml地址的传入,通过将字符串传入consts.task_to_xml(task)将字符串转换成xml地址(机器人环境地址)
2.初始化(获取id)
def _post_init(self) -> None:
self._init_q = jp.array(self._mj_model.keyframe("home").qpos)
self._default_pose = jp.array(self._mj_model.keyframe("home").qpos[7:])
# Note: First joint is freejoint.
self._lowers, self._uppers = self.mj_model.jnt_range[1:].T
self._soft_lowers = self._lowers * self._config.soft_joint_pos_limit_factor
self._soft_uppers = self._uppers * self._config.soft_joint_pos_limit_factor
self._torso_body_id = self._mj_model.body(consts.ROOT_BODY).id
self._torso_mass = self._mj_model.body_subtreemass[self._torso_body_id]
self._feet_site_id = np.array(
[self._mj_model.site(name).id for name in consts.FEET_SITES]
)
self._floor_geom_id = self._mj_model.geom("floor").id
self._feet_geom_id = np.array(
[self._mj_model.geom(name).id for name in consts.FEET_GEOMS]
)
foot_linvel_sensor_adr = []
for site in consts.FEET_SITES:
sensor_id = self._mj_model.sensor(f"{site}_global_linvel").id
sensor_adr = self._mj_model.sensor_adr[sensor_id]
sensor_dim = self._mj_model.sensor_dim[sensor_id]
foot_linvel_sensor_adr.append(
list(range(sensor_adr, sensor_adr + sensor_dim))
)
self._foot_linvel_sensor_adr = jp.array(foot_linvel_sensor_adr)
self._cmd_a = jp.array(self._config.command_config.a)
self._cmd_b = jp.array(self._config.command_config.b)
以上代码是对xml文件信息进行提取,mujoco中对信息的提取主要是通过id进行对应的,这个id号可以通过name进行对应,也可以根据xml文件出现的先后顺序,如果也是进行四足机器人的部署,其实这个地方不需要太改,最好还是去改xml文件中的name,这样更方便
3.环境初始化
def reset(self, rng: jax.Array) -> mjx_env.State:
qpos = self._init_q
qvel = jp.zeros(self.mjx_model.nv)
机器人初始位置随机化,可增加训练鲁棒性
# x=+U(-0.5, 0.5), y=+U(-0.5, 0.5), yaw=U(-3.14, 3.14).
rng, key = jax.random.split(rng)
dxy = jax.random.uniform(key, (2,), minval=-0.5, maxval=0.5)
qpos = qpos.at[0:2].set(qpos[0:2] + dxy)
rng, key = jax.random.split(rng)
yaw = jax.random.uniform(key, (1,), minval=-3.14, maxval=3.14)
quat = math.axis_angle_to_quat(jp.array([0, 0, 1]), yaw)
new_quat = math.quat_mul(qpos[3:7], quat)
qpos = qpos.at[3:7].set(new_quat)
机器人初始速度随机化,可增加训练鲁棒性
# d(xyzrpy)=U(-0.5, 0.5)
rng, key = jax.random.split(rng)
qvel = qvel.at[0:6].set(
jax.random.uniform(key, (6,), minval=-0.5, maxval=0.5)
)
以下代码的功能是将信息进行储存,不需要更改
data = mjx_env.make_data(
self.mj_model,
qpos=qpos,
qvel=qvel,
ctrl=qpos[7:],
impl=self.mjx_model.impl.value,
nconmax=self._config.nconmax,
njmax=self._config.njmax,
)
data = mjx.forward(self.mjx_model, data)
rng, key1, key2, key3 = jax.random.split(rng, 4)
time_until_next_pert = jax.random.uniform(
key1,
minval=self._config.pert_config.kick_wait_times[0],
maxval=self._config.pert_config.kick_wait_times[1],
)
steps_until_next_pert = jp.round(time_until_next_pert / self.dt).astype(
jp.int32
)
pert_duration_seconds = jax.random.uniform(
key2,
minval=self._config.pert_config.kick_durations[0],
maxval=self._config.pert_config.kick_durations[1],
)
pert_duration_steps = jp.round(pert_duration_seconds / self.dt).astype(
jp.int32
)
pert_mag = jax.random.uniform(
key3,
minval=self._config.pert_config.velocity_kick[0],
maxval=self._config.pert_config.velocity_kick[1],
)
rng, key1, key2 = jax.random.split(rng, 3)
time_until_next_cmd = jax.random.exponential(key1) * 5.0
steps_until_next_cmd = jp.round(time_until_next_cmd / self.dt).astype(
jp.int32
)
cmd = jax.random.uniform(
key2, shape=(3,), minval=-self._cmd_a, maxval=self._cmd_a
)
info = {
"rng": rng,
"command": cmd,
"steps_until_next_cmd": steps_until_next_cmd,
"last_act": jp.zeros(self.mjx_model.nu),
"last_last_act": jp.zeros(self.mjx_model.nu),
"feet_air_time": jp.zeros(4),
"last_contact": jp.zeros(4, dtype=bool),
"swing_peak": jp.zeros(4),
"steps_until_next_pert": steps_until_next_pert,
"pert_duration_seconds": pert_duration_seconds,
"pert_duration": pert_duration_steps,
"steps_since_last_pert": 0,
"pert_steps": 0,
"pert_dir": jp.zeros(3),
"pert_mag": pert_mag,
}
metrics = {}
for k in self._config.reward_config.scales.keys():
metrics[f"reward/{k}"] = jp.zeros(())
metrics["swing_peak"] = jp.zeros(())
obs = self._get_obs(data, info)
reward, done = jp.zeros(2)
return mjx_env.State(data, obs, reward, done, metrics, info)
4. 定义观测量函数
这里有两个state,一个是特权state、一个是state
特权state是没有噪声的观测量及一些方便训练收敛的观测量
state是带噪声的观测量
这个是来自于教师学生训练策略,网上有很多关于该策略的讲解,可自行查阅
注:这里需要注意的是,如果后期需要实物部署,这个state的信息需要全部可以获取的到,比如,如果你的机器人上面没有足端六维力传感器,你就不可以将其作为state,否则sim2real的时候,这个state就会缺少信息,但是你可以将其放到特权state里面,这样可以方便训练
具体每一个观测量是怎么获取的,再次就不讲解了,自行阅读代码即可,官方提供的代码十分清晰
def _get_obs(
self, data: mjx.Data, info: dict[str, Any]
) -> Dict[str, jax.Array]:
gyro = self.get_gyro(data)
info["rng"], noise_rng = jax.random.split(info["rng"])
noisy_gyro = (
gyro
+ (2 * jax.random.uniform(noise_rng, shape=gyro.shape) - 1)
* self._config.noise_config.level
* self._config.noise_config.scales.gyro
)
gravity = self.get_gravity(data)
info["rng"], noise_rng = jax.random.split(info["rng"])
noisy_gravity = (
gravity
+ (2 * jax.random.uniform(noise_rng, shape=gravity.shape) - 1)
* self._config.noise_config.level
* self._config.noise_config.scales.gravity
)
joint_angles = data.qpos[7:]
info["rng"], noise_rng = jax.random.split(info["rng"])
noisy_joint_angles = (
joint_angles
+ (2 * jax.random.uniform(noise_rng, shape=joint_angles.shape) - 1)
* self._config.noise_config.level
* self._config.noise_config.scales.joint_pos
)
joint_vel = data.qvel[6:]
info["rng"], noise_rng = jax.random.split(info["rng"])
noisy_joint_vel = (
joint_vel
+ (2 * jax.random.uniform(noise_rng, shape=joint_vel.shape) - 1)
* self._config.noise_config.level
* self._config.noise_config.scales.joint_vel
)
linvel = self.get_local_linvel(data)
info["rng"], noise_rng = jax.random.split(info["rng"])
noisy_linvel = (
linvel
+ (2 * jax.random.uniform(noise_rng, shape=linvel.shape) - 1)
* self._config.noise_config.level
* self._config.noise_config.scales.linvel
)
state = jp.hstack([
noisy_linvel, # 3
noisy_gyro, # 3
noisy_gravity, # 3
noisy_joint_angles - self._default_pose, # 12
noisy_joint_vel, # 12
info["last_act"], # 12
info["command"], # 3
])
accelerometer = self.get_accelerometer(data)
angvel = self.get_global_angvel(data)
feet_vel = data.sensordata[self._foot_linvel_sensor_adr].ravel()
privileged_state = jp.hstack([
state,
gyro, # 3
accelerometer, # 3
gravity, # 3
linvel, # 3
angvel, # 3
joint_angles - self._default_pose, # 12
joint_vel, # 12
data.actuator_force, # 12
info["last_contact"], # 4
feet_vel, # 4*3
info["feet_air_time"], # 4
data.xfrc_applied[self._torso_body_id, :3], # 3
info["steps_since_last_pert"] >= info["steps_until_next_pert"], # 1
])
return {
"state": state,
"privileged_state": privileged_state,
}
5. 获得奖励值
这里没有leggedgym里面写的那么简洁,直接通过名字进行scale和reward的对应,在mujoco playground里面是直接通过键值对及进行对应的
def _get_reward(
self,
data: mjx.Data,
action: jax.Array,
info: dict[str, Any],
metrics: dict[str, Any],
done: jax.Array,
first_contact: jax.Array,
contact: jax.Array,
) -> dict[str, jax.Array]:
del metrics # Unused.
return {
"tracking_lin_vel": self._reward_tracking_lin_vel(
info["command"], self.get_local_linvel(data)
),
"tracking_ang_vel": self._reward_tracking_ang_vel(
info["command"], self.get_gyro(data)
),
"lin_vel_z": self._cost_lin_vel_z(self.get_global_linvel(data)),
"ang_vel_xy": self._cost_ang_vel_xy(self.get_global_angvel(data)),
"orientation": self._cost_orientation(self.get_upvector(data)),
"stand_still": self._cost_stand_still(info["command"], data.qpos[7:]),
"termination": self._cost_termination(done),
"pose": self._reward_pose(data.qpos[7:]),
"torques": self._cost_torques(data.actuator_force),
"action_rate": self._cost_action_rate(
action, info["last_act"], info["last_last_act"]
),
"energy": self._cost_energy(data.qvel[6:], data.actuator_force),
"feet_slip": self._cost_feet_slip(data, contact, info),
"feet_clearance": self._cost_feet_clearance(data),
"feet_height": self._cost_feet_height(
info["swing_peak"], first_contact, info
),
"feet_air_time": self._reward_feet_air_time(
info["feet_air_time"], first_contact, info["command"]
),
"dof_pos_limits": self._cost_joint_pos_limits(data.qpos[7:]),
}
6. 定义奖励函数
以下的奖励函数,均是十分基础的函数定义,与leggedgym里面没什么区别,没个函数的功能注释中均有介绍,在此不做详细介绍🐶
def _reward_tracking_lin_vel(
self,
commands: jax.Array,
local_vel: jax.Array,
) -> jax.Array:
# Tracking of linear velocity commands (xy axes).
lin_vel_error = jp.sum(jp.square(commands[:2] - local_vel[:2]))
return jp.exp(-lin_vel_error / self._config.reward_config.tracking_sigma)
def _reward_tracking_ang_vel(
self,
commands: jax.Array,
ang_vel: jax.Array,
) -> jax.Array:
# Tracking of angular velocity commands (yaw).
ang_vel_error = jp.square(commands[2] - ang_vel[2])
return jp.exp(-ang_vel_error / self._config.reward_config.tracking_sigma)
# Base-related rewards.
def _cost_lin_vel_z(self, global_linvel) -> jax.Array:
# Penalize z axis base linear velocity.
return jp.square(global_linvel[2])
def _cost_ang_vel_xy(self, global_angvel) -> jax.Array:
# Penalize xy axes base angular velocity.
return jp.sum(jp.square(global_angvel[:2]))
def _cost_orientation(self, torso_zaxis: jax.Array) -> jax.Array:
# Penalize non flat base orientation.
return jp.sum(jp.square(torso_zaxis[:2]))
# Energy related rewards.
def _cost_torques(self, torques: jax.Array) -> jax.Array:
# Penalize torques.
return jp.sqrt(jp.sum(jp.square(torques))) + jp.sum(jp.abs(torques))
def _cost_energy(
self, qvel: jax.Array, qfrc_actuator: jax.Array
) -> jax.Array:
# Penalize energy consumption.
return jp.sum(jp.abs(qvel) * jp.abs(qfrc_actuator))
def _cost_action_rate(
self, act: jax.Array, last_act: jax.Array, last_last_act: jax.Array
) -> jax.Array:
del last_last_act # Unused.
return jp.sum(jp.square(act - last_act))
# Other rewards.
def _reward_pose(self, qpos: jax.Array) -> jax.Array:
# Stay close to the default pose.
weight = jp.array([1.0, 1.0, 0.1] * 4)
return jp.exp(-jp.sum(jp.square(qpos - self._default_pose) * weight))
def _cost_stand_still(
self,
commands: jax.Array,
qpos: jax.Array,
) -> jax.Array:
cmd_norm = jp.linalg.norm(commands)
return jp.sum(jp.abs(qpos - self._default_pose)) * (cmd_norm < 0.01)
def _cost_termination(self, done: jax.Array) -> jax.Array:
# Penalize early termination.
return done
def _cost_joint_pos_limits(self, qpos: jax.Array) -> jax.Array:
# Penalize joints if they cross soft limits.
out_of_limits = -jp.clip(qpos - self._soft_lowers, None, 0.0)
out_of_limits += jp.clip(qpos - self._soft_uppers, 0.0, None)
return jp.sum(out_of_limits)
# Feet related rewards.
def _cost_feet_slip(
self, data: mjx.Data, contact: jax.Array, info: dict[str, Any]
) -> jax.Array:
cmd_norm = jp.linalg.norm(info["command"])
feet_vel = data.sensordata[self._foot_linvel_sensor_adr]
vel_xy = feet_vel[..., :2]
vel_xy_norm_sq = jp.sum(jp.square(vel_xy), axis=-1)
return jp.sum(vel_xy_norm_sq * contact) * (cmd_norm > 0.01)
def _cost_feet_clearance(self, data: mjx.Data) -> jax.Array:
feet_vel = data.sensordata[self._foot_linvel_sensor_adr]
vel_xy = feet_vel[..., :2]
vel_norm = jp.sqrt(jp.linalg.norm(vel_xy, axis=-1))
foot_pos = data.site_xpos[self._feet_site_id]
foot_z = foot_pos[..., -1]
delta = jp.abs(foot_z - self._config.reward_config.max_foot_height)
return jp.sum(delta * vel_norm)
def _cost_feet_height(
self,
swing_peak: jax.Array,
first_contact: jax.Array,
info: dict[str, Any],
) -> jax.Array:
cmd_norm = jp.linalg.norm(info["command"])
error = swing_peak / self._config.reward_config.max_foot_height - 1.0
return jp.sum(jp.square(error) * first_contact) * (cmd_norm > 0.01)
def _reward_feet_air_time(
self, air_time: jax.Array, first_contact: jax.Array, commands: jax.Array
) -> jax.Array:
# Reward air time.
cmd_norm = jp.linalg.norm(commands)
rew_air_time = jp.sum((air_time - 0.1) * first_contact)
rew_air_time *= cmd_norm > 0.01 # No reward for zero commands.
return rew_air_time
参考文献
1.https://playground.mujoco.org/
DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐



所有评论(0)