From bd7d66ef59ae3deed695e98355d327461fdc8611 Mon Sep 17 00:00:00 2001 From: Earl Warren Date: Thu, 31 Oct 2024 15:58:39 +0100 Subject: [PATCH] fix: return an error when the argument count is wrong Closes forgejo/runner#307 --- pkg/exprparser/functions_test.go | 21 +++++++++++++++++++ pkg/exprparser/interpreter.go | 35 ++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/pkg/exprparser/functions_test.go b/pkg/exprparser/functions_test.go index ea51a2bc8ba..c90b3268099 100644 --- a/pkg/exprparser/functions_test.go +++ b/pkg/exprparser/functions_test.go @@ -43,6 +43,9 @@ func TestFunctionContains(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("contains('one')", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionStartsWith(t *testing.T) { @@ -72,6 +75,9 @@ func TestFunctionStartsWith(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("startsWith('one')", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionEndsWith(t *testing.T) { @@ -101,6 +107,9 @@ func TestFunctionEndsWith(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("endsWith('one')", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionJoin(t *testing.T) { @@ -128,6 +137,9 @@ func TestFunctionJoin(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("join()", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionToJSON(t *testing.T) { @@ -154,6 +166,9 @@ func TestFunctionToJSON(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("tojson()", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionFromJSON(t *testing.T) { @@ -177,6 +192,9 @@ func TestFunctionFromJSON(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("fromjson()", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionHashFiles(t *testing.T) { @@ -248,4 +266,7 @@ func TestFunctionFormat(t *testing.T) { } }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("format()", DefaultStatusCheckNone) + assert.Error(t, err) } diff --git a/pkg/exprparser/interpreter.go b/pkg/exprparser/interpreter.go index ce3aca38f96..a092e4064d4 100644 --- a/pkg/exprparser/interpreter.go +++ b/pkg/exprparser/interpreter.go @@ -589,23 +589,58 @@ func (impl *interperterImpl) evaluateFuncCall(funcCallNode *actionlint.FuncCallN args = append(args, reflect.ValueOf(value)) } + argCountCheck := func(argCount int) error { + if len(args) != argCount { + return fmt.Errorf("'%s' expected %d arguments but got %d instead", funcCallNode.Callee, argCount, len(args)) + } + return nil + } + + argAtLeastCheck := func(atLeast int) error { + if len(args) < atLeast { + return fmt.Errorf("'%s' expected at least %d arguments but got %d instead", funcCallNode.Callee, atLeast, len(args)) + } + return nil + } + switch strings.ToLower(funcCallNode.Callee) { case "contains": + if err := argCountCheck(2); err != nil { + return nil, err + } return impl.contains(args[0], args[1]) case "startswith": + if err := argCountCheck(2); err != nil { + return nil, err + } return impl.startsWith(args[0], args[1]) case "endswith": + if err := argCountCheck(2); err != nil { + return nil, err + } return impl.endsWith(args[0], args[1]) case "format": + if err := argAtLeastCheck(1); err != nil { + return nil, err + } return impl.format(args[0], args[1:]...) case "join": + if err := argAtLeastCheck(1); err != nil { + return nil, err + } if len(args) == 1 { return impl.join(args[0], reflect.ValueOf(",")) } return impl.join(args[0], args[1]) case "tojson": + if err := argCountCheck(1); err != nil { + return nil, err + } return impl.toJSON(args[0]) case "fromjson": + if err := argCountCheck(1); err != nil { + return nil, err + } return impl.fromJSON(args[0]) case "hashfiles": if impl.env.HashFiles != nil {