diff --git a/velox/functions/sparksql/DateTimeFunctions.cpp b/velox/functions/sparksql/DateTimeFunctions.cpp index bdbff6a88cb0..e255fefa54a9 100644 --- a/velox/functions/sparksql/DateTimeFunctions.cpp +++ b/velox/functions/sparksql/DateTimeFunctions.cpp @@ -31,14 +31,13 @@ Timestamp makeTimeStampFromDecodedArgs( DecodedVector* micros) { auto totalMicros = micros->valueAt(row); auto seconds = totalMicros / util::kMicrosPerSec; - auto nanos = totalMicros % util::kMicrosPerSec; VELOX_USER_CHECK( seconds <= 60, "Invalid value for SecondOfMinute (valid values 0 - 59): {}.", seconds); if (seconds == 60) { VELOX_USER_CHECK( - nanos == 0, + totalMicros % util::kMicrosPerSec == 0, "The fraction of sec must be zero. Valid range is [0, 60]."); } @@ -54,7 +53,7 @@ Timestamp makeTimeStampFromDecodedArgs( class MakeTimestampFunction : public exec::VectorFunction { public: - MakeTimestampFunction() = default; + MakeTimestampFunction(int64_t sessionTzID) : sessionTzID_(sessionTzID) {} void apply( const SelectivityVector& rows, @@ -62,12 +61,6 @@ class MakeTimestampFunction : public exec::VectorFunction { const TypePtr& outputType, exec::EvalCtx& context, VectorPtr& result) const override { - auto microsType = args[5]->type()->asShortDecimal(); - VELOX_USER_CHECK( - microsType.scale() == 6, - "Seconds fraction must have 6 digits for microseconds but got {}", - microsType.scale()); - context.ensureWritable(rows, TIMESTAMP(), result); auto* resultFlatVector = result->as>(); @@ -108,16 +101,10 @@ class MakeTimestampFunction : public exec::VectorFunction { } else { // Otherwise use session timezone. If session timezone is not specified, // use default value 0(UTC timezone). - int64_t sessionTzID = 0; - const auto& queryConfig = context.execCtx()->queryCtx()->queryConfig(); - const auto sessionTzName = queryConfig.sessionTimezone(); - if (!sessionTzName.empty()) { - sessionTzID = util::getTimeZoneID(sessionTzName); - } rows.applyToSelected([&](vector_size_t row) { auto timestamp = makeTimeStampFromDecodedArgs( row, year, month, day, hour, minute, micros); - timestamp.toGMT(sessionTzID); + timestamp.toGMT(sessionTzID_); resultFlatVector->set(row, timestamp); }); } @@ -150,12 +137,35 @@ class MakeTimestampFunction : public exec::VectorFunction { .build(), }; } + + private: + int64_t sessionTzID_; }; + +std::shared_ptr createMakeTimestampFunction( + const std::string& /* name */, + const std::vector& inputArgs, + const core::QueryConfig& config) { + VELOX_USER_CHECK( + inputArgs[5].type->isShortDecimal(), + "Seconds must be short decimal type but got {}", inputArgs[5].type->toString()); + auto microsType = inputArgs[5].type->asShortDecimal(); + VELOX_USER_CHECK( + microsType.scale() == 6, + "Seconds fraction must have 6 digits for microseconds but got {}", + microsType.scale()); + + const auto sessionTzName = config.sessionTimezone(); + const auto sessionTzID = + sessionTzName.empty() ? 0 : util::getTimeZoneID(sessionTzName); + + return std::make_shared(sessionTzID); +} } // namespace -VELOX_DECLARE_VECTOR_FUNCTION( +VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_make_timestamp, MakeTimestampFunction::signatures(), - std::make_unique()); + createMakeTimestampFunction); } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/DateTimeFunctionsTest.cpp b/velox/functions/sparksql/tests/DateTimeFunctionsTest.cpp index 551e5b9b1e6c..c2e59269c645 100644 --- a/velox/functions/sparksql/tests/DateTimeFunctionsTest.cpp +++ b/velox/functions/sparksql/tests/DateTimeFunctionsTest.cpp @@ -767,7 +767,17 @@ TEST_F(DateTimeFunctionsTest, makeTimestamp) { : evaluate("make_timestamp(c0, c1, c2, c3, c4, c5)", data); facebook::velox::test::assertEqualVectors(expected, result); }; - const auto microsType = DECIMAL(18, 6); + + const auto testConstantTimezone = [&](const RowVectorPtr& data, + const std::string& timezone, + const VectorPtr& expected) { + auto result = evaluate( + fmt::format("make_timestamp(c0, c1, c2, c3, c4, c5, '{}')", timezone), + data); + facebook::velox::test::assertEqualVectors(expected, result); + }; + + const auto microsType = DECIMAL(16, 6); // Valid cases w/o timezone. { @@ -779,25 +789,28 @@ TEST_F(DateTimeFunctionsTest, makeTimestamp) { const auto micros = makeNullableFlatVector( {45678000, 1e6, 6e7, 59999999, std::nullopt}, microsType); auto data = makeRowVector({year, month, day, hour, minute, micros}); - { - auto expected = makeNullableFlatVector( - {util::fromTimestampString("2021-07-11 06:30:45.678"), - util::fromTimestampString("2021-07-11 06:30:01"), - util::fromTimestampString("2021-07-11 06:31:00"), - util::fromTimestampString("2021-07-11 06:30:59.999999"), - std::nullopt}); - testMakeTimestamp(data, expected, false); - } - { - setQueryTimeZone("Asia/Shanghai"); - auto expected = makeNullableFlatVector( - {util::fromTimestampString("2021-07-10 22:30:45.678"), - util::fromTimestampString("2021-07-10 22:30:01"), - util::fromTimestampString("2021-07-10 22:31:00"), - util::fromTimestampString("2021-07-10 22:30:59.999999"), - std::nullopt}); - testMakeTimestamp(data, expected, false); - } + + // Test w/o session timezone. + setQueryTimeZone(""); + auto expectedGMT = makeNullableFlatVector( + {util::fromTimestampString("2021-07-11 06:30:45.678"), + util::fromTimestampString("2021-07-11 06:30:01"), + util::fromTimestampString("2021-07-11 06:31:00"), + util::fromTimestampString("2021-07-11 06:30:59.999999"), + std::nullopt}); + testMakeTimestamp(data, expectedGMT, false); + testConstantTimezone(data, "GMT", expectedGMT); + + // Test w/ session timezone. + setQueryTimeZone("Asia/Shanghai"); + auto expectedSessionTimezone = makeNullableFlatVector( + {util::fromTimestampString("2021-07-10 22:30:45.678"), + util::fromTimestampString("2021-07-10 22:30:01"), + util::fromTimestampString("2021-07-10 22:31:00"), + util::fromTimestampString("2021-07-10 22:30:59.999999"), + std::nullopt}); + testMakeTimestamp(data, expectedSessionTimezone, false); + testConstantTimezone(data, "GMT", expectedGMT); } // Valid cases w/ timezone. @@ -814,6 +827,7 @@ TEST_F(DateTimeFunctionsTest, makeTimestamp) { auto data = makeRowVector({year, month, day, hour, minute, micros, timezone}); { + setQueryTimeZone(""); auto expected = makeNullableFlatVector( {util::fromTimestampString("2021-07-11 06:30:45.678"), util::fromTimestampString("2021-07-11 04:30:45.678"), @@ -864,6 +878,10 @@ TEST_F(DateTimeFunctionsTest, makeTimestamp) { 60007000, microsType, "The fraction of sec must be zero. Valid range is [0, 60]."); + testMicrosError( + 60007000, + DECIMAL(20, 8), + "Seconds must be short decimal type but got DECIMAL(20, 8)"); testMicrosError( 60007000, DECIMAL(18, 8),