Skip to content

Commit

Permalink
Logging of random search results #16
Browse files Browse the repository at this point in the history
  • Loading branch information
B96 committed Jul 8, 2019
1 parent 0ccd2f1 commit e0923dc
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 67 deletions.
91 changes: 62 additions & 29 deletions AutoTuning/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,31 +57,6 @@
<version>3.22.1.4</version>
</dependency>

<!-- Arbiter - used for hyperparameter optimization (grid/random search) -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-deeplearning4j</artifactId>
<version>1.0.0-beta4</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-ui_2.11</artifactId>
<version>1.0.0-beta4</version>
</dependency>

<!-- jmetal -->
<dependency>
<groupId>org.uma.jmetal</groupId>
<artifactId>jmetal-core</artifactId>
<version>5.7</version>
</dependency>

<dependency>
<groupId>org.moeaframework</groupId>
<artifactId>moeaframework</artifactId>
<version>2.12</version>
</dependency>

<!-- Loading data from CSV -->
<dependency>
<groupId>org.apache.commons</groupId>
Expand All @@ -107,11 +82,69 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13-beta-2</version>
<scope>test</scope>
<groupId>com.opencsv</groupId>
<artifactId>opencsv</artifactId>
<version>4.1</version>
</dependency>

<!-- EVERYTHING FOR SMAC -->
<dependency>
<groupId>commons-collections</groupId>
<artifactId>commons-collections</artifactId>
<version>3.2.1</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-math3 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.3</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-math -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math</artifactId>
<version>2.2</version>
</dependency>
<dependency>
<groupId>com.smac</groupId>
<artifactId>smac</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>com.smac</groupId>
<artifactId>aeatk</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>com.smac</groupId>
<artifactId>fastrf</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>com.smac</groupId>
<artifactId>jcommander</artifactId>
<version>1.0.0</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.fasterxml.jackson.core/jackson-core -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>2.9.9</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.fasterxml.jackson.core/jackson-databind -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.9.9</version>
</dependency>
<!-- https://mvnrepository.com/artifact/ch.qos.logback/logback-core -->
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
<version>1.2.3</version>
</dependency>


</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ public class PredictionModel {
* - if one rule applies predict the label according to rule
* - if two rules apply predict according to more precise rule
* - if both rules are equally precise predict randomly
* - what about instances not covered by any rules??
*/

private List<Rule> rules;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package RandomSearch;

import org.bytedeco.javacpp.opencv_core;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public final class HyperparameterSpace {
Expand All @@ -30,6 +27,7 @@ public HyperparameterSpace() {
}

public HyperparameterSpace(HyperparameterSpace copyFrom) {
this.runtime = copyFrom.getRuntime();
this.coverage = copyFrom.getCoverage();
this.performance = copyFrom.getPerformance();
this.hyperParameters = clone(copyFrom.getHyperParameters());
Expand Down
56 changes: 30 additions & 26 deletions AutoTuning/src/main/java/RandomSearch/RandomSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ public class RandomSearch {
*/
private RandomSearch(String scenario, long terminationConditionInSec, int terminationConditionNrEx, boolean startWithDefault) {
this.scenario = scenario;
this.bestHyperparameterSpace = new HyperparameterSpace();
this.currentHyperparameterSpace = new HyperparameterSpace();
if (!startWithDefault) {
this.currentHyperparameterSpace = new HyperparameterSpace();
this.currentHyperparameterSpace = randomizeHyperparameters(currentHyperparameterSpace);
}
this.terminationConditionInSec = terminationConditionInSec;
this.terminationConditionNrEx = terminationConditionNrEx;
this.bestHyperparameterSpace = new HyperparameterSpace();


}

public RandomSearch(String scenario, long terminationConditionInSec, boolean startWithDefault) {
Expand Down Expand Up @@ -63,54 +65,56 @@ public void execute(Function<TabularInstance, Integer> classificationFunction, A
long startTime = System.currentTimeMillis();
int nrExecutions = 0;

while ((System.currentTimeMillis() - startTime) < (this.terminationConditionInSec * 1000) || nrExecutions < this.terminationConditionNrEx) {
// init logging the configurations
RandomSearchLogger randomSearchLogger = new RandomSearchLogger(scenario, bestHyperparameterSpace);

while ((System.currentTimeMillis() - startTime) < (terminationConditionInSec * 1000) || nrExecutions < this.terminationConditionNrEx) {

// to calculate the runtime of each Anchors run
long runtimeStart = System.currentTimeMillis();

// set all hyperparameters

anchorBuilder
.setTau(this.currentHyperparameterSpace.getParameterByName("tau").getCurrentValue().doubleValue())
.setBeamSize(this.currentHyperparameterSpace.getParameterByName("beamsize").getCurrentValue().intValue())
.setDelta(this.currentHyperparameterSpace.getParameterByName("delta").getCurrentValue().doubleValue())
.setEpsilon(this.currentHyperparameterSpace.getParameterByName("epsilon").getCurrentValue().doubleValue())
.setTauDiscrepancy(this.currentHyperparameterSpace.getParameterByName("tauDiscrepancy").getCurrentValue().doubleValue())
.setInitSampleCount(this.currentHyperparameterSpace.getParameterByName("initSampleCount").getCurrentValue().intValue());
.setTau(currentHyperparameterSpace.getParameterByName("tau").getCurrentValue().doubleValue())
.setBeamSize(currentHyperparameterSpace.getParameterByName("beamsize").getCurrentValue().intValue())
.setDelta(currentHyperparameterSpace.getParameterByName("delta").getCurrentValue().doubleValue())
.setEpsilon(currentHyperparameterSpace.getParameterByName("epsilon").getCurrentValue().doubleValue())
.setTauDiscrepancy(currentHyperparameterSpace.getParameterByName("tauDiscrepancy").getCurrentValue().doubleValue())
.setInitSampleCount(currentHyperparameterSpace.getParameterByName("initSampleCount").getCurrentValue().intValue());


// execute Coverage Pick of Anchors and get result
final List<AnchorResult<TabularInstance>> globalExplanations = new CoveragePick<>(anchorBuilder, 10,
Executors.newCachedThreadPool(), null)
.run(anchorTabular.getTabularInstances(), 20);

// // set runtime of current Anchors run
this.currentHyperparameterSpace.setRuntime(System.currentTimeMillis() - runtimeStart);
//
// // init prediction model of created global rules
// set runtime of current Anchors run
currentHyperparameterSpace.setRuntime(System.currentTimeMillis() - runtimeStart);

// predict labels of instances based on generated global rules
PredictionModel model = new PredictionModel(globalExplanations);
//
// // predict data set labels
List<Integer> prediction = model.predict(anchorTabular.getTabularInstances());
//
// // init performance measure

// init performance measure
PerformanceMeasures performanceMeasures = new PerformanceMeasures(prediction, classificationFunction, anchorTabular.getTabularInstances());
this.currentHyperparameterSpace.setPerformance(performanceMeasures.calcMeasure(measure));
this.currentHyperparameterSpace.setCoverage(performanceMeasures.getCoverage());
//
// // check if performance of current space is the best, if yes set current space as best space
if (checkIfBetter(this.currentHyperparameterSpace.getPerformance() * this.currentHyperparameterSpace.getCoverage())) {
this.bestHyperparameterSpace = new HyperparameterSpace(this.currentHyperparameterSpace);
this.bestGlobalExplanations = globalExplanations;
currentHyperparameterSpace.setPerformance(performanceMeasures.calcMeasure(measure));
currentHyperparameterSpace.setCoverage(performanceMeasures.getCoverage());
randomSearchLogger.addValuesToLogging(currentHyperparameterSpace);

// check if performance of current space is the best, if yes set current space as best space
if (checkIfBetter(currentHyperparameterSpace.getPerformance() * currentHyperparameterSpace.getCoverage())) {
bestHyperparameterSpace = new HyperparameterSpace(currentHyperparameterSpace);
bestGlobalExplanations = globalExplanations;
}
//
// // randomize all hyperparameters
this.currentHyperparameterSpace = randomizeHyperparameters(this.currentHyperparameterSpace);
currentHyperparameterSpace = randomizeHyperparameters(currentHyperparameterSpace);

nrExecutions++;
}

// RandomSearchLogger randomSearchLogger = new RandomSearchLogger(this.scenario, this.bestHyperparameterSpace);
randomSearchLogger.endLogging();

// visualize best hyperparameters and the best global explanations
visualizeBestHyperparameterSpace();
Expand Down
55 changes: 55 additions & 0 deletions AutoTuning/src/main/java/RandomSearch/RandomSearchLogger.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package RandomSearch;

import java.io.*;
import java.util.Date;

public class RandomSearchLogger {

private final FileWriter fileWriterCSV;
private String directory = "./rs-output/";

public RandomSearchLogger(String scenario, HyperparameterSpace hyperparameterSpace) {

try {
// create directory for random search output if not yet created
new File(directory + scenario).mkdirs();
File f = new File("rundata_" + new Date().getTime());
fileWriterCSV = new FileWriter(directory + scenario + File.separator + f.getName() + ".csv");
StringBuilder header = new StringBuilder();
for (Parameter p : hyperparameterSpace.getHyperParameters()) {
header.append(p.getName() + ",");
}
fileWriterCSV.write(header.toString() + "runtime,coverage,performance" + "\n");

} catch (IOException e) {
throw new IllegalStateException("Error occurred creating file", e);
}

}

public void addValuesToLogging(HyperparameterSpace hyperparameterSpace) {

try {
StringBuilder values = new StringBuilder();
for (Parameter p : hyperparameterSpace.getHyperParameters()) {
values.append(p.getCurrentValue() + ",");
}
values.append(hyperparameterSpace.getRuntime() + ",");
values.append(hyperparameterSpace.getCoverage() + ",");
values.append(hyperparameterSpace.getPerformance());

fileWriterCSV.write(values.toString() + "\n");
} catch (IOException e) {
throw new IllegalStateException("Error occurred writing in the file", e);
}

}

public void endLogging() {
try {
fileWriterCSV.close();
} catch (IOException e) {
throw new IllegalStateException("Error occurred closing the file", e);
}
}
}
19 changes: 11 additions & 8 deletions AutoTuning/src/main/java/main.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,30 @@ public class main {

public static void main(String[] args) throws IOException {
// Load dataset and its description
final AnchorTabular anchorTabular = TitanicDataset.createTabularTrainingDefinition();
final AnchorTabular anchorTabular = HellaDataset.createTabularTrainingDefinition();
final AnchorTabular anchorTabularTest = HellaDataset.createTabularTestDefinition();

// Obtain a second suitable model (RandomForest). Train it ourselves this time.
final TabularRandomForestClassifier randomForestModel = new TabularRandomForestClassifier(100);
randomForestModel.fit(anchorTabular.getTabularInstances());

// final H2OHellaWrapper h2oModel = new H2OHellaWrapper();
final H2OHellaWrapper h2oModel = new H2OHellaWrapper();

// Print the model's test data accuracy
outputTestsetAccuracy("RandomForest", randomForestModel);
//outputTestsetAccuracy("RandomForest", randomForestModel);

// Pick instance to be explained
// Next pick specific instance (countess or patrick dooley)
final TabularInstance explainedInstance = anchorTabular.getTabularInstances()[1];
final TabularInstance explainedInstance = anchorTabular.getTabularInstances()[1704];

final AnchorConstructionBuilder<TabularInstance> anchorBuilder = anchorTabular
.createDefaultBuilder(randomForestModel, explainedInstance);
final AnchorConstructionBuilder<TabularInstance> anchorBuilder = anchorTabularTest
.createDefaultBuilder(h2oModel::predict, explainedInstance);

// anchorBuilder.setTau(0.77).setBeamSize(11).setDelta(0.11).setEpsilon(0.41).setTauDiscrepancy(0.09).setInitSampleCount(3);

// RANDOM SEARCH with time condintion
RandomSearch rs = new RandomSearch("Titanic", 1, true);
rs.execute(randomForestModel, anchorBuilder, anchorTabular, PerformanceMeasures.Measure.ACCURACY);
RandomSearch rs = new RandomSearch("Hella_ASN_90", (long)18000, true);
rs.execute(h2oModel::predict, anchorBuilder, anchorTabularTest, PerformanceMeasures.Measure.ACCURACY);

// SMAC with condition
// HyperparameterSpace hyperparameterSpace = new HyperparameterSpace();
Expand Down
19 changes: 19 additions & 0 deletions AutoTuning/src/test/java/RandomSearch/RandomSearchLoggerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package RandomSearch;

import org.junit.Test;

import static org.junit.Assert.*;


public class RandomSearchLoggerTest {

@Test
public void testFileCreation(){
//Given
HyperparameterSpace hyperparameterSpace = new HyperparameterSpace();
RandomSearchLogger logger = new RandomSearchLogger("Titanic", hyperparameterSpace);

logger.endLogging();
}

}

0 comments on commit e0923dc

Please sign in to comment.