Skip to content

Commit

Permalink
Removed collection of arithmetic ground rules and refactored writeMap…
Browse files Browse the repository at this point in the history
… so that ground truth map, ground rule map, ground atom map, and rule map can all be arguments to the same method.
  • Loading branch information
tsalh committed Aug 27, 2020
1 parent 8e1ce89 commit 5c666b5
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
import org.linqs.psl.model.term.VariableTypeMap;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.util.Parallel;
import org.linqs.psl.util.VizDataCollection;

import com.healthmarketscience.sqlbuilder.BinaryCondition;
import com.healthmarketscience.sqlbuilder.CustomSql;
Expand Down Expand Up @@ -290,11 +289,6 @@ private void groundForNonSummation(Constant[] constants, Map<Variable, Integer>
groundSingleNonSummationRule(constants, variableMap, atomManager, resources);

results.addAll(resources.groundRules);
if (Options.CLI_VIZ.getBoolean()) {
for (GroundRule groundRule : resources.groundRules) {
VizDataCollection.addGroundRule(this, groundRule, variableMap, constants);
}
}
resources.groundRules.clear();
resources.accessExceptionAtoms.clear();
}
Expand All @@ -316,11 +310,6 @@ private void groundForSummation(Constant[] constants, Map<Variable, Integer> var
groundSingleSummationRule(constants, variableMap, atomManager, resources);

results.addAll(resources.groundRules);
if (Options.CLI_VIZ.getBoolean()) {
for (GroundRule groundRule : resources.groundRules) {
VizDataCollection.addGroundRule(this, groundRule, variableMap, constants);
}
}
resources.groundRules.clear();
resources.accessExceptionAtoms.clear();
}
Expand Down Expand Up @@ -350,10 +339,6 @@ private int groundAllNonSummationRule(AtomManager atomManager, GroundRuleStore g

for (int groundingIndex = 0; groundingIndex < results.size(); groundingIndex++) {
groundSingleNonSummationRule(results.get(groundingIndex), variableMap, atomManager, resources);
GroundRule groundRule = resources.groundRules.get(groundingIndex);
if (Options.CLI_VIZ.getBoolean()) {
VizDataCollection.addGroundRule(this, groundRule, variableMap, results.get(groundingIndex));
}
}

int count = resources.groundRules.size();
Expand Down Expand Up @@ -428,10 +413,6 @@ private int groundAllSummationRule(AtomManager atomManager, GroundRuleStore grou

for (int groundingIndex = 0; groundingIndex < results.size(); groundingIndex++) {
groundSingleSummationRule(results.get(groundingIndex), variableMap, atomManager, resources);
GroundRule groundRule = resources.groundRules.get(groundingIndex);
if (Options.CLI_VIZ.getBoolean()) {
VizDataCollection.addGroundRule(this, groundRule, variableMap, results.get(groundingIndex));
}
}

int count = resources.groundRules.size();
Expand Down
136 changes: 64 additions & 72 deletions psl-core/src/main/java/org/linqs/psl/util/VizDataCollection.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,91 +49,82 @@ private static synchronized void init() {
runtime.addShutdownHook(new ShutdownHook());
}

public static void outputJSON() {
PrintStream stream = System.out;
public static void outputJSON() throws IOException {
FilterOutputStream stream = System.out;

if (outputPath != null) {
try {
stream = new PrintStream(outputPath);
if (outputPath.endsWith(".gz")) {
GZIPOutputStream gzipStream = new GZIPOutputStream(stream, true);
writeToStream(gzipStream);
gzipStream.close();
} else {
writeToStream(stream);
}
stream.close();
stream = new GZIPOutputStream(new PrintStream(outputPath));
} catch (IOException ex) {
throw new RuntimeException();
throw new RuntimeException(ex);
}
} else {
writeToStream(stream);
}
}
// Write to stream with JSON formatting.
private static void writeToStream(FilterOutputStream stream) {
// JSON format reference: https://www.json.org/json-en.html.
try {
stream.write("{".getBytes());
// Write each map as a JSON object, each JSON object is comma delimited.
writeMap(vizData.truthMap, stream, "truthMap");
stream.write(",".getBytes());
writeMap(stream, vizData.rules, "rules");
stream.write(",".getBytes());
writeMap(stream, vizData.groundRules, "groundRules");
stream.write(",".getBytes());
writeMap(stream, vizData.groundAtoms, "groundAtoms");

stream.write("}".getBytes());
} catch (IOException ex) {
throw new RuntimeException();

writeToStream(stream);

if (outputPath != null) {
stream.close();
}
}
// Write map to stream with JSON formatting.
private static void writeMap(FilterOutputStream stream, Map<String, Map<String, Object>> map, String key) {
try {
// Each key must be string formatted.
stream.write((" \"" + key + "\" :{").getBytes());

Iterator<Map.Entry<String, Map<String, Object>>> iterator = map.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<String, Map<String, Object>> entry = iterator.next();
JSONObject jsonObject = new JSONObject(entry.getValue());
stream.write((" \"" + entry.getKey() + "\" :" + jsonObject.toString()).getBytes());
if (iterator.hasNext()) {
stream.write(",".getBytes());
}
}
stream.write("}".getBytes());
} catch (IOException ex) {
throw new RuntimeException();
}

/**
* Write to stream with JSON formatting.
*/
private static void writeToStream(FilterOutputStream stream) throws IOException {
// JSON format reference: https://www.json.org/json-en.html.
stream.write("{ \"truthMap\" :".getBytes());

// Write each map as a JSON object, each JSON object is comma delimited.
writeMap(stream, vizData.truthMap, "truthMap");
stream.write(", \"rules\" :".getBytes());
writeMap(stream, vizData.rules, "rules");
stream.write(", \"groundRules\" :".getBytes());
writeMap(stream, vizData.groundRules, "groundRules");
stream.write(", \"groundAtoms\" :".getBytes());
writeMap(stream, vizData.groundAtoms, "groundAtoms");

stream.write('}');
}
// Write map to stream with JSON formatting.
private static void writeMap(Map<String, Float> map, FilterOutputStream stream, String key) {
try {
// Each key must be string formatted.
stream.write((" \"" + key + "\" :{").getBytes());

Iterator<Map.Entry<String, Float>> iterator = map.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<String, Float> entry = iterator.next();
stream.write((" \"" + entry.getKey() + "\" :" + entry.getValue()).getBytes());
if (iterator.hasNext()) {
stream.write(",".getBytes());
}

/**
* Write map to stream with JSON formatting.
*/
@SuppressWarnings("unchecked")
private static void writeMap(FilterOutputStream stream, Object map, String z) throws IOException {
stream.write('{');

Map<String, Object> stringObjMap = (Map<String, Object>) map;
Iterator<Map.Entry<String, Object>> iterator = stringObjMap.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<String, Object> entry = iterator.next();
stream.write((" \"" + entry.getKey() + "\" :").getBytes());

// Values of the map will either be a Float or Map
if (entry.getValue() instanceof Float) {
stream.write(entry.getValue().toString().getBytes());
} else {
// Assumption that the JSON Objects carry small amounts of data
Map<String, Object> data = (Map<String, Object>) entry.getValue();
JSONObject jsonObject = new JSONObject(data);
stream.write(jsonObject.toString().getBytes());
}

stream.write("}".getBytes());
} catch (IOException ex) {
throw new RuntimeException();
if (iterator.hasNext()) {
stream.write(',');
}
}

stream.write('}');
}

private static class ShutdownHook extends Thread {
@Override
public void run() {
outputJSON();
try {
outputJSON();
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
}

Expand All @@ -154,7 +145,10 @@ public VisualizationData() {
public static void setOutputPath(String path) {
outputPath = path;
}
// Takes in a prediction truth pair and adds it to the Truth Map.

/**
* Takes in a prediction truth pair and adds it to the Truth Map.
*/
public static void addTruth(GroundAtom target, float truthVal ) {
String groundAtomID = Integer.toString(System.identityHashCode(target));
vizData.truthMap.put(groundAtomID, truthVal);
Expand All @@ -167,13 +161,11 @@ public static void dissatisfactionPerGroundRule(GroundRuleStore groundRuleStore)
if (groundRule instanceof WeightedGroundRule) {
WeightedGroundRule weightedGroundRule = (WeightedGroundRule) groundRule;
groundRuleObj.put("dissatisfaction", weightedGroundRule.getIncompatibility());
} else {
UnweightedGroundRule unweightedGroundRule = (UnweightedGroundRule) groundRule;
groundRuleObj.put("dissatisfaction", unweightedGroundRule.getInfeasibility());
}
}
}

// TODO: Collect Abstract Arithmetic Ground Rules
public static synchronized void addGroundRule(AbstractRule parentRule,
GroundRule groundRule, Map<Variable, Integer> variableMap, Constant[] constantsList) {
if (groundRule == null) {
Expand Down

0 comments on commit 5c666b5

Please sign in to comment.