Skip to content

Commit 0101f4a

Browse files
committed
added motor example
1 parent 01b7121 commit 0101f4a

15 files changed

+32503
-0
lines changed

examples/motor/_plotting_utils.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
import os
4+
from typing import List, Tuple, Union
5+
6+
def setup_plot(nrows: int = 2, ncols: int = 2, figsize: Tuple[int, int] = (10, 10)) -> Tuple[plt.Figure, Union[plt.Axes, List[plt.Axes]]]:
7+
"""
8+
Set up a matplotlib figure with the specified number of rows and columns.
9+
10+
Args:
11+
nrows (int): Number of rows in the subplot grid.
12+
ncols (int): Number of columns in the subplot grid.
13+
figsize (Tuple[int, int]): Figure size in inches (width, height).
14+
15+
Returns:
16+
Tuple[plt.Figure, Union[plt.Axes, List[plt.Axes]]]: Figure and axes objects.
17+
"""
18+
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
19+
if nrows * ncols == 1:
20+
axes = [axes]
21+
elif nrows == 1 or ncols == 1:
22+
axes = axes.flatten()
23+
return fig, axes
24+
25+
def plot_state(ax: plt.Axes, timespan: np.ndarray, actual: np.ndarray, simulated: np.ndarray, label: str, color: str = 'blue', alpha: float = 0.5) -> None:
26+
"""
27+
Plot actual and simulated state data on the given axes.
28+
29+
Args:
30+
ax (plt.Axes): The matplotlib axes to plot on.
31+
timespan (np.ndarray): Array of time points.
32+
actual (np.ndarray): Array of actual state values.
33+
simulated (np.ndarray): Array of simulated state values.
34+
label (str): Label for the state (e.g., "Angle" or "Velocity").
35+
color (str): Color for the simulated data plot.
36+
alpha (float): Alpha value for the simulated data plot.
37+
"""
38+
ax.plot(timespan, actual, label=f"Actual {label}", color="black", linestyle="dashed", linewidth=2)
39+
ax.plot(timespan, simulated, alpha=alpha, color=color, label=f"Simulated {label}")
40+
ax.set_ylabel(f"{label} (rad{'/' if label == 'Velocity' else ''}s)")
41+
ax.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
42+
ax.legend()
43+
44+
def plot_phase_portrait(ax: plt.Axes, angle: np.ndarray, velocity: np.ndarray, simulated_angle: np.ndarray, simulated_velocity: np.ndarray, color: str = 'blue', alpha: float = 0.5) -> None:
45+
"""
46+
Plot the phase portrait of actual and simulated data.
47+
48+
Args:
49+
ax (plt.Axes): The matplotlib axes to plot on.
50+
angle (np.ndarray): Array of actual angle values.
51+
velocity (np.ndarray): Array of actual velocity values.
52+
simulated_angle (np.ndarray): Array of simulated angle values.
53+
simulated_velocity (np.ndarray): Array of simulated velocity values.
54+
color (str): Color for the simulated data plot.
55+
alpha (float): Alpha value for the simulated data plot.
56+
"""
57+
ax.plot(angle, velocity, label="Actual", color="black", linestyle="dashed", linewidth=2)
58+
ax.plot(simulated_angle, simulated_velocity, alpha=alpha, color=color, label="Simulated")
59+
ax.set_xlabel("Angle (rad)")
60+
ax.set_ylabel("Angular Velocity (rad/s)")
61+
ax.set_title("Phase Portrait")
62+
ax.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
63+
ax.legend()
64+
65+
def plot_simulation_errors(timespan: np.ndarray, angle: np.ndarray, velocity: np.ndarray, batched_states_trajectories: np.ndarray, predicted_terminal_points: np.ndarray, interval_terminal_states: np.ndarray, HORIZON: int, save_path: str = None, show: bool = False, title: str = "Simulation Errors", iteration: int = None) -> np.ndarray:
66+
"""
67+
Plot simulation errors for the pendulum system and return the frame as an image.
68+
69+
Args:
70+
timespan (np.ndarray): Array of time points.
71+
angle (np.ndarray): Array of actual angle values.
72+
velocity (np.ndarray): Array of actual velocity values.
73+
batched_states_trajectories (np.ndarray): Array of simulated state trajectories.
74+
predicted_terminal_points (np.ndarray): Array of predicted terminal points.
75+
interval_terminal_states (np.ndarray): Array of actual terminal states at intervals.
76+
HORIZON (int): Number of time steps in each interval.
77+
save_path (str): Path to save the plot. If None, the plot is not saved.
78+
show (bool): Whether to display the plot.
79+
title (str): Title for the plot.
80+
iteration (int, optional): Current iteration number for animation frames.
81+
82+
Returns:
83+
np.ndarray: Image array representing the current frame.
84+
"""
85+
fig = plt.figure(figsize=(12, 6))
86+
gs = fig.add_gridspec(2, 2)
87+
88+
ax1 = fig.add_subplot(gs[0, 0])
89+
ax2 = fig.add_subplot(gs[1, 0])
90+
ax3 = fig.add_subplot(gs[:, 1])
91+
92+
plot_state(ax1, timespan, angle, batched_states_trajectories[:, 0], "Angle")
93+
ax1.plot(timespan[HORIZON + 1 :][::HORIZON], predicted_terminal_points[:-1, 0], "ob", label="Predicted")
94+
ax1.plot(timespan[HORIZON + 1 :][::HORIZON], interval_terminal_states[:, 0], "or", label="Actual")
95+
if iteration is not None:
96+
ax1.set_title(f"{title} (Iteration {iteration})")
97+
else:
98+
ax1.set_title(title)
99+
ax1.legend(loc='upper right')
100+
101+
plot_state(ax2, timespan, velocity, batched_states_trajectories[:, 1], "Velocity")
102+
ax2.plot(timespan[HORIZON + 1 :][::HORIZON], predicted_terminal_points[:-1, 1], "ob", label="Predicted")
103+
ax2.plot(timespan[HORIZON + 1 :][::HORIZON], interval_terminal_states[:, 1], "or", label="Actual")
104+
ax2.set_xlabel("Time (s)")
105+
ax2.legend(loc='upper right')
106+
107+
plot_phase_portrait(ax3, angle, velocity, batched_states_trajectories[:, 0], batched_states_trajectories[:, 1])
108+
ax3.plot(predicted_terminal_points[:-1, 0], predicted_terminal_points[:-1, 1], "ob", label="Predicted")
109+
ax3.plot(interval_terminal_states[:, 0], interval_terminal_states[:, 1], "or", label="Actual")
110+
ax3.legend(loc='upper right')
111+
112+
plt.tight_layout()
113+
114+
if save_path:
115+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
116+
plt.savefig(save_path, dpi=300)
117+
118+
if show:
119+
plt.show()
120+
121+
# Convert plot to image array
122+
fig.canvas.draw()
123+
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
124+
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
125+
126+
plt.close(fig)
127+
128+
return image
129+
130+
def create_animation_frame(timespan: np.ndarray, true_trajectory: np.ndarray, current_rollout: np.ndarray, iteration: int) -> np.ndarray:
131+
"""
132+
Create a single frame for the animation of the learning process.
133+
134+
Args:
135+
timespan (np.ndarray): Array of time points.
136+
true_trajectory (np.ndarray): Array of actual state values.
137+
current_rollout (np.ndarray): Array of current simulated state values.
138+
iteration (int): Current iteration number.
139+
140+
Returns:
141+
np.ndarray: Image array representing the current frame.
142+
"""
143+
fig = plt.figure(figsize=(12, 6)) # Reduced height from 10 to 5
144+
gs = fig.add_gridspec(2, 2)
145+
146+
ax1 = fig.add_subplot(gs[0, 0])
147+
ax2 = fig.add_subplot(gs[1, 0])
148+
ax3 = fig.add_subplot(gs[:, 1])
149+
150+
plot_state(ax1, timespan, true_trajectory[:, 0], current_rollout[:, 0], "Angle", color="red")
151+
ax1.set_title(f"Iteration {iteration}")
152+
153+
plot_state(ax2, timespan, true_trajectory[:, 1], current_rollout[:, 1], "Velocity", color="red")
154+
ax2.set_xlabel("Time (s)")
155+
156+
plot_phase_portrait(ax3, true_trajectory[:, 0], true_trajectory[:, 1], current_rollout[:, 0], current_rollout[:, 1], color="red")
157+
ax3.set_title("Phase Portrait")
158+
159+
plt.tight_layout()
160+
161+
fig.canvas.draw()
162+
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
163+
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
164+
165+
plt.close(fig)
166+
167+
return image
168+
169+
def plot_full_simulation(timespan: np.ndarray, angle: np.ndarray, velocity: np.ndarray, old_rollout: np.ndarray, new_rollout: np.ndarray, save_path: str = "plots/learning_results.png", show: bool = True) -> None:
170+
"""
171+
Plot full simulation results for the pendulum system.
172+
173+
Args:
174+
timespan (np.ndarray): Array of time points.
175+
angle (np.ndarray): Array of actual angle values.
176+
velocity (np.ndarray): Array of actual velocity values.
177+
old_rollout (np.ndarray): Array of simulated states using the old model.
178+
new_rollout (np.ndarray): Array of simulated states using the new model.
179+
save_path (str): Path to save the plot.
180+
show (bool): Whether to display the plot.
181+
"""
182+
fig = plt.figure(figsize=(12, 6)) # Reduced height from 10 to 5
183+
gs = fig.add_gridspec(2, 2)
184+
185+
ax1 = fig.add_subplot(gs[0, 0])
186+
ax2 = fig.add_subplot(gs[1, 0])
187+
ax3 = fig.add_subplot(gs[:, 1])
188+
189+
plot_state(ax1, timespan, angle, old_rollout[:, 0], "Angle", color="blue", alpha=0.3)
190+
ax1.plot(timespan, new_rollout[:, 0], color="red", label="Optimized Model")
191+
192+
plot_state(ax2, timespan, velocity, old_rollout[:, 1], "Velocity", color="blue", alpha=0.3)
193+
ax2.plot(timespan, new_rollout[:, 1], color="red", label="Optimized Model")
194+
ax2.set_xlabel("Time (s)")
195+
196+
plot_phase_portrait(ax3, angle, velocity, old_rollout[:, 0], old_rollout[:, 1], color="blue", alpha=0.3)
197+
ax3.plot(new_rollout[:, 0], new_rollout[:, 1], color="red", label="Optimized Model")
198+
199+
plt.tight_layout()
200+
if save_path:
201+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
202+
plt.savefig(save_path, dpi=300)
203+
if show:
204+
plt.show()
205+
else:
206+
plt.close(fig)

0 commit comments

Comments
 (0)