• d3rlpy离线强化学习算法库安装及使用


    GitHub - takuseno/d3rlpy: An offline deep reinforcement learning library

    d3rlpy,离线强化学习算法

    我装在windows下用anaconda,按照官网教程

    conda install -c conda-forge d3rlpy

    第一次安装报错CondaSSLError: OpenSSL appears to be unavailable on this machine

    [报错解决]CondaSSLError: OpenSSL appears to be unavailable on this machine. OpenSSL is required to downl_一件迷途小书童的博客-CSDN博客

    参考这篇文章解决后正常安装没问题,值得注意的是d3rkpy安装时包含cudatoolkit11.几,我在想这个在不同电脑上可能之后会出错,不过后面运行算法时可以选择是否使用GPU

    我是打算用离线强化学习算法,安装后测试,官网上也有测试代码

    1. import d3rlpy
    2. # prepare dataset
    3. dataset, env = d3rlpy.datasets.get_d4rl('hopper-medium-v0')
    4. # prepare algorithm
    5. cql = d3rlpy.algos.CQL(use_gpu=True)
    6. # train
    7. cql.fit(
    8. dataset,
    9. eval_episodes=dataset,
    10. n_epochs=100,
    11. scorers={
    12. 'environment': d3rlpy.metrics.evaluate_on_environment(env),
    13. 'td_error': d3rlpy.metrics.td_error_scorer,
    14. },
    15. )

    看得出来,这接口用起来非常方便啊

    因为我没装d4rl所以肯定是失败了,d4rl数据集查了下资料可能无法装在windows环境下,有点难办。可以使用下面这个在测试,用的是d3rlpy自带用于测试的数据集,也是比较常用的两个环境,具体是在d3rlpy的文档上找到的

    1. import d3rlpy
    2. # prepare dataset
    3. # dataset, env = d3rlpy.datasets.get_d4rl('CartPole-v0')
    4. dataset, env = d3rlpy.datasets.get_pendulum("random")
    5. # prepare algorithm
    6. cql = d3rlpy.algos.CQL(use_gpu=True)
    7. # train
    8. cql.fit(
    9. dataset,
    10. eval_episodes=dataset,
    11. n_epochs=100,
    12. scorers={
    13. 'environment': d3rlpy.metrics.evaluate_on_environment(env),
    14. 'td_error': d3rlpy.metrics.td_error_scorer,
    15. },
    16. )

    资料很充分,d3rlpy文档:d3rlpy.datasets.get_cartpole — d3rlpy documentation

     成功运行:

    如果失败的话可能是下载失败,

    在这找到下载网址,自己下载到本地,改成规定的名字即可,放到对d3rlpy_data文件夹里,再运行时就不需要在线下载了,比如这样

     

    之后回到d4rl,我打算把自己的数据集按照d4rl的格式来编写,但我不打算装d4rl

    可以看到在d3rlpy中读取d4rl的数据集主要是用d4rl中的get_dataset函数,于是我索性把d4rl中这个函数搬到d3rlpy中,其实就是读取h5格式的函数,也挺好移植,主要也就这一段

    1. data_dict = {}
    2. with h5py.File(h5path, 'r') as dataset_file:
    3. for k in tqdm(get_keys(dataset_file), desc="load datafile"):
    4. try: # first try loading as an array
    5. data_dict[k] = dataset_file[k][:]
    6. except ValueError as e: # try loading as a scalar
    7. data_dict[k] = dataset_file[k][()]

    注意还需要

    1. import h5py
    2. from tqdm import tqdm

    1. def get_keys(h5file):
    2. keys = []
    3. def visitor(name, item):
    4. if isinstance(item, h5py.Dataset):
    5. keys.append(name)
    6. h5file.visititems(visitor)
    7. return keys

    至于原先是个类,我感觉好像也不需要,同时还是把在线改掉,直接变成一个绝对位置(这个在d4rl中也可以找到下载的网址)

    h5path = "D:\xxx_project\pycharm\offline_RL\d3rlpy_data\hopper_random.hdf5"

    运行成功

    我考虑下一步制作自己的hdf5格式数据集,及做下自己的gym环境

    甚至不能算是入门,希望没有问题,欢迎指正

  • 相关阅读:
    CListCtrl控件为只显示一列,持滚动显示其他,不用SetScrollFlags
    C++安装qt软件教程
    Python 推导式和递归
    esp8266 Task任务创建与执行
    git远程仓库分支推送与常见问题
    Redis 集群详解及搭建过程
    开发环境搭建---Ubuntu18.04开发环境搭建
    如何在 Vue.js 中使用 Axios
    在VS Code 中调试远程服务器的PHP代码
    VR酒店专业情景教学演示
  • 原文地址:https://blog.csdn.net/Already8888/article/details/128173013