Skip to content

Commit

Permalink
fix tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
Vittorio-Caggiano committed May 13, 2024
1 parent 2e159e8 commit d6ce0e0
Showing 1 changed file with 7 additions and 138 deletions.
145 changes: 7 additions & 138 deletions docs/source/tutorials/7_Fatigue_Modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
"%env MUJOCO_GL=egl\n",
"# %env MUJOCO_GL=egl\n",
"import myosuite\n",
"from myosuite.utils import gym\n",
"import skvideo.io\n",
Expand Down Expand Up @@ -37,12 +37,12 @@
"\n",
"import PIL.Image, PIL.ImageDraw, PIL.ImageFont\n",
"\n",
"def add_text_to_frame(frame, text, font=\"dejavu/DejaVuSans.ttf\", pos=(20, 20), color=(255, 0, 0), fontsize=12):\n",
"def add_text_to_frame(frame, text, pos=(20, 20), color=(255, 0, 0), fontsize=12):\n",
" if isinstance(frame, np.ndarray):\n",
" frame = PIL.Image.fromarray(frame)\n",
" \n",
" draw = PIL.ImageDraw.Draw(frame)\n",
" draw.text(pos, text, fill=color, font=PIL.ImageFont.truetype(font, fontsize))\n",
" draw.text(pos, text, fill=color)\n",
" return frame"
]
},
Expand Down Expand Up @@ -380,7 +380,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -400,7 +400,7 @@
")\n",
"\n",
"model = PPO(\"MlpPolicy\", env, verbose=0)\n",
"model.learn(total_timesteps=200000, callback=checkpoint_callback)"
"model.learn(total_timesteps=200, callback=checkpoint_callback)"
]
},
{
Expand Down Expand Up @@ -578,137 +578,6 @@
"Comparison: Policy trained without fatigue"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env_name = \"myoFatiElbowPose1D6MRandom-v0\"\n",
"\n",
"GENERATE_VIDEO = True\n",
"GENERATE_VIDEO_EPS = 4 #number of episodes that are rendered BOTH at the beginning (i.e., without fatigue) and at the end (i.e., with fatigue)\n",
"\n",
"STORE_DATA = True #store collected data from evaluation run in .npy file\n",
"n_eps = 250\n",
"\n",
"###################################\n",
"\n",
"env = gym.make(env_name)\n",
"\n",
"policy = \"../../../myosuite/agents/baslines_NPG/myoElbowPose1D6MRandom-v0/2022-02-26_21-16-27/33_env=myoElbowPose1D6MRandom-v0,seed=1/iterations/best_policy.pickle\"\n",
"pi = pickle.load(open(policy, 'rb'))\n",
"\n",
"env.unwrapped.set_fatigue_reset_random(False)\n",
"env.reset(fatigue_reset=True) #ensure that fatigue is reset before the simulation starts\n",
"\n",
"env.unwrapped.sim.model.cam_poscom0[0]= np.array([-1.3955, -0.3287, 0.6579])\n",
"\n",
"data_store = []\n",
"if GENERATE_VIDEO:\n",
" frames = []\n",
"\n",
"env.unwrapped.target_jnt_value = env.unwrapped.target_jnt_range[:, 1]\n",
"env.unwrapped.target_type = 'fixed'\n",
"env.unwrapped.update_target(restore_sim=True)\n",
"\n",
"start_time = time.time()\n",
"for ep in range(n_eps):\n",
" print(\"Ep {} of {}\".format(ep, n_eps))\n",
"\n",
" for _cstep in range(env.spec.max_episode_steps):\n",
" if GENERATE_VIDEO and (ep in range(GENERATE_VIDEO_EPS) or ep in range(n_eps-GENERATE_VIDEO_EPS, n_eps)):\n",
" frame = env.unwrapped.sim.renderer.render_offscreen(width=400, height=400, camera_id=0)\n",
" \n",
" # Add text overlay\n",
" _current_time = (ep*env.spec.max_episode_steps + _cstep)*env.unwrapped.dt\n",
" frame = np.array(add_text_to_frame(frame,\n",
" f\"t={str(int(_current_time//60)).zfill(2)}:{str(int(_current_time%60)).zfill(2)}min\",\n",
" pos=(285, 3), color=(0, 0, 0), fontsize=18))\n",
" \n",
" frames.append(frame)\n",
" o = env.unwrapped.get_obs()\n",
" a = pi.get_action(o)[0]\n",
" next_o, r, done, _, ifo = env.step(a) # take an action based on the current observation\n",
"\n",
" data_store.append({\"action\":a.copy(), \n",
" \"jpos\":env.unwrapped.sim.data.qpos.copy(), \n",
" \"mlen\":env.unwrapped.sim.data.actuator_length.copy(), \n",
" \"act\":env.unwrapped.sim.data.act.copy(),\n",
" \"reward\":r,\n",
" \"solved\":env.unwrapped.rwd_dict['solved'].item(),\n",
" \"pose_err\":env.unwrapped.get_obs_dict(env.unwrapped.sim)[\"pose_err\"],\n",
" \"MA\":env.unwrapped.muscle_fatigue.MA.copy(),\n",
" \"MR\":env.unwrapped.muscle_fatigue.MR.copy(),\n",
" \"MF\":env.unwrapped.muscle_fatigue.MF.copy(),\n",
" \"ctrl\":env.unwrapped.last_ctrl.copy()})\n",
"env.close()\n",
"\n",
"## OPTIONALLY: Stored simulated data\n",
"if STORE_DATA:\n",
" os.makedirs(f\"{env_name}/logs\", exist_ok=True)\n",
" np.save(f\"{env_name}/logs/fatitest_trained_wo_fatigue.npy\", data_store)\n",
"\n",
"## OPTIONALLY: Render video\n",
"if GENERATE_VIDEO:\n",
" os.makedirs(f'{env_name}/videos', exist_ok=True)\n",
" # make a local copy\n",
" skvideo.io.vwrite(f'{env_name}/videos/fatitest_trained_wo_fatigue.mp4', np.asarray(frames),inputdict={'-r': str(int(1/env.unwrapped.dt))},outputdict={\"-pix_fmt\": \"yuv420p\"})\n",
"\n",
"end_time = time.time()\n",
"print(f\"DURATION: {end_time - start_time:.2f}s\")\n",
"\n",
"if GENERATE_VIDEO:\n",
" display(show_video(f'{env_name}/videos/fatitest_trained_wo_fatigue.mp4'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env_name = \"myoFatiElbowPose1D6MRandom-v0\"\n",
"\n",
"####################\n",
"\n",
"env_test = gym.make(env_name, normalize_act=False)\n",
"muscle_names = [env_test.unwrapped.sim.model.id2name(i, \"actuator\") for i in range(env_test.unwrapped.sim.model.nu) if env_test.unwrapped.sim.model.actuator_dyntype[i] == mujoco.mjtDyn.mjDYN_MUSCLE]\n",
"_env_dt = env_test.unwrapped.dt #0.02\n",
"\n",
"data_store = np.load(f\"{env_name}/logs/fatitest_trained_wo_fatigue.npy\", allow_pickle=True)\n",
"\n",
"plt.figure()\n",
"for _muscleid in range(len(data_store[0]['MF'])):\n",
" plt.plot(_env_dt*np.arange(len(data_store)), np.array([d['MF'][_muscleid] for d in data_store]), label=muscle_names[_muscleid])\n",
"plt.legend()\n",
"plt.title('Fatigued Motor Units')\n",
"\n",
"plt.figure()\n",
"for _muscleid in range(len(data_store[0]['MR'])):\n",
" plt.plot(_env_dt*np.arange(len(data_store)), np.array([d['MR'][_muscleid] for d in data_store]), label=muscle_names[_muscleid])\n",
"plt.legend()\n",
"plt.title('Resting Motor Units')\n",
"\n",
"plt.figure()\n",
"for _muscleid in range(len(data_store[0]['MA'])):\n",
" plt.plot(_env_dt*np.arange(len(data_store)), np.array([d['MA'][_muscleid] for d in data_store]), label=muscle_names[_muscleid])\n",
"plt.legend()\n",
"plt.title('Active Motor Units')\n",
"\n",
"plt.figure()\n",
"plt.plot(_env_dt*np.arange(len(data_store)), np.array([np.linalg.norm(d['pose_err']) for d in data_store])), plt.title('Pose Error')\n",
"\n",
"plt.figure()\n",
"plt.plot(_env_dt*np.arange(len(data_store)), np.array([d['reward'] for d in data_store])), plt.title(f\"Reward (Total: {np.array([d['reward'] for d in data_store]).sum():.2f})\")\n",
"\n",
"if \"solved\" in data_store[0]:\n",
" plt.figure()\n",
" plt.scatter(_env_dt*np.arange(len(data_store))[np.array([d['solved'] for d in data_store])], np.array([d['solved'] for d in data_store])[np.array([d['solved'] for d in data_store])]), plt.title(f\"Success\")\n",
"\n",
"print(f\"Muscle Fatigue Equilibrium: {data_store[-1]['MF']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -871,7 +740,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -885,7 +754,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.9.19"
}
},
"nbformat": 4,
Expand Down

0 comments on commit d6ce0e0

Please sign in to comment.