From 16311f4ff54ceb7e329bc40f113f503d3ff8979a Mon Sep 17 00:00:00 2001 From: Laurence Isla Date: Mon, 19 Aug 2024 20:27:14 -0500 Subject: [PATCH] (WIP) fix: spread embeds failing with count() aggregate - Fixed "column reference is ambiguous" when selecting "?select=...table(col,count())" - Fixed "column . does not exist" when selecting "?select=...table(aias:count())" --- src/PostgREST/Plan.hs | 29 ++++++++++++++++++++--------- src/PostgREST/Plan/Types.hs | 3 ++- src/PostgREST/Query/QueryBuilder.hs | 2 +- src/PostgREST/Query/SqlFragment.hs | 1 + test/spec/fixtures/data.sql | 13 ++++++++----- test/spec/fixtures/schema.sql | 1 + 6 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/PostgREST/Plan.hs b/src/PostgREST/Plan.hs index 64843d9490..923f60cb52 100644 --- a/src/PostgREST/Plan.hs +++ b/src/PostgREST/Plan.hs @@ -53,6 +53,7 @@ import PostgREST.SchemaCache.Identifiers (FieldName, QualifiedIdentifier (..), RelIdentifier (..), Schema) + --TableName) import PostgREST.SchemaCache.Relationship (Cardinality (..), Junction (..), Relationship (..), @@ -270,7 +271,7 @@ data ResolverContext = ResolverContext } resolveColumnField :: Column -> CoercibleField -resolveColumnField col = CoercibleField (colName col) mempty False (colNominalType col) Nothing (colDefault col) +resolveColumnField col = CoercibleField (colName col) mempty False (colNominalType col) Nothing (colDefault col) False resolveTableFieldName :: Table -> FieldName -> CoercibleField resolveTableFieldName table fieldName = @@ -376,15 +377,24 @@ initReadRequest ctx@ResolverContext{qi=QualifiedIdentifier{..}} = addAliases :: ReadPlanTree -> Either ApiRequestError ReadPlanTree addAliases = Right . fmap addAliasToPlan where - addAliasToPlan rp@ReadPlan{select=sel} = rp{select=map aliasSelectField sel} + addAliasToPlan rp@ReadPlan{select=sel, relIsSpread=spr} = rp{select=map (aliasSelectField spr) sel} - aliasSelectField :: CoercibleSelectField -> CoercibleSelectField - aliasSelectField field@CoercibleSelectField{csField=fieldDetails, csAggFunction=aggFun, csAlias=alias} - | isJust alias || isJust aggFun = field + aliasSelectField :: Bool -> CoercibleSelectField -> CoercibleSelectField + aliasSelectField isSpread field@CoercibleSelectField{csField=fieldDetails, csAggFunction=aggFun, csAlias=alias} + | isJust alias = field + | isJust aggFun = fieldAliasForSpreadAgg isSpread field | isJsonKeyPath fieldDetails, Just key <- lastJsonKey fieldDetails = field { csAlias = Just key } | isTransformPath fieldDetails = field { csAlias = Just (cfName fieldDetails) } | otherwise = field + fieldAliasForSpreadAgg isSpread field@CoercibleSelectField{csField=fieldDetails, csAggFunction=aggFun} = + -- A request like: `/top_table?select=...middle_table(...nested_table(count()))` will `SELECT` the full row instead of `*`, + -- because doing a `COUNT(*)` in `top_table` would not return the desired results. + -- So we use the "count" alias if none is present since the field name won't be selected. + if isSpread && cfName fieldDetails == "*" && aggFun == Just Count + then field { csAlias = Just "count" } + else field + isJsonKeyPath CoercibleField{cfJsonPath=(_: _)} = True isJsonKeyPath _ = False @@ -703,9 +713,10 @@ hoistFromSelectFields relAggAlias fields = let (modifiedField, maybeAgg) = modifyField field in (modifiedField : newFields, maybeAgg : aggList) - modifyField field@CoercibleSelectField{csAggFunction=Just aggFunc, csField, csAggCast, csAlias} = - let determineFieldName = fromMaybe (cfName csField) csAlias - updatedField = field {csAggFunction = Nothing, csAggCast = Nothing} + modifyField field@CoercibleSelectField{csAggFunction=Just aggFunc, csField=fieldDetails, csAggCast, csAlias} = + let determineFieldName = fromMaybe (cfName fieldDetails) csAlias + isFullRow = cfName fieldDetails == "*" && aggFunc == Count + updatedField = field {csField = fieldDetails{cfFullRow = isFullRow}, csAggFunction = Nothing, csAggCast = Nothing} hoistedField = Just ((relAggAlias, determineFieldName), (aggFunc, csAggCast, csAlias)) in (updatedField, hoistedField) modifyField field = (field, Nothing) @@ -858,7 +869,7 @@ addNullEmbedFilters (Node rp@ReadPlan{where_=curLogic} forest) = do newNullFilters rPlans = \case (CoercibleExpr b lOp trees) -> CoercibleExpr b lOp <$> (newNullFilters rPlans `traverse` trees) - flt@(CoercibleStmnt (CoercibleFilter (CoercibleField fld [] _ _ _ _) opExpr)) -> + flt@(CoercibleStmnt (CoercibleFilter (CoercibleField fld [] _ _ _ _ _) opExpr)) -> let foundRP = find (\ReadPlan{relName, relAlias} -> fld == fromMaybe relName relAlias) rPlans in case (foundRP, opExpr) of (Just ReadPlan{relAggAlias}, OpExpr b (Is TriNull)) -> Right $ CoercibleStmnt $ CoercibleFilterNullEmbed b relAggAlias diff --git a/src/PostgREST/Plan/Types.hs b/src/PostgREST/Plan/Types.hs index 97de469952..2d72c05419 100644 --- a/src/PostgREST/Plan/Types.hs +++ b/src/PostgREST/Plan/Types.hs @@ -39,10 +39,11 @@ data CoercibleField = CoercibleField , cfIRType :: Text -- ^ The native Postgres type of the field, the intermediate (IR) type before mapping. , cfTransform :: Maybe TransformerProc -- ^ The optional mapping from irType -> targetType. , cfDefault :: Maybe Text + , cfFullRow :: Bool -- ^ True if the field represents the whole selected row. Used in spread rels: instead of COUNT(*), it does a COUNT() in order to not mix with other spreaded resources. } deriving (Eq, Show) unknownField :: FieldName -> JsonPath -> CoercibleField -unknownField name path = CoercibleField name path False "" Nothing Nothing +unknownField name path = CoercibleField name path False "" Nothing Nothing False -- | Like an API request LogicTree, but with coercible field information. data CoercibleLogicTree diff --git a/src/PostgREST/Query/QueryBuilder.hs b/src/PostgREST/Query/QueryBuilder.hs index 96701be252..602ae27ef1 100644 --- a/src/PostgREST/Query/QueryBuilder.hs +++ b/src/PostgREST/Query/QueryBuilder.hs @@ -206,7 +206,7 @@ callPlanToQuery (FunctionCall qi params arguments returnsScalar returnsSetOfScal KeyParams [] -> "FROM " <> callIt mempty KeyParams prms -> case arguments of DirectArgs args -> "FROM " <> callIt (fmtArgs prms args) - JsonArgs json -> fromJsonBodyF json ((\p -> CoercibleField (ppName p) mempty False (ppTypeMaxLength p) Nothing Nothing) <$> prms) False True False <> ", " <> + JsonArgs json -> fromJsonBodyF json ((\p -> CoercibleField (ppName p) mempty False (ppTypeMaxLength p) Nothing Nothing False) <$> prms) False True False <> ", " <> "LATERAL " <> callIt (fmtParams prms) callIt :: SQL.Snippet -> SQL.Snippet diff --git a/src/PostgREST/Query/SqlFragment.hs b/src/PostgREST/Query/SqlFragment.hs index 39b869d5d9..b2e5884140 100644 --- a/src/PostgREST/Query/SqlFragment.hs +++ b/src/PostgREST/Query/SqlFragment.hs @@ -252,6 +252,7 @@ pgFmtCallUnary :: Text -> SQL.Snippet -> SQL.Snippet pgFmtCallUnary f x = SQL.sql (encodeUtf8 f) <> "(" <> x <> ")" pgFmtField :: QualifiedIdentifier -> CoercibleField -> SQL.Snippet +pgFmtField table CoercibleField{cfFullRow=True} = fromQi table pgFmtField table CoercibleField{cfName=fn, cfJsonPath=[]} = pgFmtColumn table fn pgFmtField table CoercibleField{cfName=fn, cfToJson=doToJson, cfJsonPath=jp} | doToJson = "to_jsonb(" <> pgFmtColumn table fn <> ")" <> pgFmtJsonPath jp | otherwise = pgFmtColumn table fn <> pgFmtJsonPath jp diff --git a/test/spec/fixtures/data.sql b/test/spec/fixtures/data.sql index 8d9a324594..ae32777064 100644 --- a/test/spec/fixtures/data.sql +++ b/test/spec/fixtures/data.sql @@ -895,11 +895,13 @@ INSERT INTO process_categories VALUES (1, 'Batch'); INSERT INTO process_categories VALUES (2, 'Mass'); TRUNCATE TABLE processes CASCADE; -INSERT INTO processes VALUES (1, 'Process A1', 1, 1); -INSERT INTO processes VALUES (2, 'Process A2', 1, 2); -INSERT INTO processes VALUES (3, 'Process B1', 2, 1); -INSERT INTO processes VALUES (4, 'Process B2', 2, 1); -INSERT INTO processes VALUES (5, 'Process C1', 3, 2); +INSERT INTO processes VALUES (1, 'Process A1', 23, 1, 1); +INSERT INTO processes VALUES (2, 'Process A2', 23, 1, 2); +INSERT INTO processes VALUES (3, 'Process B1', 23, 2, 1); +INSERT INTO processes VALUES (4, 'Process B2', 23, 2, 1); +INSERT INTO processes VALUES (5, 'Process C1', 23, 3, 2); +INSERT INTO processes VALUES (6, 'Process C2', 23, 3, 2); +INSERT INTO processes VALUES (7, 'Process XX', 23, 3, 2); TRUNCATE TABLE process_costs CASCADE; INSERT INTO process_costs VALUES (1, 150.00); @@ -922,3 +924,4 @@ INSERT INTO process_supervisor VALUES (3, 4); INSERT INTO process_supervisor VALUES (4, 1); INSERT INTO process_supervisor VALUES (4, 2); INSERT INTO process_supervisor VALUES (5, 3); +INSERT INTO process_supervisor VALUES (6, 3); diff --git a/test/spec/fixtures/schema.sql b/test/spec/fixtures/schema.sql index a3b6edda16..e15b5b6c8c 100644 --- a/test/spec/fixtures/schema.sql +++ b/test/spec/fixtures/schema.sql @@ -3763,6 +3763,7 @@ create table process_categories ( create table processes ( id int primary key, name text, + count int, factory_id int references factories(id), category_id int references process_categories(id) );