From f0c3f257f28e38c13dff520c26367ef83d495fb7 Mon Sep 17 00:00:00 2001 From: Marc Scholten Date: Tue, 7 Nov 2023 21:22:52 -0800 Subject: [PATCH] Added SECURITY DEFINER to sql --- IHP/IDE/SchemaDesigner/Compiler.hs | 6 +++++- IHP/IDE/SchemaDesigner/Parser.hs | 9 +++++++-- IHP/IDE/SchemaDesigner/SchemaOperations.hs | 1 + IHP/IDE/SchemaDesigner/Types.hs | 2 +- Test/IDE/CodeGeneration/MigrationGenerator.hs | 1 + Test/IDE/SchemaDesigner/CompilerSpec.hs | 17 +++++++++++++++++ Test/IDE/SchemaDesigner/ParserSpec.hs | 16 ++++++++++++++++ Test/IDE/SchemaDesigner/SchemaOperationsSpec.hs | 2 ++ 8 files changed, 50 insertions(+), 4 deletions(-) diff --git a/IHP/IDE/SchemaDesigner/Compiler.hs b/IHP/IDE/SchemaDesigner/Compiler.hs index b05af3c59..46f797261 100644 --- a/IHP/IDE/SchemaDesigner/Compiler.hs +++ b/IHP/IDE/SchemaDesigner/Compiler.hs @@ -33,7 +33,7 @@ compileStatement RenameColumn { tableName, from, to } = "ALTER TABLE " <> compil compileStatement DropTable { tableName } = "DROP TABLE " <> compileIdentifier tableName <> ";" compileStatement Comment { content } = "--" <> content compileStatement CreateIndex { indexName, unique, tableName, columns, whereClause, indexType } = "CREATE" <> (if unique then " UNIQUE " else " ") <> "INDEX " <> compileIdentifier indexName <> " ON " <> compileIdentifier tableName <> (maybe "" (\indexType -> " USING " <> compileIndexType indexType) indexType) <> " (" <> (intercalate ", " (map compileIndexColumn columns)) <> ")" <> (case whereClause of Just expression -> " WHERE " <> compileExpression expression; Nothing -> "") <> ";" -compileStatement CreateFunction { functionName, functionArguments, functionBody, orReplace, returns, language } = "CREATE " <> (if orReplace then "OR REPLACE " else "") <> "FUNCTION " <> functionName <> "(" <> (functionArguments |> map (\(argName, argType) -> argName ++ " " ++ compilePostgresType argType) |> intercalate ", ") <> ")" <> " RETURNS " <> compilePostgresType returns <> " AS $$" <> functionBody <> "$$ language " <> language <> ";" +compileStatement CreateFunction { functionName, functionArguments, functionBody, orReplace, returns, language, securityDefiner } = "CREATE " <> (if orReplace then "OR REPLACE " else "") <> "FUNCTION " <> functionName <> "(" <> (functionArguments |> map (\(argName, argType) -> argName ++ " " ++ compilePostgresType argType) |> intercalate ", ") <> ")" <> " RETURNS " <> compilePostgresType returns <> compileSecurityDefiner securityDefiner <> " AS $$" <> functionBody <> "$$ language " <> language <> ";" compileStatement EnableRowLevelSecurity { tableName } = "ALTER TABLE " <> compileIdentifier tableName <> " ENABLE ROW LEVEL SECURITY;" compileStatement CreatePolicy { name, action, tableName, using, check } = "CREATE POLICY " <> compileIdentifier name <> " ON " <> compileIdentifier tableName <> maybe "" (\action -> " FOR " <> compilePolicyAction action) action <> maybe "" (\expr -> " USING (" <> compileExpression expr <> ")") using <> maybe "" (\expr -> " WITH CHECK (" <> compileExpression expr <> ")") check <> ";" compileStatement CreateSequence { name } = "CREATE SEQUENCE " <> compileIdentifier name <> ";" @@ -500,3 +500,7 @@ compileIndexColumnOrder Asc = "ASC" compileIndexColumnOrder Desc = "DESC" compileIndexColumnOrder NullsFirst = "NULLS FIRST" compileIndexColumnOrder NullsLast = "NULLS LAST" + +compileSecurityDefiner :: Bool -> Text +compileSecurityDefiner True = " SECURITY DEFINER" +compileSecurityDefiner False = "" \ No newline at end of file diff --git a/IHP/IDE/SchemaDesigner/Parser.hs b/IHP/IDE/SchemaDesigner/Parser.hs index 753e0c21c..e337b0fcd 100644 --- a/IHP/IDE/SchemaDesigner/Parser.hs +++ b/IHP/IDE/SchemaDesigner/Parser.hs @@ -604,8 +604,13 @@ createFunction = do lexeme "language" <|> lexeme "LANGUAGE" symbol' "plpgsql" <|> symbol' "SQL" + securityDefiner <- isJust <$> optional do + lexeme "SECURITY" + lexeme "DEFINER" + pure True + lexeme "AS" - space + -- space functionBody <- cs <$> between (char '$' >> char '$') (char '$' >> char '$') (many (anySingleBut '$')) space @@ -615,7 +620,7 @@ createFunction = do lexeme "language" <|> lexeme "LANGUAGE" symbol' "plpgsql" <|> symbol' "SQL" char ';' - pure CreateFunction { functionName, functionArguments, functionBody, orReplace, returns, language } + pure CreateFunction { functionName, functionArguments, functionBody, orReplace, returns, language, securityDefiner } where functionArgument = do argumentName <- qualifiedIdentifier diff --git a/IHP/IDE/SchemaDesigner/SchemaOperations.hs b/IHP/IDE/SchemaDesigner/SchemaOperations.hs index c6ca53286..f24f79a70 100644 --- a/IHP/IDE/SchemaDesigner/SchemaOperations.hs +++ b/IHP/IDE/SchemaDesigner/SchemaOperations.hs @@ -525,6 +525,7 @@ addUpdatedAtTrigger tableName schema = , orReplace = False , returns = PTrigger , language = "plpgsql" + , securityDefiner = False } deleteTriggerIfExists :: Text -> [Statement] -> [Statement] diff --git a/IHP/IDE/SchemaDesigner/Types.hs b/IHP/IDE/SchemaDesigner/Types.hs index 1d9fbf465..017cb9a85 100644 --- a/IHP/IDE/SchemaDesigner/Types.hs +++ b/IHP/IDE/SchemaDesigner/Types.hs @@ -35,7 +35,7 @@ data Statement -- | DROP INDEX indexName; | DropIndex { indexName :: Text } -- | CREATE OR REPLACE FUNCTION functionName(param1 TEXT, param2 INT) RETURNS TRIGGER AS $$functionBody$$ language plpgsql; - | CreateFunction { functionName :: Text, functionArguments :: [(Text, PostgresType)], functionBody :: Text, orReplace :: Bool, returns :: PostgresType, language :: Text } + | CreateFunction { functionName :: Text, functionArguments :: [(Text, PostgresType)], functionBody :: Text, orReplace :: Bool, returns :: PostgresType, language :: Text, securityDefiner :: Bool } -- | ALTER TABLE tableName ENABLE ROW LEVEL SECURITY; | EnableRowLevelSecurity { tableName :: Text } -- CREATE POLICY name ON tableName USING using WITH CHECK check; diff --git a/Test/IDE/CodeGeneration/MigrationGenerator.hs b/Test/IDE/CodeGeneration/MigrationGenerator.hs index 16c64739b..1136870b1 100644 --- a/Test/IDE/CodeGeneration/MigrationGenerator.hs +++ b/Test/IDE/CodeGeneration/MigrationGenerator.hs @@ -1314,6 +1314,7 @@ CREATE POLICY "Users can read and edit their own record" ON public.users USING ( , orReplace = False , returns = PTrigger , language = "PLPGSQL" + , securityDefiner = False }] it "should delete the updated_at trigger when the updated_at column is deleted" do diff --git a/Test/IDE/SchemaDesigner/CompilerSpec.hs b/Test/IDE/SchemaDesigner/CompilerSpec.hs index 22a2f3bc4..996d4f4df 100644 --- a/Test/IDE/SchemaDesigner/CompilerSpec.hs +++ b/Test/IDE/SchemaDesigner/CompilerSpec.hs @@ -749,6 +749,7 @@ tests = do , orReplace = True , returns = PTrigger , language = "plpgsql" + , securityDefiner = False } compileSql [statement] `shouldBe` sql @@ -762,6 +763,7 @@ tests = do , orReplace = False , returns = PTrigger , language = "plpgsql" + , securityDefiner = False } compileSql [statement] `shouldBe` sql @@ -775,6 +777,21 @@ tests = do , orReplace = False , returns = PTrigger , language = "plpgsql" + , securityDefiner = False + } + + compileSql [statement] `shouldBe` sql + + it "should compile a CREATE FUNCTION with SECURITY DEFINER" do + let sql = cs [plain|CREATE FUNCTION create_membership_for_new_organisation() RETURNS TRIGGER SECURITY DEFINER AS $$ BEGIN INSERT INTO organisation_memberships (user_id, organisation_id) VALUES (ihp_user_id(), NEW.id); RETURN NEW; END; $$ language plpgsql;\n|] + let statement = CreateFunction + { functionName = "create_membership_for_new_organisation" + , functionArguments = [] + , functionBody = " BEGIN INSERT INTO organisation_memberships (user_id, organisation_id) VALUES (ihp_user_id(), NEW.id); RETURN NEW; END; " + , orReplace = False + , returns = PTrigger + , language = "plpgsql" + , securityDefiner = True } compileSql [statement] `shouldBe` sql diff --git a/Test/IDE/SchemaDesigner/ParserSpec.hs b/Test/IDE/SchemaDesigner/ParserSpec.hs index e7b558022..962efa9ce 100644 --- a/Test/IDE/SchemaDesigner/ParserSpec.hs +++ b/Test/IDE/SchemaDesigner/ParserSpec.hs @@ -762,6 +762,7 @@ tests = do , orReplace = True , returns = PTrigger , language = "plpgsql" + , securityDefiner = False } it "should parse a CREATE FUNCTION ..() RETURNS TRIGGER .." do @@ -772,6 +773,7 @@ tests = do , orReplace = False , returns = PTrigger , language = "plpgsql" + , securityDefiner = False } it "should parse a CREATE FUNCTION with parameters ..() RETURNS TRIGGER .." do @@ -782,8 +784,21 @@ tests = do , orReplace = False , returns = PTrigger , language = "plpgsql" + , securityDefiner = False } + it "should parse a CREATE FUNCTION with SECURITY DEFINER" do + parseSql "CREATE FUNCTION create_membership_for_new_organisation() RETURNS TRIGGER LANGUAGE plpgsql SECURITY DEFINER AS $$ BEGIN INSERT INTO organisation_memberships (user_id, organisation_id) VALUES (ihp_user_id(), NEW.id); RETURN NEW; END; $$ language plpgsql;" `shouldBe` CreateFunction + { functionName = "create_membership_for_new_organisation" + , functionArguments = [] + , functionBody = " BEGIN INSERT INTO organisation_memberships (user_id, organisation_id) VALUES (ihp_user_id(), NEW.id); RETURN NEW; END; " + , orReplace = False + , returns = PTrigger + , language = "plpgsql" + , securityDefiner = True + } + + it "should parse CREATE FUNCTION statements that are outputted by pg_dump" do let sql = cs [plain| CREATE FUNCTION public.notify_did_change_projects() RETURNS trigger @@ -800,6 +815,7 @@ $$; , orReplace = False , returns = PTrigger , language = "plpgsql" + , securityDefiner = False } it "should parse a decimal default value with a type-cast" do diff --git a/Test/IDE/SchemaDesigner/SchemaOperationsSpec.hs b/Test/IDE/SchemaDesigner/SchemaOperationsSpec.hs index b4eff8b7b..cd5895949 100644 --- a/Test/IDE/SchemaDesigner/SchemaOperationsSpec.hs +++ b/Test/IDE/SchemaDesigner/SchemaOperationsSpec.hs @@ -247,6 +247,7 @@ tests = do , orReplace = False , returns = PTrigger , language = "plpgsql" + , securityDefiner = False } let trigger = CreateTrigger { name = "update_a_updated_at" @@ -396,6 +397,7 @@ tests = do , orReplace = False , returns = PTrigger , language = "plpgsql" + , securityDefiner = False } let trigger = CreateTrigger { name = "update_a_updated_at"