Skip to content

Commit

Permalink
approx: caseo: grammar: program: add #inference directive for aseo
Browse files Browse the repository at this point in the history
  • Loading branch information
RenatoGeh committed May 29, 2024
1 parent 0f34fe5 commit 0016235
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 13 deletions.
23 changes: 19 additions & 4 deletions pasp/approx.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
#include "caseo.h"
#include "cground.h"

static bool _aseo_maxent(program_t *P, size_t n_samples, size_t scale, double **R) {
static bool _aseo_maxent(program_t *P, size_t n_samples, size_t scale, double **R, bool quiet,
bool status) {
bool ok = false;
double *a, *b, *r = a = b = NULL;

Expand All @@ -26,7 +27,7 @@ static bool _aseo_maxent(program_t *P, size_t n_samples, size_t scale, double **
/* Run first ASEO separately to get the number of models. */
models_t M = {0};
if (!aseo_reuse(P, n_samples, MAXENT_SEMANTICS, NULL, (int) scale, 0, W, U, &M,
approx_rec_query_maxent)) goto cleanup;
approx_rec_query_maxent, false)) goto cleanup;
size_t n_M = M.n;

/* The resulting (flattened) array has dimension n*k x 1, where n is the number of queries and k
Expand All @@ -44,11 +45,18 @@ static bool _aseo_maxent(program_t *P, size_t n_samples, size_t scale, double **
approx_query_maxent_ab(P, &M, a, b);
approx_query_maxent_r(P, &M, r, a, b);
models_free_contents(&M);
if (!quiet) {
for (size_t i = 0; i < P->Q_n; ++i) {
print_query(P->Q+i);
wprintf(L" = %f\n", r[i]);
}
fputws(L"---\n", stdout);
}
r += n_M;

for (size_t i = 1; i < m_neural; ++i) {
if (!aseo_reuse(P, n_samples, MAXENT_SEMANTICS, NULL, (int) scale, i, W, U, &M,
approx_rec_query_maxent)) {
approx_rec_query_maxent, status)) {
models_free_contents(&M);
goto cleanup;
}
Expand All @@ -57,6 +65,13 @@ static bool _aseo_maxent(program_t *P, size_t n_samples, size_t scale, double **
approx_query_maxent_ab(P, &M, a, b);
approx_query_maxent_r(P, &M, r, a, b);
models_free_contents(&M);
if (!quiet) {
for (size_t i = 0; i < P->Q_n; ++i) {
print_query(P->Q+i);
wprintf(L" = %f\n", r[i]);
}
fputws(L"---\n", stdout);
}
r += n_M;
}

Expand Down Expand Up @@ -88,7 +103,7 @@ static PyObject* _aseo(PyObject *self, PyObject *args, PyObject *kwargs) {
if (P.stable) if (!ground_all(P.stable, NULL)) goto cleanup;
}

if (!_aseo_maxent(&P, n_samples, scale, &R)) goto cleanup;
if (!_aseo_maxent(&P, n_samples, scale, &R, quiet, status)) goto cleanup;

bool has_neural = P.NA_n+P.NR_n > 0;
npy_intp dims[3] = {P.Q_n, 1, 1};
Expand Down
10 changes: 7 additions & 3 deletions pasp/caseo.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <math.h>

#include "cutils.h"
#include "../progressbar/statusbar.h"

bool watch_minimize(clingo_weight_t p, const clingo_weighted_literal_t *W, size_t n, void *data) {
void **pack = (void**) data;
Expand Down Expand Up @@ -228,10 +229,11 @@ bool set_upper_bound(clingo_backend_t *back, clingo_weighted_literal_t *W, size_
bool aseo_reuse(program_t *P, size_t k, psemantics_t psem, observations_t *O, int scale,
size_t neural_idx, clingo_weighted_literal_t *W, clingo_weighted_literal_t *U, models_t* M,
bool (*f)(const clingo_model_t*, program_t*, models_t*, size_t, observations_t*,
clingo_control_t*)) {
clingo_control_t*), bool status) {
bool ok = false;
clingo_control_t *C = NULL;
clingo_backend_t *back = NULL;
statusbar *bar = status ? statusbar_new("Querying") : NULL;

size_t n = num_prob_params(P);
if (!(W && U)) {
Expand Down Expand Up @@ -291,6 +293,7 @@ bool aseo_reuse(program_t *P, size_t k, psemantics_t psem, observations_t *O, in

if (!set_upper_bound(back, W, n, U, (int) cost)) goto cleanup;
if (!aseo_solve(P, C, k, &res, &m, M, O, neural_idx, f)) goto cleanup;
if (bar) statusbar_inc(bar);
}
M->m = m;

Expand All @@ -309,19 +312,20 @@ bool aseo_reuse(program_t *P, size_t k, psemantics_t psem, observations_t *O, in
if (clingo_error_code() != clingo_error_success) raise_clingo_error(NULL);
else if (!ok) PyErr_SetString(PyExc_RuntimeError, "an error has occurred during ASEO!");
clingo_control_free(C);
if (bar) statusbar_finish(bar);
if (!ok) { models_free(M); M = NULL; }
return ok;
}

bool aseo(program_t *P, size_t k, psemantics_t psem, observations_t *O, int scale,
size_t neural_idx, models_t *M, bool (*f)(const clingo_model_t*, program_t*, models_t*, size_t,
observations_t*, clingo_control_t*)) {
observations_t*, clingo_control_t*), bool status) {
bool ok = false;
size_t n = num_prob_params(P);
clingo_weighted_literal_t *W = (clingo_weighted_literal_t*) malloc(n*sizeof(clingo_weighted_literal_t));
clingo_weighted_literal_t *U = (clingo_weighted_literal_t*) malloc(n*sizeof(clingo_weighted_literal_t));
if (!(W && U)) goto nomem;
if (!aseo_reuse(P, k, psem, O, scale, neural_idx, W, U, M, f)) goto cleanup;
if (!aseo_reuse(P, k, psem, O, scale, neural_idx, W, U, M, f, status)) goto cleanup;

ok = true;
goto cleanup;
Expand Down
4 changes: 2 additions & 2 deletions pasp/caseo.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@

bool aseo(program_t *P, size_t k, psemantics_t psem, observations_t *O, int scale,
size_t neural_idx, models_t* M, bool (*f)(const clingo_model_t*, program_t*, models_t*, size_t,
observations_t*, clingo_control_t*));
observations_t*, clingo_control_t*), bool status);

bool aseo_reuse(program_t *P, size_t k, psemantics_t psem, observations_t *O, int scale,
size_t neural_idx, clingo_weighted_literal_t *W, clingo_weighted_literal_t *U, models_t *M,
bool (*f)(const clingo_model_t*, program_t*, models_t*, size_t,
observations_t*, clingo_control_t*));
observations_t*, clingo_control_t*), bool status);

#endif
7 changes: 6 additions & 1 deletion pasp/grammar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ SEMANTICS_OPT_PROB: "maxent" | "credal"
_semantics_exp: ((SEMANTICS_OPT_LOGIC ("," SEMANTICS_OPT_PROB)?) | (SEMANTICS_OPT_PROB ("," SEMANTICS_OPT_LOGIC)?))
semantics: "#semantics" (("(" _semantics_exp ")") | (_semantics_exp)) "."

// Inference directive.
exact_inf: "exact"
aseo_inf: "aseo" "," "nmodels" "=" ID
inference: "#inference" (exact_inf | aseo_inf) "."

// Constraint.
constraint: ":-" body "."

Expand All @@ -173,7 +178,7 @@ query: "#query" (("(" _interp_exp ")") | ( _interp_exp )) "."?
// Constant definition.
constdef: "#const" WORD "=" ID "."

plp: (constdef | _fact | _rule | _ad | _neural | data | python | constraint | query | learn | semantics | _aggr)*
plp: (constdef | _fact | _rule | _ad | _neural | data | python | constraint | query | learn | semantics | _aggr | inference)*

COMMENT: "%" /[^\n]*/ NEWLINE

Expand Down
6 changes: 5 additions & 1 deletion pasp/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,10 @@ def learn(self, L):
data = self.torch_scope[L[0][1]] if L[0][0] == "PY_FUNC" else StableTransformer.path2obs(L[0][1])
return self.pack("directive", "", ("learn", data, A))

def exact_inf(self, I): return ("inference", "exact", tuple())
def aseo_inf(self, I): return ("inference", "aseo", (I[0][2],))
def inference(self, I): return self.pack("directive", "", I[0])

# Semantics directive and options.
def SEMANTICS_OPT_LOGIC(self, _): return lark.visitors.Discard
def SEMANTICS_OPT_PROB(self, O): return str(O)
Expand Down Expand Up @@ -503,7 +507,7 @@ def plp(self, C) -> Program:
# Actual neural rules and neural ADs.
NR, NA = [], []
# Directives.
directives = {}
directives = {"inference": ("exact", tuple())}
# Mapping.
M = {"pfact": PF, "prule": PR, "query": Q, "varquery": VQ, "cfact": CF, "ad": AD, "nrule": TNR,
"nad": TNA}
Expand Down
8 changes: 6 additions & 2 deletions pasp/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,11 @@ def __call__(self, **kwargs):
else: learn(self, D, **A)
if len(self.Q) + len(self.VQ) > 0:
from exact import exact
from approx import aseo
A = {"quiet": False, "status": True}
A.update(kwargs)
if "psemantics" in self.directives: A.update(self.directives["psemantics"])
return exact(self, **A)
# TODO: implement additional semantics for ASEO and remove the exact exception below.
if ("psemantics" in self.directives) and (self.directives["inference"][0] == "exact"):
A.update(self.directives["psemantics"])
f = vars()[self.directives["inference"][0]]
return f(self, *self.directives["inference"][1], **A)

0 comments on commit 0016235

Please sign in to comment.