Как распечатать распределения действий в RLlib во время обучения?

Я пытаюсь распечатать распределения действий в конце каждого эпизода, чтобы увидеть, что делает мой агент. Я попытался поставить это rock_paper_scissors_multiagent.py, включив следующий метод

def on_episode_end(info):
    episode = info["episode"]
    policy = episode._policies['learned']
    print(policy.model.base_model.summary())

    obs_space = Tuple((Discrete(3),Discrete(3)))
    prep = get_preprocessor(obs_space)(obs_space)
    curr_state = list((0,1))
    curr_state = tuple(curr_state)
    curr_state = prep.transform(curr_state)
    logits, _ = policy.model.from_batch({"obs": np.array([curr_state])})
    dist = policy.dist_class(logits, policy.model)
    dist.sample()
    print(dist.logp([0]))

И добавляем опцию обратного вызова в tune.run. Однако я получаю следующую ошибку. Вот как я должен пытаться печатать политики после каждой серии? И если да, то что я делаю не так? Исходный пример rock_paper_scissors_multiagent.py находится здесь: https://github.com/ray-project/ray/blob/master/rllib/examples/rock_paper_scissors_multiagent.py

>Traceback (most recent call last):
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/tune/trial_runner.py", line 515, in _process_trial
    result = self.trial_executor.fetch_result(trial)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/tune/ray_trial_executor.py", line 351, in fetch_result
    result = ray.get(trial_future[0])
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/worker.py", line 2121, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ValueError): [36mray_worker[39m (pid=5765, host=Charlies-MBP.fios-router.home)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 418, in train
    raise e
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 407, in train
    result = Trainable.train(self)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/tune/trainable.py", line 176, in train
    result = self._train()
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/agents/trainer_template.py", line 129, in _train
    fetches = self.optimizer.step()
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/optimizers/multi_gpu_optimizer.py", line 140, in step
    self.num_envs_per_worker, self.train_batch_size)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/optimizers/rollout.py", line 29, in collect_samples
    next_sample = ray_get_and_free(fut_sample)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/utils/memory.py", line 33, in ray_get_and_free
    result = ray.get(object_ids)
ray.exceptions.RayTaskError(ValueError): [36mray_worker[39m (pid=5768, host=Charlies-MBP.fios-router.home)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 469, in sample
    batches = [self.input_reader.next()]
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 56, in next
    batches = [self.get_data()]
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 99, in get_data
    item = next(self.rollout_provider)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 319, in _env_runner
    soft_horizon, no_done_at_end)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 473, in _process_observations
    "episode": episode
  File "rock_paper_scissors_multiagent.py", line 204, in on_episode_end
    logits, _ = policy.model.from_batch({"obs": np.array([curr_state])})
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/models/modelv2.py", line 197, in from_batch
    return self.__call__(input_dict, states, train_batch.get("seq_lens"))
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/models/modelv2.py", line 154, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/models/tf/fcnet_v2.py", line 84, in forward
    model_out, self._value_out = self.base_model(input_dict["obs_flat"])
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 634, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 751, in call
    return self._run_internal_graph(inputs, training=training, mask=mask)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 893, in _run_internal_graph
    output_tensors = layer(computed_tensors, **kwargs)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 586, in __call__
    self.name)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/input_spec.py", line 159, in assert_input_compatibility
    ' but received input with shape ' + str(shape))
ValueError: Input 0 of layer fc_value_1 is incompatible with the layer: expected axis -1 of input shape to have value 3 but received input with shape [1, 6]  

person Charlie Hou    schedule 09.11.2019    source источник


Ответы (1)


Вы можете получить доступ к action словарю в объекте sample_batch, sample_obj.columns(["actions"] в on_postprocess_traj обратном вызове определить следующим образом:

# The callback function

def on_postprocess_traj(info):
    """
    arg: {"agent_id": ..., "episode": ...,
          "pre_batch": (before processing),
          "post_batch": (after processing),
          "all_pre_batches": (other agent ids),
          }

    # https://github.com/ray-project/ray/blob/ee8c9ff7320ec6a2d7d097cd5532005c6aeb216e/rllib/policy/sample_batch.py
    Dictionaries in a sample_obj, k:
        t
        eps_id
        agent_index
        obs
        actions
        rewards
        prev_actions
        prev_rewards
        dones
        infos
        new_obs
        action_prob
        action_logp
        vf_preds
        behaviour_logits
        unroll_id       
    """
    agt_id = info["agent_id"]
    eps_id = info["episode"].episode_id
    policy_obj = info["pre_batch"][0]
    sample_obj = info["pre_batch"][1]    

    if(agt_id == 'player1'):
        print('agent_id = {}'.format(agt_id))
        print('episode = {}'.format(eps_id))

        #print("on_postprocess_traj info = {}".format(info))
        #print("on_postprocess_traj sample_obj = {}".format(sample_obj))
        print('actions = {}'.format(sample_obj.columns(["actions"])))
    return

Вам также нужно будет добавить функцию обратного вызова в свою конфигурацию следующим образом:

             config={"env": RockPaperScissorsEnv,
                     #"eager": True,
                     "gamma": 0.9,
                     "num_workers": 1,
                     "num_envs_per_worker": 4,
                     "sample_batch_size": 10,
                     "train_batch_size": 200,
                     #"multiagent": {"policies_to_train": ["learned"],
                     "multiagent": {"policies_to_train": ["learned", "learned_2"],
                                    "policies": {"always_same": (AlwaysSameHeuristic, Discrete(3), Discrete(3), {}),
                                                 #"beat_last": (BeatLastHeuristic, Discrete(3), Discrete(3), {}),
                                                 "learned": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm}}),
                                                 "learned_2": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm}}),
                                                 },
                                    "policy_mapping_fn": select_policy,
                                   },
                      "callbacks": {#"on_episode_start": on_episode_start, 
                                    #"on_episode_step": on_episode_step, 
                                    #"on_episode_end": on_episode_end, 
                                    #"on_sample_end": on_sample_end,
                                    "on_postprocess_traj": on_postprocess_traj,
                                    #"on_train_result": on_train_result,
                                    }

Результат ниже показывает результат выполнения примера rock_paper_scissors_multiagent.pyray[rllib]==0.8.2 в Colab), обратите внимание на распечатку идентификатора агента, идентификатора эпизода и траектории действия:

== Status ==
Memory usage on this node: 1.3/12.7 GiB
Using FIFO scheduling algorithm.
Resources requested: 2/2 CPUs, 0/0 GPUs, 0.0/7.18 GiB heap, 0.0/2.44 GiB objects
Result logdir: /root/ray_results/PPO
Number of trials: 1 (1 RUNNING)
Trial name  status  loc
PPO_RockPaperScissorsEnv_979bff44   RUNNING 


(pid=1541) 2020-04-25 12:45:10,823  INFO trainer.py:420 -- Tip: set 'eager': true or the --eager flag to enable TensorFlow eager execution
(pid=1541) 2020-04-25 12:45:10,827  INFO trainer.py:580 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
(pid=1541) /usr/local/lib/python3.6/dist-packages/gym/logger.py:30: UserWarning: WARN: Box bound precision lowered by casting to float32
(pid=1541)   warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))
(pid=1587) /usr/local/lib/python3.6/dist-packages/gym/logger.py:30: UserWarning: WARN: Box bound precision lowered by casting to float32
(pid=1587)   warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))
(pid=1541) 2020-04-25 12:45:19,048  WARNING util.py:37 -- Install gputil for GPU system monitoring.
(pid=1587) agent_id = player1
(pid=1587) episode = 975148816
(pid=1587) actions = [array([1, 1, 0, 2, 0, 0, 1, 2, 1, 2])]
(pid=1587) agent_id = player1
(pid=1587) episode = 942369634
(pid=1587) actions = [array([1, 2, 1, 2, 2, 2, 1, 0, 2, 0])]
(pid=1587) agent_id = player1
(pid=1587) episode = 296105405
(pid=1587) actions = [array([2, 2, 0, 2, 2, 1, 2, 1, 0, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 475466940
(pid=1587) actions = [array([0, 2, 1, 0, 2, 0, 2, 1, 0, 2])]
(pid=1587) agent_id = player1
(pid=1587) episode = 793839240
(pid=1587) actions = [array([0, 0, 1, 2, 0, 2, 1, 1, 1, 2])]
(pid=1587) agent_id = player1
(pid=1587) episode = 578652318
(pid=1587) actions = [array([0, 1, 0, 0, 2, 1, 2, 2, 1, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 112165627
(pid=1587) actions = [array([2, 1, 2, 1, 0, 0, 0, 1, 1, 0]
(pid=1587) agent_id = player1
(pid=1587) episode = 996828544
(pid=1587) actions = [array([1, 2, 2, 2, 0, 0, 1, 2, 0, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 94669775
(pid=1587) actions = [array([1, 0, 1, 1, 2, 0, 2, 1, 2, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 1063457620
(pid=1587) actions = [array([1, 0, 2, 1, 2, 2, 1, 2, 2, 0])]
(pid=1587) agent_id = player1
(pid=1587) episode = 1956229719
(pid=1587) actions = [array([0, 0, 2, 1, 2, 2, 2, 1, 2, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 503578202
(pid=1587) actions = [array([1, 2, 0, 0, 0, 0, 1, 0, 0, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 1599756661
(pid=1587) actions = [array([0, 0, 1, 2, 0, 2, 2, 2, 1, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 1333277267
(pid=1587) actions = [array([0, 2, 1, 0, 1, 1, 2, 2, 2, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 1832916757
(pid=1587) actions = [array([1, 1, 0, 0, 2, 1, 0, 1, 1, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 585983090
(pid=1587) actions = [array([1, 2, 1, 2, 2, 1, 0, 2, 0, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 1731969708
(pid=1587) actions = [array([2, 1, 0, 2, 2, 0, 0, 0, 1, 0])]
(pid=1587) agent_id = player1
(pid=1587) episode = 374111939
(pid=1587) actions = [array([0, 0, 0, 2, 0, 2, 2, 0, 1, 0])]
(pid=1587) agent_id = player1
(pid=1587) episode = 399432786
(pid=1587) actions = [array([0, 2, 0, 0, 0, 1, 0, 0, 1, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 396598872
(pid=1587) actions = [array([1, 1, 0, 2, 0, 2, 0, 2, 1, 0])]
Result for PPO_RockPaperScissorsEnv_979bff44:
  custom_metrics: {}
  date: 2020-04-25_12-45-24
  done: true
  episode_len_mean: 10.0
  episode_reward_max: 0.0
  episode_reward_mean: 0.0
  episode_reward_min: 0.0
  episodes_this_iter: 20
  episodes_total: 20
  experiment_id: 87214df9c01d4efeae8edd4d656a6ca4
  experiment_tag: '0'
  hostname: 2ebf5ae102f8
  info:
    grad_time_ms: 1005.051
    learner:
      learned:
        cur_kl_coeff: 0.20000000298023224
        cur_lr: 4.999999873689376e-05
        entropy: 1.0945309400558472
        entropy_coeff: 0.0
        kl: 0.004110474139451981
        policy_loss: -0.0945899486541748
        total_loss: 2.941073417663574
        vf_explained_var: 0.00013327598571777344
        vf_loss: 3.034841299057007
      learned_2:
        cur_kl_coeff: 0.20000000298023224
        cur_lr: 4.999999873689376e-05
        entropy: 1.0941331386566162
        entropy_coeff: 0.0
        kl: 0.004472262226045132
        policy_loss: -0.0190987978130579
        total_loss: 3.0051088333129883
        vf_explained_var: 0.008207857608795166
        vf_loss: 3.023313045501709
    load_time_ms: 179.466
    num_steps_sampled: 200
    num_steps_trained: 128
    sample_time_ms: 343.341
    update_time_ms: 2861.349
  iterations_since_restore: 1
  node_ip: 172.28.0.2
  num_healthy_workers: 1
  off_policy_estimator: {}
  perf:
    cpu_util_percent: 85.65
    ram_util_percent: 16.225
  pid: 1541
  policy_reward_max:
    learned: 6.0
    learned_2: 6.0
  policy_reward_mean:
    learned: -0.15
    learned_2: 0.15
  policy_reward_min:
    learned: -6.0
    learned_2: -6.0
  sampler_perf:
    mean_env_wait_ms: 0.062040254181506584
    mean_inference_ms: 3.5300535314223347
    mean_processing_ms: 1.2217222475538068
  time_since_restore: 4.562142610549927
  time_this_iter_s: 4.562142610549927
  time_total_s: 4.562142610549927
  timestamp: 1587818724
  timesteps_since_restore: 200
  timesteps_this_iter: 200
  timesteps_total: 200
  training_iteration: 1
  trial_id: 979bff44

== Status ==
Memory usage on this node: 2.0/12.7 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/2 CPUs, 0/0 GPUs, 0.0/7.18 GiB heap, 0.0/2.44 GiB objects
Result logdir: /root/ray_results/PPO
Number of trials: 1 (1 TERMINATED)
Trial name  status  loc reward  total time (s)  ts  iter
PPO_RockPaperScissorsEnv_979bff44   TERMINATED      0   4.56214 200 1


== Status ==
Memory usage on this node: 1.9/12.7 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/2 CPUs, 0/0 GPUs, 0.0/7.18 GiB heap, 0.0/2.44 GiB objects
Result logdir: /root/ray_results/PPO
Number of trials: 1 (1 TERMINATED)
Trial name  status  loc reward  total time (s)  ts  iter
PPO_RockPaperScissorsEnv_979bff44   TERMINATED      0   4.56214 200 1


2020-04-25 12:45:24,345 INFO tune.py:352 -- Returning an analysis object by default. You can call `analysis.trials` to retrieve a list of trials. This message will be removed in future versions of Tune.

Вы можете не только получить доступ к действиям, но и таким образом получить доступ ко всей полезной информации до / после пакетной обработки (траектории). Взгляните на комментарии, которые я сделал в функции обратного вызова, чтобы получить список доступных имен словарей (например, obs, rewards), которые также могут оказаться полезными.

Полный rock_paper_scissors_multiagent.py пример кода, который печатает вышеприведенный вывод, показан ниже:

#!pip install ray[rllib]==0.8.2

"""A simple multi-agent env with two agents playing rock paper scissors.
This demonstrates running the following policies in competition:
    (1) heuristic policy of repeating the same move
    (2) heuristic policy of beating the last opponent move
    (3) LSTM/feedforward PG policies
    (4) LSTM policy with custom entropy loss
"""

import argparse
import random
from gym.spaces import Discrete

from ray import tune
from ray.rllib.agents.pg.pg import PGTrainer
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils import try_import_tf

tf = try_import_tf()

ROCK = 0
PAPER = 1
SCISSORS = 2

parser = argparse.ArgumentParser()
parser.add_argument("--stop", type=int, default=400000)

class RockPaperScissorsEnv(MultiAgentEnv):
    """Two-player environment for rock paper scissors.
    The observation is simply the last opponent action."""

    def __init__(self, _):
        self.action_space = Discrete(3)
        self.observation_space = Discrete(3)
        self.player1 = "player1"
        self.player2 = "player2"
        self.last_move = None
        self.num_moves = 0

    def reset(self):
        self.last_move = (0, 0)
        self.num_moves = 0
        return {
            self.player1: self.last_move[1],
            self.player2: self.last_move[0],
        }

    def step(self, action_dict):
        move1 = action_dict[self.player1]
        move2 = action_dict[self.player2]
        self.last_move = (move1, move2)
        obs = {
            self.player1: self.last_move[1],
            self.player2: self.last_move[0],
        }

        r1, r2 = {
            (ROCK, ROCK): (0, 0),
            (ROCK, PAPER): (-1, 1),
            (ROCK, SCISSORS): (1, -1),
            (PAPER, ROCK): (1, -1),
            (PAPER, PAPER): (0, 0),
            (PAPER, SCISSORS): (-1, 1),
            (SCISSORS, ROCK): (-1, 1),
            (SCISSORS, PAPER): (1, -1),
            (SCISSORS, SCISSORS): (0, 0),
        }[move1, move2]
        rew = {
            self.player1: r1,
            self.player2: r2,
        }
        self.num_moves += 1
        done = {
            "__all__": self.num_moves >= 10,
        }

        #print('obs', obs)

        return obs, rew, done, {}

class AlwaysSameHeuristic(Policy):
    """Pick a random move and stick with it for the entire episode."""

    def get_initial_state(self):
        return [random.choice([ROCK, PAPER, SCISSORS])]

    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        **kwargs):
        return list(state_batches[0]), state_batches, {}

    def learn_on_batch(self, samples):
        pass

    def get_weights(self):
        pass

    def set_weights(self, weights):
        pass

class BeatLastHeuristic(Policy):
    """Play the move that would beat the last move of the opponent."""

    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        **kwargs):
        def successor(x):
            if x[ROCK] == 1:
                return PAPER
            elif x[PAPER] == 1:
                return SCISSORS
            elif x[SCISSORS] == 1:
                return ROCK

        return [successor(x) for x in obs_batch], [], {}

    def learn_on_batch(self, samples):
        pass

    def get_weights(self):
        pass

    def set_weights(self, weights):
        pass

def on_postprocess_traj(info):
    """
    arg: {"agent_id": ..., "episode": ...,
          "pre_batch": (before processing),
          "post_batch": (after processing),
          "all_pre_batches": (other agent ids),
          }

    # https://github.com/ray-project/ray/blob/ee8c9ff7320ec6a2d7d097cd5532005c6aeb216e/rllib/policy/sample_batch.py
    Dictionaries in a sample_obj, k:
        t
        eps_id
        agent_index
        obs
        actions
        rewards
        prev_actions
        prev_rewards
        dones
        infos
        new_obs
        action_prob
        action_logp
        vf_preds
        behaviour_logits
        unroll_id
    """
    agt_id = info["agent_id"]
    eps_id = info["episode"].episode_id
    policy_obj = info["pre_batch"][0]
    sample_obj = info["pre_batch"][1]

    if(agt_id == 'player1'):
        print('agent_id = {}'.format(agt_id))
        print('episode = {}'.format(eps_id))

        #print("on_postprocess_traj info = {}".format(info))
        #print("on_postprocess_traj sample_obj = {}".format(sample_obj))
        print('actions = {}'.format(sample_obj.columns(["actions"])))
    return

def run_same_policy():
    """Use the same policy for both agents (trivial case)."""

    #tune.run("PG", config={"env": RockPaperScissorsEnv})
    tune.run("PPO", config={"env": RockPaperScissorsEnv})

#def run_heuristic_vs_learned(use_lstm=False, trainer="PG"):
def run_heuristic_vs_learned(use_lstm=False, trainer="PPO"):
    """Run heuristic policies vs a learned agent.
    The learned agent should eventually reach a reward of ~5 with
    use_lstm=False, and ~7 with use_lstm=True. The reason the LSTM policy
    can perform better is since it can distinguish between the always_same vs
    beat_last heuristics.
    """

    def select_policy(agent_id):
        if agent_id == "player1":
            return "learned"

        elif agent_id == "player2":
            return "learned_2"

        else:
            return random.choice(["always_same", "beat_last"])

    #args = parser.parse_args()
    tune.run(trainer,
             #stop={"timesteps_total": args.stop},
             #stop={"timesteps_total": 400000},
             stop={"timesteps_total": 3},

             config={"env": RockPaperScissorsEnv,
                     #"eager": True,
                     "gamma": 0.9,
                     "num_workers": 1,
                     "num_envs_per_worker": 4,
                     "sample_batch_size": 10,
                     "train_batch_size": 200,
                     #"multiagent": {"policies_to_train": ["learned"],
                     "multiagent": {"policies_to_train": ["learned", "learned_2"],
                                    "policies": {"always_same": (AlwaysSameHeuristic, Discrete(3), Discrete(3), {}),
                                                 #"beat_last": (BeatLastHeuristic, Discrete(3), Discrete(3), {}),
                                                 "learned": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm}}),
                                                 "learned_2": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm}}),
                                                 },
                                    "policy_mapping_fn": select_policy,
                                   },
                      "callbacks": {#"on_episode_start": on_episode_start,
                                    #"on_episode_step": on_episode_step,
                                    #"on_episode_end": on_episode_end,
                                    #"on_sample_end": on_sample_end,
                                    "on_postprocess_traj": on_postprocess_traj,
                                    #"on_train_result": on_train_result,
                                    }
                    }
             )

def run_with_custom_entropy_loss():
    """Example of customizing the loss function of an existing policy.
    This performs about the same as the default loss does."""

    def entropy_policy_gradient_loss(policy, model, dist_class, train_batch):
        logits, _ = model.from_batch(train_batch)
        action_dist = dist_class(logits, model)
        return (-0.1 * action_dist.entropy() - tf.reduce_mean(
            action_dist.logp(train_batch["actions"]) *
            train_batch["advantages"]))

    EntropyPolicy = PGTFPolicy.with_updates(
        loss_fn=entropy_policy_gradient_loss)
    EntropyLossPG = PGTrainer.with_updates(
        name="EntropyPG", get_policy_class=lambda _: EntropyPolicy)
    run_heuristic_vs_learned(use_lstm=True, trainer=EntropyLossPG)

'''
if __name__ == "__main__":
    # run_same_policy()
    # run_heuristic_vs_learned(use_lstm=False)
    run_heuristic_vs_learned(use_lstm=False)
    # run_with_custom_entropy_loss()
'''
#run_same_policy()
run_heuristic_vs_learned(use_lstm=False)
#run_with_custom_entropy_loss()

person Huan    schedule 25.04.2020