Tutorials¶
CartPole¶
Download dataset¶
First of all, download the cartpole dataset as follows:
$ wget https://www.dropbox.com/s/vc7fm7qdnu0kh01/cartpole.csv?dl=1 -O cartpole.csv
Or access to https://www.dropbox.com/s/vc7fm7qdnu0kh01/cartpole.csv .
Train¶
Follow instruction from Upload Dataset to Start Training.
Deploy¶
Finally, you can download the trained policy as Export Policy Function. At this time, you have two options of the model format, TorchScript and ONNX.
TorchScript¶
You can load the policy in two lines of codes only with PyTorch.
import torch
policy = torch.jit.load('policy.pt')
It’s easy, right?
Then you can write the rest of interaction codes as usual.
import gym
env = gym.make('CartPole-v0')
observation = env.reset()
while True:
# feed observation to the policy
action = policy(torch.tensor([observation], dtype=torch.float32))
# take action to get next observation
observation, _, done, _ = env.step(action[0].numpy())
# rendering environment
env.render()
# break if the episode reaches the termination
if done:
break
ONNX¶
In this tutorial, onnxruntime is used to load the model.
import onnxruntime as ort
ort_session = ort.InferenceSession('policy.onnx')
Basically, ONNX is also easy to load.
Then you can write the rest of interaction codes like above.
import gym
env = gym.make('CartPole-v0')
observation = env.reset()
while True:
# change dtype strictly to float32 and expand its shape
observation = observation.astype('f4').reshape((1, -1))
# feed observation to the policy
action = ort_session.run(None, {'input_0': observation})[0]
# take action to get next observation
observation, _, done, _ = env.step(action[0])
# rendering environment
env.render()
# break if the episode reaches the termination
if done:
break