diff --git a/api/handle_scenarios.go b/api/handle_scenarios.go index 9f70394d4..a72c0bb41 100644 --- a/api/handle_scenarios.go +++ b/api/handle_scenarios.go @@ -1,11 +1,15 @@ package api import ( + "io" "net/http" + "github.com/cockroachdb/errors" "github.com/gin-gonic/gin" "github.com/checkmarble/marble-backend/dto" + "github.com/checkmarble/marble-backend/models" + "github.com/checkmarble/marble-backend/models/ast" "github.com/checkmarble/marble-backend/pure_utils" "github.com/checkmarble/marble-backend/usecases" "github.com/checkmarble/marble-backend/utils" @@ -24,7 +28,13 @@ func listScenarios(uc usecases.Usecases) func(c *gin.Context) { if presentError(ctx, c, err) { return } - c.JSON(http.StatusOK, pure_utils.Map(scenarios, dto.AdaptScenarioDto)) + + scenariosDto, err := pure_utils.MapErr(scenarios, dto.AdaptScenarioDto) + if presentError(ctx, c, err) { + return + } + + c.JSON(http.StatusOK, scenariosDto) } } @@ -49,7 +59,13 @@ func createScenario(uc usecases.Usecases) func(c *gin.Context) { if presentError(ctx, c, err) { return } - c.JSON(http.StatusOK, dto.AdaptScenarioDto(scenario)) + + scenarioDto, err := dto.AdaptScenarioDto(scenario) + if presentError(ctx, c, err) { + return + } + + c.JSON(http.StatusOK, scenarioDto) } } @@ -64,7 +80,13 @@ func getScenario(uc usecases.Usecases) func(c *gin.Context) { if presentError(ctx, c, err) { return } - c.JSON(http.StatusOK, dto.AdaptScenarioDto(scenario)) + + scenarioDto, err := dto.AdaptScenarioDto(scenario) + if presentError(ctx, c, err) { + return + } + + c.JSON(http.StatusOK, scenarioDto) } } @@ -86,6 +108,53 @@ func updateScenario(uc usecases.Usecases) func(c *gin.Context) { if presentError(ctx, c, err) { return } - c.JSON(http.StatusOK, dto.AdaptScenarioDto(scenario)) + + scenarioDto, err := dto.AdaptScenarioDto(scenario) + if presentError(ctx, c, err) { + return + } + + c.JSON(http.StatusOK, scenarioDto) + } +} + +type PostScenarioAstValidationInputBody struct { + Node dto.NodeDto `json:"node" binding:"required"` + ExpectedReturnType string `json:"expected_return_type"` +} + +func validateScenarioAst(uc usecases.Usecases) func(c *gin.Context) { + return func(c *gin.Context) { + ctx := c.Request.Context() + var input PostScenarioAstValidationInputBody + err := c.ShouldBindJSON(&input) + if err != nil && err != io.EOF { //nolint:errorlint + c.Status(http.StatusBadRequest) + return + } + + scenarioId := c.Param("scenario_id") + + astNode, err := dto.AdaptASTNode(input.Node) + if err != nil { + presentError(ctx, c, errors.Wrap(models.BadParameterError, err.Error())) + return + } + + expectedReturnType := "bool" + if input.ExpectedReturnType != "" { + expectedReturnType = input.ExpectedReturnType + } + + usecase := usecasesWithCreds(ctx, uc).NewScenarioUsecase() + astValidation, err := usecase.ValidateScenarioAst(ctx, scenarioId, &astNode, expectedReturnType) + + if presentError(ctx, c, err) { + return + } + + c.JSON(http.StatusOK, gin.H{ + "ast_validation": ast.AdaptNodeEvaluationDto(astValidation), + }) } } diff --git a/api/routes.go b/api/routes.go index 4e4946581..d648146a1 100644 --- a/api/routes.go +++ b/api/routes.go @@ -57,6 +57,7 @@ func addRoutes(r *gin.Engine, conf Configuration, uc usecases.Usecases, auth Aut router.POST("/scenarios", tom, createScenario(uc)) router.GET("/scenarios/:scenario_id", tom, getScenario(uc)) router.PATCH("/scenarios/:scenario_id", tom, updateScenario(uc)) + router.POST("/scenarios/:scenario_id/validate-ast", tom, validateScenarioAst(uc)) router.GET("/scenario-iterations", tom, handleListScenarioIterations(uc)) router.POST("/scenario-iterations", tom, handleCreateScenarioIteration(uc)) diff --git a/dto/scenarios.go b/dto/scenarios.go index 4490f79e6..a1fb96366 100644 --- a/dto/scenarios.go +++ b/dto/scenarios.go @@ -1,6 +1,7 @@ package dto import ( + "fmt" "time" "github.com/checkmarble/marble-backend/models" @@ -15,6 +16,7 @@ type ScenarioDto struct { DecisionToCaseOutcomes []string `json:"decision_to_case_outcomes"` DecisionToCaseInboxId null.String `json:"decision_to_case_inbox_id"` DecisionToCaseWorkflowType string `json:"decision_to_case_workflow_type"` + DecisionToCaseNameTemplate *NodeDto `json:"decision_to_case_name_template"` Description string `json:"description"` LiveVersionID *string `json:"live_version_id,omitempty"` Name string `json:"name"` @@ -22,8 +24,8 @@ type ScenarioDto struct { TriggerObjectType string `json:"trigger_object_type"` } -func AdaptScenarioDto(scenario models.Scenario) ScenarioDto { - return ScenarioDto{ +func AdaptScenarioDto(scenario models.Scenario) (ScenarioDto, error) { + scenarioDto := ScenarioDto{ Id: scenario.Id, CreatedAt: scenario.CreatedAt, DecisionToCaseInboxId: null.StringFromPtr(scenario.DecisionToCaseInboxId), @@ -36,6 +38,17 @@ func AdaptScenarioDto(scenario models.Scenario) ScenarioDto { OrganizationId: scenario.OrganizationId, TriggerObjectType: scenario.TriggerObjectType, } + + if scenario.DecisionToCaseNameTemplate != nil { + astDto, err := AdaptNodeDto(*scenario.DecisionToCaseNameTemplate) + if err != nil { + return ScenarioDto{}, + fmt.Errorf("unable to marshal ast expression: %w", err) + } + scenarioDto.DecisionToCaseNameTemplate = &astDto + } + + return scenarioDto, nil } // Create scenario DTO @@ -61,6 +74,7 @@ type UpdateScenarioBody struct { DecisionToCaseOutcomes []string `json:"decision_to_case_outcomes"` DecisionToCaseInboxId null.String `json:"decision_to_case_inbox_id"` DecisionToCaseWorkflowType *string `json:"decision_to_case_workflow_type"` + DecisionToCaseNameTemplate *NodeDto `json:"decision_to_case_name_template"` Description *string `json:"description"` Name *string `json:"name"` } @@ -79,5 +93,11 @@ func AdaptUpdateScenarioInput(scenarioId string, input UpdateScenarioBody) model val := models.WorkflowType(*input.DecisionToCaseWorkflowType) parsedInput.DecisionToCaseWorkflowType = &val } + if input.DecisionToCaseNameTemplate != nil { + astNode, err := AdaptASTNode(*input.DecisionToCaseNameTemplate) + if err == nil { + parsedInput.DecisionToCaseNameTemplate = &astNode + } + } return parsedInput } diff --git a/models/ast/ast_function.go b/models/ast/ast_function.go index b9d326b3d..77f0cc500 100644 --- a/models/ast/ast_function.go +++ b/models/ast/ast_function.go @@ -69,6 +69,7 @@ const ( FUNC_STRING_STARTS_WITH FUNC_STRING_ENDS_WITH FUNC_IS_MULTIPLE_OF + FUNC_STRING_TEMPLATE FUNC_UNDEFINED Function = -1 FUNC_UNKNOWN Function = -2 ) @@ -229,6 +230,10 @@ var FuncAttributesMap = map[Function]FuncAttributes{ AstName: "IsMultipleOf", NamedArguments: []string{"value", "divider"}, }, + FUNC_STRING_TEMPLATE: { + DebugName: "FUNC_STRING_TEMPLATE", + AstName: "StringTemplate", + }, FUNC_FILTER: FuncFilterAttributes, } diff --git a/models/ast/ast_node_evaluation.go b/models/ast/ast_node_evaluation.go index 04306bb36..c3c56fef8 100644 --- a/models/ast/ast_node_evaluation.go +++ b/models/ast/ast_node_evaluation.go @@ -41,5 +41,17 @@ func (root NodeEvaluation) GetBoolReturnValue() (bool, error) { } return false, errors.New( - fmt.Sprintf("root ast expression does not return a boolean, '%v' instead", root.ReturnValue)) + fmt.Sprintf("root ast expression does not return a boolean, '%T' instead", root.ReturnValue)) +} + +func (root NodeEvaluation) GetStringReturnValue() (string, error) { + if root.ReturnValue == nil { + return "", ErrNullFieldRead + } + + if returnValue, ok := root.ReturnValue.(string); ok { + return returnValue, nil + } + + return "", errors.New(fmt.Sprintf("ast expression expected to return a string, got '%T' instead", root.ReturnValue)) } diff --git a/models/scenarios.go b/models/scenarios.go index 7ef8fdbb2..cca0a16d6 100644 --- a/models/scenarios.go +++ b/models/scenarios.go @@ -3,6 +3,7 @@ package models import ( "time" + "github.com/checkmarble/marble-backend/models/ast" "github.com/guregu/null/v5" ) @@ -26,6 +27,7 @@ type Scenario struct { DecisionToCaseOutcomes []Outcome DecisionToCaseInboxId *string DecisionToCaseWorkflowType WorkflowType + DecisionToCaseNameTemplate *ast.Node Description string LiveVersionID *string Name string @@ -45,6 +47,7 @@ type UpdateScenarioInput struct { DecisionToCaseOutcomes []Outcome DecisionToCaseInboxId null.String DecisionToCaseWorkflowType *WorkflowType + DecisionToCaseNameTemplate *ast.Node Description *string Name *string } diff --git a/repositories/dbmodels/db_scenario.go b/repositories/dbmodels/db_scenario.go index 6d73b6380..48dd0729f 100644 --- a/repositories/dbmodels/db_scenario.go +++ b/repositories/dbmodels/db_scenario.go @@ -1,6 +1,7 @@ package dbmodels import ( + "fmt" "time" "github.com/checkmarble/marble-backend/models" @@ -16,6 +17,7 @@ type DBScenario struct { DecisionToCaseInboxId pgtype.Text `db:"decision_to_case_inbox_id"` DecisionToCaseOutcomes []string `db:"decision_to_case_outcomes"` DecisionToCaseWorkflowType string `db:"decision_to_case_workflow_type"` + DecisionToCaseNameTemplate []byte `db:"decision_to_case_name_template"` DeletedAt pgtype.Time `db:"deleted_at"` Description string `db:"description"` LiveVersionID pgtype.Text `db:"live_scenario_iteration_id"` @@ -46,5 +48,13 @@ func AdaptScenario(dto DBScenario) (models.Scenario, error) { if dto.LiveVersionID.Valid { scenario.LiveVersionID = &dto.LiveVersionID.String } + + var err error + scenario.DecisionToCaseNameTemplate, err = + AdaptSerializedAstExpression(dto.DecisionToCaseNameTemplate) + if err != nil { + return scenario, fmt.Errorf("unable to unmarshal ast expression: %w", err) + } + return scenario, nil } diff --git a/repositories/migrations/20250102151657_case_name_template.sql b/repositories/migrations/20250102151657_case_name_template.sql new file mode 100644 index 000000000..f1fc46538 --- /dev/null +++ b/repositories/migrations/20250102151657_case_name_template.sql @@ -0,0 +1,11 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE scenarios +ADD COLUMN decision_to_case_name_template JSON; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE scenarios +DROP COLUMN decision_to_case_name_template; +-- +goose StatementEnd diff --git a/repositories/scenarios_write.go b/repositories/scenarios_write.go index dba42f1b0..a14f56df8 100644 --- a/repositories/scenarios_write.go +++ b/repositories/scenarios_write.go @@ -2,6 +2,7 @@ package repositories import ( "context" + "fmt" "github.com/checkmarble/marble-backend/models" "github.com/checkmarble/marble-backend/repositories/dbmodels" @@ -66,6 +67,15 @@ func (repo *MarbleDbRepository) UpdateScenario(ctx context.Context, exec Executo sql = sql.Set("decision_to_case_workflow_type", scenario.DecisionToCaseWorkflowType) countApply++ } + if scenario.DecisionToCaseNameTemplate != nil { + serializedAst, err := dbmodels.SerializeFormulaAstExpression(scenario.DecisionToCaseNameTemplate) + if err != nil { + return fmt.Errorf( + "unable to marshal ast expression: %w", err) + } + sql = sql.Set("decision_to_case_name_template", serializedAst) + countApply++ + } if scenario.Description != nil { sql = sql.Set("description", scenario.Description) countApply++ diff --git a/usecases/ast_eval/evaluate/eval_string_template.go b/usecases/ast_eval/evaluate/eval_string_template.go new file mode 100644 index 000000000..234362c5c --- /dev/null +++ b/usecases/ast_eval/evaluate/eval_string_template.go @@ -0,0 +1,71 @@ +package evaluate + +import ( + "context" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/checkmarble/marble-backend/models/ast" + cockroachdbErrors "github.com/cockroachdb/errors" +) + +var stringTemplateVariableRegexp = regexp.MustCompile(`(?mi)%([a-z0-9_]+)%`) + +type StringTemplate struct{} + +func (f StringTemplate) Evaluate(ctx context.Context, arguments ast.Arguments) (any, []error) { + if err := verifyNumberOfArguments(arguments.Args, 1); err != nil { + return MakeEvaluateError(err) + } + + if arguments.Args[0] == nil || arguments.Args[0] == "" { + return nil, MakeAdaptedArgsErrors([]error{ast.ErrArgumentRequired}) + } + + template, templateErr := adaptArgumentToString(arguments.Args[0]) + if templateErr != nil { + return MakeEvaluateError(templateErr) + } + + var execErrors []error + replacedTemplate := template + for _, match := range stringTemplateVariableRegexp.FindAllStringSubmatch(template, -1) { + variableValue, argErr := adapatVariableValue(arguments.NamedArgs, match[1]) + if argErr != nil { + if !errors.Is(argErr, ast.ErrArgumentRequired) { + execErrors = append(execErrors, argErr) + continue + } + variableValue = "{}" + } + replacedTemplate = strings.Replace(replacedTemplate, + fmt.Sprintf("%%%s%%", match[1]), variableValue, -1) + } + + errs := MakeAdaptedArgsErrors(execErrors) + if len(errs) > 0 { + return nil, errs + } + + return replacedTemplate, nil +} + +func adapatVariableValue(namedArgs map[string]any, name string) (string, error) { + if value, err := AdaptNamedArgument(namedArgs, name, adaptArgumentToString); err == nil { + return value, nil + } + + if value, err := AdaptNamedArgument(namedArgs, name, promoteArgumentToFloat64); err == nil { + return strconv.FormatFloat(value, 'f', 2, 64), nil + } + + if value, err := AdaptNamedArgument(namedArgs, name, promoteArgumentToInt64); err == nil { + return strconv.FormatInt(value, 10), nil + } + + return "", cockroachdbErrors.Wrap(ast.ErrArgumentInvalidType, + "all variables to String Template Evaluate must be string, int or float") +} diff --git a/usecases/ast_eval/evaluate_ast_expression.go b/usecases/ast_eval/evaluate_ast_expression.go index bfc54a458..a2c176805 100644 --- a/usecases/ast_eval/evaluate_ast_expression.go +++ b/usecases/ast_eval/evaluate_ast_expression.go @@ -19,7 +19,7 @@ func (evaluator *EvaluateAstExpression) EvaluateAstExpression( organizationId string, payload models.ClientObject, dataModel models.DataModel, -) (bool, ast.NodeEvaluation, error) { +) (ast.NodeEvaluation, error) { environment := evaluator.AstEvaluationEnvironmentFactory(EvaluationEnvironmentFactoryParams{ OrganizationId: organizationId, ClientObject: payload, @@ -29,12 +29,8 @@ func (evaluator *EvaluateAstExpression) EvaluateAstExpression( evaluation, ok := EvaluateAst(ctx, environment, ruleAstExpression) if !ok { - return false, evaluation, errors.Join(evaluation.FlattenErrors()...) + return evaluation, errors.Join(evaluation.FlattenErrors()...) } - returnValue, err := evaluation.GetBoolReturnValue() - if err != nil { - return false, evaluation, errors.Join(ast.ErrRuntimeExpression, err) - } - return returnValue, evaluation, nil + return evaluation, nil } diff --git a/usecases/ast_eval/evaluate_environment.go b/usecases/ast_eval/evaluate_environment.go index 82c1384a9..28a6e13f9 100644 --- a/usecases/ast_eval/evaluate_environment.go +++ b/usecases/ast_eval/evaluate_environment.go @@ -74,5 +74,6 @@ func NewAstEvaluationEnvironment() AstEvaluationEnvironment { environment.AddEvaluator(ast.FUNC_IS_EMPTY, evaluate.IsEmpty{}) environment.AddEvaluator(ast.FUNC_IS_NOT_EMPTY, evaluate.IsNotEmpty{}) environment.AddEvaluator(ast.FUNC_IS_MULTIPLE_OF, evaluate.IsMultipleOf{}) + environment.AddEvaluator(ast.FUNC_STRING_TEMPLATE, evaluate.StringTemplate{}) return environment } diff --git a/usecases/decision_usecase.go b/usecases/decision_usecase.go index 7521b18af..75df70822 100644 --- a/usecases/decision_usecase.go +++ b/usecases/decision_usecase.go @@ -82,6 +82,8 @@ type decisionWorkflowsUsecase interface { tx repositories.Transaction, scenario models.Scenario, decision models.DecisionWithRuleExecutions, + repositories evaluate_scenario.ScenarioEvaluationRepositories, + params evaluate_scenario.ScenarioEvaluationParameters, webhookEventId string, ) (bool, error) } @@ -459,6 +461,8 @@ func (usecase *DecisionUsecase) CreateDecision( tx, scenario, decision, + evaluationRepositories, + evaluationParameters, caseWebhookEventId) if err != nil { return models.DecisionWithRuleExecutions{}, err @@ -622,8 +626,16 @@ func (usecase *DecisionUsecase) CreateAllDecisions( sendWebhookEventIds = append(sendWebhookEventIds, webhookEventId) caseWebhookEventId := uuid.NewString() + + evaluationParameters := evaluate_scenario.ScenarioEvaluationParameters{ + Scenario: item.scenario, + ClientObject: payload, + DataModel: dataModel, + Pivot: pivot, + } webhookEventCreated, err := usecase.decisionWorkflows.AutomaticDecisionToCase( - ctx, tx, item.scenario, item.decision, caseWebhookEventId) + ctx, tx, item.scenario, item.decision, evaluationRepositories, + evaluationParameters, caseWebhookEventId) if err != nil { return nil, err } diff --git a/usecases/decision_workflows/decision_workflows.go b/usecases/decision_workflows/decision_workflows.go index f2d9e4db9..9afd08031 100644 --- a/usecases/decision_workflows/decision_workflows.go +++ b/usecases/decision_workflows/decision_workflows.go @@ -7,6 +7,7 @@ import ( "github.com/checkmarble/marble-backend/models" "github.com/checkmarble/marble-backend/repositories" + "github.com/checkmarble/marble-backend/usecases/evaluate_scenario" "github.com/pkg/errors" ) @@ -71,6 +72,8 @@ func (d DecisionsWorkflows) AutomaticDecisionToCase( tx repositories.Transaction, scenario models.Scenario, decision models.DecisionWithRuleExecutions, + repositories evaluate_scenario.ScenarioEvaluationRepositories, + params evaluate_scenario.ScenarioEvaluationParameters, webhookEventId string, ) (addedToCase bool, err error) { if scenario.DecisionToCaseWorkflowType == models.WorkflowDisabled || @@ -81,7 +84,12 @@ func (d DecisionsWorkflows) AutomaticDecisionToCase( } if scenario.DecisionToCaseWorkflowType == models.WorkflowCreateCase { - input := automaticCreateCaseAttributes(scenario, decision) + caseName, err := evaluate_scenario.EvalCaseName(ctx, params, repositories, scenario, decision) + if err != nil { + return false, errors.Wrap(err, "error creating case for decision") + } + + input := automaticCreateCaseAttributes(scenario, decision, caseName) newCase, err := d.caseEditor.CreateCase(ctx, tx, "", input, false) if err != nil { return false, errors.Wrap(err, "error creating case for decision") @@ -105,7 +113,12 @@ func (d DecisionsWorkflows) AutomaticDecisionToCase( } if !added { - input := automaticCreateCaseAttributes(scenario, decision) + caseName, err := evaluate_scenario.EvalCaseName(ctx, params, repositories, scenario, decision) + if err != nil { + return false, errors.Wrap(err, "error creating case for decision") + } + + input := automaticCreateCaseAttributes(scenario, decision, caseName) newCase, err := d.caseEditor.CreateCase(ctx, tx, "", input, false) if err != nil { return false, errors.Wrap(err, "error creating case for decision") @@ -141,15 +154,12 @@ func (d DecisionsWorkflows) AutomaticDecisionToCase( func automaticCreateCaseAttributes( scenario models.Scenario, decision models.DecisionWithRuleExecutions, + name string, ) models.CreateCaseAttributes { return models.CreateCaseAttributes{ - DecisionIds: []string{decision.DecisionId}, - InboxId: *scenario.DecisionToCaseInboxId, - Name: fmt.Sprintf( - "Case for %s: %s", - scenario.TriggerObjectType, - decision.ClientObject.Data["object_id"], - ), + DecisionIds: []string{decision.DecisionId}, + InboxId: *scenario.DecisionToCaseInboxId, + Name: name, OrganizationId: scenario.OrganizationId, } } diff --git a/usecases/evaluate_scenario/evaluate_scenario.go b/usecases/evaluate_scenario/evaluate_scenario.go index a406abd2f..4a671040d 100644 --- a/usecases/evaluate_scenario/evaluate_scenario.go +++ b/usecases/evaluate_scenario/evaluate_scenario.go @@ -319,7 +319,7 @@ func evalScenarioRule( } // Evaluate single rule - returnValue, ruleEvaluation, err := repositories.EvaluateAstExpression.EvaluateAstExpression( + ruleEvaluation, err := repositories.EvaluateAstExpression.EvaluateAstExpression( ctx, *rule.FormulaAstExpression, dataAccessor.organizationId, @@ -327,11 +327,22 @@ func evalScenarioRule( dataModel, ) - if err != nil && !ast.IsAuthorizedError(err) { + isAuthorizedError := ast.IsAuthorizedError(err) + if err != nil && !isAuthorizedError { return 0, models.RuleExecution{}, errors.Wrap(err, fmt.Sprintf("error while evaluating rule %s (%s)", rule.Name, rule.Id)) } + var returnValue bool + if err == nil { + returnValue, err = ruleEvaluation.GetBoolReturnValue() + if err != nil && !ast.IsAuthorizedError(err) { + return 0, models.RuleExecution{}, errors.Wrap( + errors.Join(ast.ErrRuntimeExpression, err), + fmt.Sprintf("error while evaluating rule %s (%s)", rule.Name, rule.Id)) + } + } + ruleEvaluationDto := ast.AdaptNodeEvaluationDto(ruleEvaluation) ruleExecution := models.RuleExecution{ Outcome: "no_hit", @@ -374,20 +385,34 @@ func evalScenarioTrigger( ctx, span := tracer.Start(ctx, "evaluate_scenario.evalScenarioTrigger") defer span.End() - returnValue, _, err := repositories.EvaluateAstExpression.EvaluateAstExpression( + triggerEvaluation, err := repositories.EvaluateAstExpression.EvaluateAstExpression( ctx, triggerAstExpression, organizationId, payload, dataModel, ) + isAuthorizedError := ast.IsAuthorizedError(err) + if err != nil && !isAuthorizedError { return errors.Wrap(err, "Unexpected error evaluating trigger condition in EvalScenario") } - if !returnValue || isAuthorizedError { + var returnValue bool + var isAuthorizedTypeError bool + if err == nil { + returnValue, err = triggerEvaluation.GetBoolReturnValue() + isAuthorizedTypeError = ast.IsAuthorizedError(err) + if err != nil && !isAuthorizedTypeError { + return errors.Wrap( + errors.Join(ast.ErrRuntimeExpression, err), + "Unexpected error evaluating trigger condition in EvalScenario") + } + } + + if !returnValue || isAuthorizedError || isAuthorizedTypeError { return errors.Wrap( models.ErrScenarioTriggerConditionAndTriggerObjectMismatch, "scenario trigger object does not match payload in EvalScenario") @@ -481,3 +506,39 @@ func getPivotValue(ctx context.Context, pivot models.Pivot, dataAccessor DataAcc return &valStr, nil } + +func EvalCaseName( + ctx context.Context, + params ScenarioEvaluationParameters, + repositories ScenarioEvaluationRepositories, + scenario models.Scenario, + decision models.DecisionWithRuleExecutions, +) (string, error) { + if scenario.DecisionToCaseNameTemplate == nil { + return fmt.Sprintf("Case for %s: %s", scenario.TriggerObjectType, + decision.ClientObject.Data["object_id"]), nil + } + + caseNameEvaluation, err := repositories.EvaluateAstExpression.EvaluateAstExpression( + ctx, + *scenario.DecisionToCaseNameTemplate, + params.Scenario.OrganizationId, + params.ClientObject, + params.DataModel, + ) + + isAuthorizedError := ast.IsAuthorizedError(err) + if err != nil && !isAuthorizedError { + return "", errors.Wrap(err, + "Unexpected error evaluating case name in EvalCaseName") + } + + returnValue, err := caseNameEvaluation.GetStringReturnValue() + if err != nil && !isAuthorizedError { + return "", errors.Wrap( + errors.Join(ast.ErrRuntimeExpression, err), + "Unexpected error evaluating case name in EvalCaseName") + } + + return returnValue, nil +} diff --git a/usecases/scenario_usecase.go b/usecases/scenario_usecase.go index 058af97db..4a90848ff 100644 --- a/usecases/scenario_usecase.go +++ b/usecases/scenario_usecase.go @@ -5,9 +5,11 @@ import ( "slices" "github.com/checkmarble/marble-backend/models" + "github.com/checkmarble/marble-backend/models/ast" "github.com/checkmarble/marble-backend/pure_utils" "github.com/checkmarble/marble-backend/repositories" "github.com/checkmarble/marble-backend/usecases/executor_factory" + "github.com/checkmarble/marble-backend/usecases/scenarios" "github.com/checkmarble/marble-backend/usecases/security" "github.com/checkmarble/marble-backend/usecases/tracking" @@ -15,10 +17,12 @@ import ( ) type ScenarioUsecase struct { - transactionFactory executor_factory.TransactionFactory - executorFactory executor_factory.ExecutorFactory - enforceSecurity security.EnforceSecurityScenario - repository repositories.ScenarioUsecaseRepository + transactionFactory executor_factory.TransactionFactory + scenarioFetcher scenarios.ScenarioFetcher + validateScenarioAst scenarios.ValidateScenarioAst + executorFactory executor_factory.ExecutorFactory + enforceSecurity security.EnforceSecurityScenario + repository repositories.ScenarioUsecaseRepository } func (usecase *ScenarioUsecase) ListScenarios(ctx context.Context, organizationId string) ([]models.Scenario, error) { @@ -71,7 +75,8 @@ func (usecase *ScenarioUsecase) UpdateScenario( // influence how decisions are treated) so require a higher permission to update changeWorkflowSettings := scenarioInput.DecisionToCaseInboxId.Valid || scenarioInput.DecisionToCaseOutcomes != nil || - scenarioInput.DecisionToCaseWorkflowType != nil + scenarioInput.DecisionToCaseWorkflowType != nil || + scenarioInput.DecisionToCaseNameTemplate != nil if changeWorkflowSettings { if err := usecase.enforceSecurity.PublishScenario(scenario); err != nil { return models.Scenario{}, err @@ -82,6 +87,17 @@ func (usecase *ScenarioUsecase) UpdateScenario( return models.Scenario{}, err } + if scenarioInput.DecisionToCaseNameTemplate != nil { + validation, err := usecase.ValidateScenarioAst(ctx, scenarioInput.Id, + scenarioInput.DecisionToCaseNameTemplate, "string") + if err != nil { + return models.Scenario{}, err + } + if len(validation.FlattenErrors()) > 0 { + return models.Scenario{}, errors.Join(validation.FlattenErrors()...) + } + } + err = usecase.repository.UpdateScenario(ctx, tx, scenarioInput) if err != nil { return models.Scenario{}, err @@ -132,6 +148,24 @@ func validateScenarioUpdate(scenario models.Scenario, input models.UpdateScenari return nil } +func (usecase *ScenarioUsecase) ValidateScenarioAst(ctx context.Context, + scenarioId string, astNode *ast.Node, expectedReturnType string, +) (validation ast.NodeEvaluation, err error) { + scenario, err := usecase.scenarioFetcher.FetchScenario(ctx, + usecase.executorFactory.NewExecutor(), scenarioId) + if err != nil { + return validation, err + } + + if err := usecase.enforceSecurity.ReadScenario(scenario); err != nil { + return validation, err + } + + validation, err = usecase.validateScenarioAst.Validate(ctx, scenario, astNode, expectedReturnType) + + return validation, err +} + func (usecase *ScenarioUsecase) CreateScenario( ctx context.Context, scenario models.CreateScenarioInput, diff --git a/usecases/scenarios/scenario_and_iteration.go b/usecases/scenarios/scenario_and_iteration.go index 4eebd830f..d7ddea662 100644 --- a/usecases/scenarios/scenario_and_iteration.go +++ b/usecases/scenarios/scenario_and_iteration.go @@ -33,3 +33,7 @@ func (fetcher ScenarioFetcher) FetchScenarioAndIteration(ctx context.Context, return result, err } + +func (fetcher ScenarioFetcher) FetchScenario(ctx context.Context, exec repositories.Executor, scenarioId string) (models.Scenario, error) { + return fetcher.Repository.GetScenarioById(ctx, exec, scenarioId) +} diff --git a/usecases/scenarios/scenario_validation.go b/usecases/scenarios/scenario_validation.go index 22149a15d..ae94f7ef8 100644 --- a/usecases/scenarios/scenario_validation.go +++ b/usecases/scenarios/scenario_validation.go @@ -3,10 +3,13 @@ package scenarios import ( "context" "fmt" + "reflect" + "time" "github.com/cockroachdb/errors" "github.com/checkmarble/marble-backend/models" + "github.com/checkmarble/marble-backend/models/ast" "github.com/checkmarble/marble-backend/pure_utils" "github.com/checkmarble/marble-backend/repositories" "github.com/checkmarble/marble-backend/usecases/ast_eval" @@ -42,12 +45,12 @@ type ValidateScenarioIteration interface { } type ValidateScenarioIterationImpl struct { - DataModelRepository repositories.DataModelRepository - AstEvaluationEnvironmentFactory ast_eval.AstEvaluationEnvironmentFactory - ExecutorFactory executor_factory.ExecutorFactory + AstValidator AstValidator } -func (validator *ValidateScenarioIterationImpl) Validate(ctx context.Context, si models.ScenarioAndIteration) models.ScenarioValidation { +func (self *ValidateScenarioIterationImpl) Validate(ctx context.Context, + si models.ScenarioAndIteration, +) models.ScenarioValidation { iteration := si.Iteration result := models.NewScenarioValidation() @@ -71,7 +74,7 @@ func (validator *ValidateScenarioIterationImpl) Validate(ctx context.Context, si }) } - dryRunEnvironment, err := validator.makeDryRunEnvironment(ctx, si) + dryRunEnvironment, err := self.AstValidator.MakeDryRunEnvironment(ctx, si.Scenario) if err != nil { result.Errors = append(result.Errors, *err) return result @@ -121,16 +124,80 @@ func (validator *ValidateScenarioIterationImpl) Validate(ctx context.Context, si return result } +type ValidateScenarioAst interface { + Validate(ctx context.Context, scenario models.Scenario, astNode *ast.Node, + expectedReturnType string) (ast.NodeEvaluation, error) +} + +type ValidateScenarioAstImpl struct { + AstValidator AstValidator +} + +func (self *ValidateScenarioAstImpl) Validate(ctx context.Context, + scenario models.Scenario, + astNode *ast.Node, + expectedReturnTypeStr string, +) (ast.NodeEvaluation, error) { + dryRunEnvironment, err := self.AstValidator.MakeDryRunEnvironment(ctx, scenario) + if err != nil { + return ast.NodeEvaluation{}, err.Error + } + + expectedReturnType, ok := getTypeFromString(expectedReturnTypeStr) + if !ok { + return ast.NodeEvaluation{}, errors.Wrapf(models.BadParameterError, + "unknown specified type '%s'", expectedReturnTypeStr) + } + + astEvaluation, _ := ast_eval.EvaluateAst(ctx, dryRunEnvironment, *astNode) + astEvaluationReturnType := reflect.TypeOf(astEvaluation.ReturnValue) + + if len(astEvaluation.FlattenErrors()) == 0 && astEvaluationReturnType != expectedReturnType { + astEvaluation.Errors = append(astEvaluation.Errors, errors.Wrapf(models.BadParameterError, + "ast node does not return a %s", expectedReturnTypeStr)) + } + + return astEvaluation, nil +} + +func getTypeFromString(typeStr string) (reflect.Type, bool) { + switch typeStr { + case "string": + return reflect.TypeOf(""), true + case "int": + return reflect.TypeOf(int64(0)), true + case "float": + return reflect.TypeOf(float64(0)), true + case "bool": + return reflect.TypeOf(false), true + case "datetime": + return reflect.TypeOf(time.Now()), true + default: + return nil, false + } +} + func hasScoreThresholds(iteration models.ScenarioIteration) bool { return iteration.ScoreReviewThreshold != nil && iteration.ScoreBlockAndReviewThreshold != nil && iteration.ScoreDeclineThreshold != nil } -func (validator *ValidateScenarioIterationImpl) makeDryRunEnvironment(ctx context.Context, - si models.ScenarioAndIteration, +type AstValidator interface { + MakeDryRunEnvironment(ctx context.Context, scenario models.Scenario) ( + ast_eval.AstEvaluationEnvironment, *models.ScenarioValidationError) +} + +type AstValidatorImpl struct { + DataModelRepository repositories.DataModelRepository + AstEvaluationEnvironmentFactory ast_eval.AstEvaluationEnvironmentFactory + ExecutorFactory executor_factory.ExecutorFactory +} + +func (validator *AstValidatorImpl) MakeDryRunEnvironment(ctx context.Context, + scenario models.Scenario, ) (ast_eval.AstEvaluationEnvironment, *models.ScenarioValidationError) { - organizationId := si.Scenario.OrganizationId + organizationId := scenario.OrganizationId dataModel, err := validator.DataModelRepository.GetDataModel(ctx, validator.ExecutorFactory.NewExecutor(), organizationId, false) @@ -141,11 +208,11 @@ func (validator *ValidateScenarioIterationImpl) makeDryRunEnvironment(ctx contex } } - table, ok := dataModel.Tables[si.Scenario.TriggerObjectType] + table, ok := dataModel.Tables[scenario.TriggerObjectType] if !ok { return ast_eval.AstEvaluationEnvironment{}, &models.ScenarioValidationError{ Error: errors.Wrap(models.NotFoundError, - fmt.Sprintf("table %s not found in data model for dry run", si.Scenario.TriggerObjectType)), + fmt.Sprintf("table %s not found in data model for dry run", scenario.TriggerObjectType)), Code: models.TrigerObjectNotFound, } } diff --git a/usecases/scenarios/scenario_validation_test.go b/usecases/scenarios/scenario_validation_test.go index a0ee168e7..a43d3b1c2 100644 --- a/usecases/scenarios/scenario_validation_test.go +++ b/usecases/scenarios/scenario_validation_test.go @@ -88,7 +88,7 @@ func TestValidateScenarioIterationImpl_Validate(t *testing.T) { }, }, nil) - validator := ValidateScenarioIterationImpl{ + validator := AstValidatorImpl{ DataModelRepository: mdmr, AstEvaluationEnvironmentFactory: func(params ast_eval.EvaluationEnvironmentFactoryParams) ast_eval.AstEvaluationEnvironment { return ast_eval.NewAstEvaluationEnvironment() @@ -96,7 +96,11 @@ func TestValidateScenarioIterationImpl_Validate(t *testing.T) { ExecutorFactory: executorFactory, } - result := validator.Validate(ctx, models.ScenarioAndIteration{ + siValidator := ValidateScenarioIterationImpl{ + AstValidator: &validator, + } + + result := siValidator.Validate(ctx, models.ScenarioAndIteration{ Scenario: scenario, Iteration: scenarioIteration, }) @@ -176,7 +180,7 @@ func TestValidateScenarioIterationImpl_Validate_notBool(t *testing.T) { }, }, nil) - validator := ValidateScenarioIterationImpl{ + validator := AstValidatorImpl{ DataModelRepository: mdmr, AstEvaluationEnvironmentFactory: func(params ast_eval.EvaluationEnvironmentFactoryParams) ast_eval.AstEvaluationEnvironment { return ast_eval.NewAstEvaluationEnvironment() @@ -184,7 +188,11 @@ func TestValidateScenarioIterationImpl_Validate_notBool(t *testing.T) { ExecutorFactory: executorFactory, } - result := validator.Validate(ctx, models.ScenarioAndIteration{ + siValidator := ValidateScenarioIterationImpl{ + AstValidator: &validator, + } + + result := siValidator.Validate(ctx, models.ScenarioAndIteration{ Scenario: scenario, Iteration: scenarioIteration, }) diff --git a/usecases/scheduled_execution/async_decision_job.go b/usecases/scheduled_execution/async_decision_job.go index de648e5d1..dede2c24b 100644 --- a/usecases/scheduled_execution/async_decision_job.go +++ b/usecases/scheduled_execution/async_decision_job.go @@ -30,6 +30,8 @@ type decisionWorkflowsUsecase interface { tx repositories.Transaction, scenario models.Scenario, decision models.DecisionWithRuleExecutions, + repositories evaluate_scenario.ScenarioEvaluationRepositories, + params evaluate_scenario.ScenarioEvaluationParameters, webhookEventId string, ) (bool, error) } @@ -274,22 +276,27 @@ func (w *AsyncDecisionWorker) createSingleDecisionForObjectId( } object := models.ClientObject{TableName: table.Name, Data: objectMap[0]} + + evaluationParameters := evaluate_scenario.ScenarioEvaluationParameters{ + Scenario: scenario, + TargetIterationId: &args.ScenarioIterationId, + ClientObject: object, + DataModel: dataModel, + Pivot: pivot, + } + + evaluationRepositories := evaluate_scenario.ScenarioEvaluationRepositories{ + EvalScenarioRepository: w.repository, + ExecutorFactory: w.executorFactory, + IngestedDataReadRepository: w.ingestedDataReadRepository, + EvaluateAstExpression: w.evaluateAstExpression, + SnoozeReader: w.snoozesReader, + } + scenarioExecution, err := evaluate_scenario.EvalScenario( ctx, - evaluate_scenario.ScenarioEvaluationParameters{ - Scenario: scenario, - TargetIterationId: &args.ScenarioIterationId, - ClientObject: object, - DataModel: dataModel, - Pivot: pivot, - }, - evaluate_scenario.ScenarioEvaluationRepositories{ - EvalScenarioRepository: w.repository, - ExecutorFactory: w.executorFactory, - IngestedDataReadRepository: w.ingestedDataReadRepository, - EvaluateAstExpression: w.evaluateAstExpression, - SnoozeReader: w.snoozesReader, - }, + evaluationParameters, + evaluationRepositories, ) if errors.Is(err, models.ErrScenarioTriggerConditionAndTriggerObjectMismatch) { @@ -336,7 +343,8 @@ func (w *AsyncDecisionWorker) createSingleDecisionForObjectId( sendWebhookEventId = append(sendWebhookEventId, webhookEventId) caseWebhookEventId := uuid.NewString() - webhookEventCreated, err := w.decisionWorkflows.AutomaticDecisionToCase(ctx, tx, scenario, decision, caseWebhookEventId) + webhookEventCreated, err := w.decisionWorkflows.AutomaticDecisionToCase(ctx, tx, scenario, + decision, evaluationRepositories, evaluationParameters, caseWebhookEventId) if err != nil { return false, nil, err } diff --git a/usecases/usecases.go b/usecases/usecases.go index 3b198e951..39e52a188 100644 --- a/usecases/usecases.go +++ b/usecases/usecases.go @@ -215,8 +215,20 @@ func (usecases *Usecases) NewScenarioPublisher() ScenarioPublisher { } } +func (usecases *Usecases) NewValidateScenarioAst() scenarios.ValidateScenarioAst { + return &scenarios.ValidateScenarioAstImpl{ + AstValidator: usecases.NewAstValidator(), + } +} + func (usecases *Usecases) NewValidateScenarioIteration() scenarios.ValidateScenarioIteration { return &scenarios.ValidateScenarioIterationImpl{ + AstValidator: usecases.NewAstValidator(), + } +} + +func (usecases *Usecases) NewAstValidator() scenarios.AstValidator { + return &scenarios.AstValidatorImpl{ DataModelRepository: usecases.Repositories.DataModelRepository, AstEvaluationEnvironmentFactory: usecases.AstEvaluationEnvironmentFactory, ExecutorFactory: usecases.NewExecutorFactory(), diff --git a/usecases/usecases_with_creds.go b/usecases/usecases_with_creds.go index e253b8f8c..a23979a86 100644 --- a/usecases/usecases_with_creds.go +++ b/usecases/usecases_with_creds.go @@ -124,10 +124,12 @@ func (usecases *UsecasesWithCreds) NewDecisionWorkflows() decision_workflows.Dec func (usecases *UsecasesWithCreds) NewScenarioUsecase() ScenarioUsecase { return ScenarioUsecase{ - transactionFactory: usecases.NewTransactionFactory(), - executorFactory: usecases.NewExecutorFactory(), - enforceSecurity: usecases.NewEnforceScenarioSecurity(), - repository: &usecases.Repositories.MarbleDbRepository, + transactionFactory: usecases.NewTransactionFactory(), + scenarioFetcher: usecases.NewScenarioFetcher(), + validateScenarioAst: usecases.NewValidateScenarioAst(), + executorFactory: usecases.NewExecutorFactory(), + enforceSecurity: usecases.NewEnforceScenarioSecurity(), + repository: &usecases.Repositories.MarbleDbRepository, } }