Skip to content

Commit

Permalink
env fixes and updates
Browse files Browse the repository at this point in the history
  • Loading branch information
aPovidlo committed Aug 1, 2023
1 parent 01da308 commit a713f55
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 7 deletions.
5 changes: 4 additions & 1 deletion rl_core/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def pipeline_fitting_and_evaluating(self) -> float:
y_pred = pred.predict
y_true = self.val_data.target

self.metric_value = self.metric(y_score=y_pred, y_true=y_true)
try:
self.metric_value = self.metric(y_score=y_pred, y_true=y_true)
except:
self.metric_value = -0.999

return self.metric_value

Expand Down
2 changes: 1 addition & 1 deletion rl_core/environments/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ def _environment_response(self, reward: float, done: bool) -> (int, bool, dict):
'metric_value': self.metric_value,
}

return reward, done, info
return reward, done, info
8 changes: 4 additions & 4 deletions rl_core/environments/linear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
from copy import deepcopy

import numpy as np
import torch
Expand Down Expand Up @@ -33,12 +34,11 @@ def reset(self, **kwargs):
self.pipeline = PipelineBuilder()
self.time_step = 0
self.metric_value = 0

self.position = 0

self.init_state()

return self.state
return deepcopy(self.state)

def _train_step(self, action):
self.last_action = action
Expand Down Expand Up @@ -69,7 +69,7 @@ def _train_step(self, action):

reward, done, info = self._environment_response(reward, done)

return self.state, reward, done, info
return deepcopy(self.state), reward, done, info

def _inference_step(self, action):
raise NotImplementedError()
Expand All @@ -81,4 +81,4 @@ def _environment_response(self, reward: float, done: bool) -> (int, bool, dict):
'metric_value': self.metric_value,
}

return reward, done, info
return reward, done, info
64 changes: 63 additions & 1 deletion test/unit/rl_test/test_environment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os.path

import numpy as np
import torch
from fedot.core.pipelines.pipeline import Pipeline
from fedot.core.repository.operation_types_repository import OperationTypesRepository
from fedot.core.repository.tasks import TaskTypesEnum
Expand Down Expand Up @@ -136,4 +137,65 @@ def test_ensemble_pipeline_environment_pipeline_build():
assert isinstance(info['pipeline'], Pipeline)
assert isinstance(info['time_step'], int)
assert isinstance(info['metric_value'], float)
assert info['metric_value'] > 0.5
assert info['metric_value'] > 0.5


def test_linear_pipeline_environment_pipeline_reset():
primitives = OperationTypesRepository('all').suitable_operation(
task_type=TaskTypesEnum.classification)

state_dim = 2
env = LinearPipelineGenerationEnvironment(state_dim=state_dim, primitives=primitives)

path = os.path.join(str(project_root()), 'MetaFEDOT/rl_core/data/scoring_train.csv')

datasets = {
'scoring': path,
}

dataloader = DataLoader(datasets)

train_data, val_data = dataloader.get_data()
env.load_data(train_data, val_data)

actions = [env.primitives.index(p) for p in ['rf', 'dt', 'eop']]

for action in actions:
new_state, r, done, info = env.step(action)

env.reset()

assert torch.all(env.state == torch.tensor(np.array([0, 0]), dtype=torch.float64))

for action in actions:
new_state, r, done, info = env.step(action)

assert torch.all(new_state == torch.tensor(np.array([8, 2]), dtype=torch.float64))


def test_linear_pipeline_environment_state():
primitives = OperationTypesRepository('all').suitable_operation(
task_type=TaskTypesEnum.classification)

state_dim = 4
env = LinearPipelineGenerationEnvironment(state_dim=state_dim, primitives=primitives)

path = os.path.join(str(project_root()), 'MetaFEDOT/rl_core/data/scoring_train.csv')

datasets = {
'scoring': path,
}

dataloader = DataLoader(datasets)

train_data, val_data = dataloader.get_data()
env.load_data(train_data, val_data)

actions = [env.primitives.index(p) for p in ['knn', 'dt', 'scaling', 'rf']]
states = []

for action in actions:
new_state, r, done, info = env.step(action)
states.append(new_state.numpy())

assert np.unique(states, axis=1).shape[1] == len(states)

0 comments on commit a713f55

Please sign in to comment.