diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java index 430071a046..26434a2282 100644 --- a/common/src/main/java/org/apache/sedona/common/Functions.java +++ b/common/src/main/java/org/apache/sedona/common/Functions.java @@ -2266,6 +2266,45 @@ public static Geometry points(Geometry geometry) { return geometry.getFactory().createMultiPointFromCoords(coordinates); } + public static Geometry scale(Geometry geometry, double scaleX, double scaleY) { + return scaleGeom(geometry, Constructors.point(scaleX, scaleY)); + } + + public static Geometry scaleGeom(Geometry geometry, Geometry factor) { + return scaleGeom(geometry, factor, null); + } + + public static Geometry scaleGeom(Geometry geometry, Geometry factor, Geometry origin) { + if (geometry == null || factor == null || geometry.isEmpty() || factor.isEmpty()) { + return geometry; + } + + if (!factor.getGeometryType().equalsIgnoreCase(Geometry.TYPENAME_POINT)) { + throw new IllegalArgumentException("Scale factor geometry should be a Point type."); + } + + Geometry resultGeom = null; + AffineTransformation scaleInstance = null; + Coordinate factorCoordinate = factor.getCoordinate(); + + if (origin == null || origin.isEmpty()) { + scaleInstance = + AffineTransformation.scaleInstance(factorCoordinate.getX(), factorCoordinate.getY()); + resultGeom = scaleInstance.transform(geometry); + } else { + Coordinate falseOrigin = origin.getCoordinate(); + scaleInstance = + AffineTransformation.scaleInstance( + factorCoordinate.getX(), + factorCoordinate.getY(), + falseOrigin.getX(), + falseOrigin.getY()); + resultGeom = scaleInstance.transform(geometry); + } + + return resultGeom; + } + public static Geometry rotateX(Geometry geometry, double angle) { if (GeomUtils.isAnyGeomEmpty(geometry)) { return geometry; diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java index cd60d16406..b554339c93 100644 --- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java +++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java @@ -3994,6 +3994,44 @@ public void points() throws ParseException { assertEquals("MULTIPOINT Z((0 0 1), (1 1 2), (2 2 3), (0 0 1))", result1); } + @Test + public void scale() throws ParseException { + Geometry geom = Constructors.geomFromWKT("POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))", 0); + Geometry actual = Functions.scale(geom, 3, 2); + String expected = "POLYGON ((0 0, 0 2, 3 2, 3 0, 0 0))"; + assertEquals(expected, actual.toString()); + + geom = Constructors.geomFromWKT("LINESTRING(0 1, 1 0)", 0); + actual = Functions.scale(geom, 10, 5); + expected = "LINESTRING (0 5, 10 0)"; + assertEquals(expected, actual.toString()); + + geom = Constructors.geomFromWKT("POLYGON ((0 0, 0 1.5, 1.5 1.5, 1.5 0, 0 0))", 1111); + actual = Functions.scaleGeom(geom, Constructors.point(1.8, 2.1)); + expected = "POLYGON ((0 0, 0 3.1500000000000004, 2.7 3.1500000000000004, 2.7 0, 0 0))"; + assertEquals(expected, actual.toString()); + assertEquals(1111, actual.getSRID()); + + actual = + Functions.scaleGeom(geom, Constructors.point(3, 2), Constructors.point(0.32959, 0.796483)); + expected = + "POLYGON ((-0.6591799999999999 -0.796483, -0.6591799999999999 2.2035169999999997, 3.84082 2.2035169999999997, 3.84082 -0.796483, -0.6591799999999999 -0.796483))"; + assertEquals(expected, actual.toString()); + + // test to check Z and M ordinate preservation + geom = Constructors.geomFromWKT("POLYGON ((0 0 1, 0 1.5 2, 1.5 1.5 2, 1.5 0 3, 0 0 1))", 0); + String actualWKT = Functions.asWKT(Functions.scale(geom, 3, 2)); + expected = "POLYGON Z((0 0 1, 0 3 2, 4.5 3 2, 4.5 0 3, 0 0 1))"; + assertEquals(expected, actualWKT); + + geom = + Constructors.geomFromWKT( + "POLYGON ZM((0 0 1 2, 0 1.5 2 2, 1.5 1.5 2 2, 1.5 0 3 2, 0 0 1 2))", 0); + actualWKT = Functions.asWKT(Functions.scale(geom, 3, 2)); + expected = "POLYGON ZM((0 0 1 2, 0 3 2 2, 4.5 3 2 2, 4.5 0 3 2, 0 0 1 2))"; + assertEquals(expected, actualWKT); + } + @Test public void rotateX() throws ParseException { Geometry lineString = Constructors.geomFromEWKT("LINESTRING (50 160, 50 50, 100 50)"); diff --git a/flink/src/main/java/org/apache/sedona/flink/Catalog.java b/flink/src/main/java/org/apache/sedona/flink/Catalog.java index bef6465214..21728c5727 100644 --- a/flink/src/main/java/org/apache/sedona/flink/Catalog.java +++ b/flink/src/main/java/org/apache/sedona/flink/Catalog.java @@ -100,6 +100,8 @@ public static UserDefinedFunction[] getFuncs() { new Functions.ST_FlipCoordinates(), new Functions.ST_GeoHash(), new Functions.ST_PointOnSurface(), + new Functions.ST_Scale(), + new Functions.ST_ScaleGeom(), new Functions.ST_ReducePrecision(), new Functions.ST_Reverse(), new Functions.ST_Rotate(), diff --git a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java index ccaefc2183..bccdbe2364 100644 --- a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java +++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java @@ -1977,6 +1977,44 @@ public String eval( } } + public static class ST_Scale extends ScalarFunction { + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) + public Geometry eval( + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) Object o, + @DataTypeHint(value = "Double") Double scaleX, + @DataTypeHint(value = "Double") Double scaleY) { + Geometry geometry = (Geometry) o; + return org.apache.sedona.common.Functions.scale(geometry, scaleX, scaleY); + } + } + + public static class ST_ScaleGeom extends ScalarFunction { + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) + public Geometry eval( + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) + Object o1, + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) + Object o2, + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) + Object o3) { + Geometry geometry = (Geometry) o1; + Geometry factor = (Geometry) o2; + Geometry origin = (Geometry) o3; + return org.apache.sedona.common.Functions.scaleGeom(geometry, factor, origin); + } + + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) + public Geometry eval( + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) + Object o1, + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) + Object o2) { + Geometry geometry = (Geometry) o1; + Geometry factor = (Geometry) o2; + return org.apache.sedona.common.Functions.scaleGeom(geometry, factor); + } + } + public static class ST_RotateX extends ScalarFunction { @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) public Geometry eval( diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java index 4de98c0fd3..cc1a137a30 100644 --- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java +++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java @@ -2654,6 +2654,52 @@ public void testIsValidReason() { // standards } + @Test + public void testScale() { + Table tbl = + tableEnv.sqlQuery( + "SELECT ST_GeomFromWKT('POLYGON ((0 0, 2 0, 1 1, 2 2, 0 2, 1 1, 0 0))', 1010) AS geom"); + Geometry actual = + (Geometry) + first(tbl.select(call(Functions.ST_Scale.class.getSimpleName(), $("geom"), 2, 3))) + .getField(0); + String expected = "POLYGON ((0 0, 4 0, 2 3, 4 6, 0 6, 2 3, 0 0))"; + assertEquals(expected, actual.toString()); + assertEquals(1010, actual.getSRID()); + } + + @Test + public void testScaleGeom() { + Table tbl = + tableEnv.sqlQuery( + "SELECT ST_GeomFromWKT('POLYGON ((0 0, 2 0, 1 1, 2 2, 0 2, 1 1, 0 0))', 1010) AS geom, ST_GeomFromWKT('POINT (2 3)') AS factor"); + Geometry actual = + (Geometry) + first( + tbl.select( + call(Functions.ST_ScaleGeom.class.getSimpleName(), $("geom"), $("factor")))) + .getField(0); + String expected = "POLYGON ((0 0, 4 0, 2 3, 4 6, 0 6, 2 3, 0 0))"; + assertEquals(expected, actual.toString()); + assertEquals(1010, actual.getSRID()); + + tbl = + tableEnv.sqlQuery( + "SELECT ST_GeomFromWKT('POLYGON ((0 0, 2 0, 1 1, 2 2, 0 2, 1 1, 0 0))', 1010) AS geom, ST_GeomFromWKT('POINT (2 3)') AS factor, ST_GeomFromWKT('POINT (-1 0)') AS origin"); + actual = + (Geometry) + first( + tbl.select( + call( + Functions.ST_ScaleGeom.class.getSimpleName(), + $("geom"), + $("factor"), + $("origin")))) + .getField(0); + expected = "POLYGON ((1 0, 5 0, 3 3, 5 6, 1 6, 3 3, 1 0))"; + assertEquals(expected, actual.toString()); + } + @Test public void testRotateX() { Table tbl = diff --git a/python/sedona/sql/st_functions.py b/python/sedona/sql/st_functions.py index 0d66725a5e..e766c33fdd 100644 --- a/python/sedona/sql/st_functions.py +++ b/python/sedona/sql/st_functions.py @@ -2298,6 +2298,40 @@ def ST_IsCollection(geometry: ColumnOrName) -> Column: return _call_st_function("ST_IsCollection", geometry) +@validate_argument_types +def ST_Scale( + geometry: ColumnOrName, + scaleX: Union[ColumnOrNameOrNumber, float], + scaleY: Union[ColumnOrNameOrNumber, float], +) -> Column: + """Scale geometry with X and Y axis. + + @param geometry: + @param scaleX: + @param scaleY: + @return: + """ + return _call_st_function("ST_Scale", (geometry, scaleX, scaleY)) + + +@validate_argument_types +def ST_ScaleGeom( + geometry: ColumnOrName, factor: ColumnOrName, origin: Optional[ColumnOrName] = None +) -> Column: + """Scale geometry with the corodinates of factor geometry + + @param geometry: + @param factor: + @param origin: + @return: + """ + if origin is not None: + args = (geometry, factor, origin) + else: + args = (geometry, factor) + return _call_st_function("ST_ScaleGeom", args) + + @validate_argument_types def ST_RotateX(geometry: ColumnOrName, angle: Union[ColumnOrName, float]) -> Column: """Returns geometry rotated by the given angle in X axis diff --git a/python/tests/sql/test_dataframe_api.py b/python/tests/sql/test_dataframe_api.py index 2af7e33e91..3d40267a84 100644 --- a/python/tests/sql/test_dataframe_api.py +++ b/python/tests/sql/test_dataframe_api.py @@ -845,6 +845,20 @@ "", "LINESTRING (5 0, 4 0, 3 0, 2 0, 1 0, 0 0)", ), + ( + stf.ST_Scale, + ("poly", 3, 2), + "poly_and_point", + "", + "POLYGON ((0 0, 0 2, 3 2, 3 0, 0 0))", + ), + ( + stf.ST_ScaleGeom, + ("poly", "point"), + "poly_and_point", + "", + "POLYGON ((0 0, 0 2, 3 2, 3 0, 0 0))", + ), ( stf.ST_RotateX, ("line", 10.0), @@ -1316,6 +1330,8 @@ (stf.ST_RemovePoint, ("", 1.0)), (stf.ST_RemoveRepeatedPoints, (None, None)), (stf.ST_Reverse, (None,)), + (stf.ST_Scale, (None, None, None)), + (stf.ST_ScaleGeom, (None, None, None)), ( stf.ST_Rotate, ( @@ -1592,6 +1608,10 @@ def base_df(self, request): return TestDataFrameAPI.spark.sql( "SELECT array(ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))'), ST_GeomFromWKT('POLYGON ((-2 1, 2 1, 2 4, -2 4, -2 1))')) as polys" ) + elif request.param == "poly_and_point": + return TestDataFrameAPI.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))') AS poly, ST_GeomFromWKT('POINT (3 2)') AS point" + ) elif request.param == "poly_and_line": return TestDataFrameAPI.spark.sql( "SELECT ST_GeomFromWKT('POLYGON((2.6 12.5, 2.6 20.0, 12.6 20.0, 12.6 12.5, 2.6 12.5 ))') as poly, ST_GeomFromWKT('LINESTRING (0.5 10.7, 5.4 8.4, 10.1 10.0)') as line" diff --git a/python/tests/sql/test_function.py b/python/tests/sql/test_function.py index 367e0e0554..b83343a0ae 100644 --- a/python/tests/sql/test_function.py +++ b/python/tests/sql/test_function.py @@ -1345,6 +1345,32 @@ def test_st_add_point(self): ] assert collected_geometries[0] == "LINESTRING (0 0, 1 1, 1 0, 21 52)" + def test_st_scale(self): + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('LINESTRING (50 160, 50 50, 100 50)') AS geom" + ) + actual = baseDf.selectExpr("ST_AsText(ST_Scale(geom, -10, 5))").first()[0] + expected = "LINESTRING (-500 800, -500 250, -1000 250)" + assert expected == actual + + def test_st_scalegeom(self): + baseDf = self.spark.sql( + "SELECT ST_GeomFromWKT('POLYGON ((0 0, 0 1.5, 1.5 1.5, 1.5 0, 0 0))') AS geometry, ST_GeomFromWKT('POINT (1.8 2.1)') AS factor, ST_GeomFromWKT('POINT (0.32959 0.796483)') AS origin" + ) + actual = baseDf.selectExpr("ST_AsText(ST_ScaleGeom(geometry, factor))").first()[ + 0 + ] + expected = ( + "POLYGON ((0 0, 0 3.1500000000000004, 2.7 3.1500000000000004, 2.7 0, 0 0))" + ) + assert expected == actual + + actual = baseDf.selectExpr( + "ST_AsText(ST_ScaleGeom(geometry, factor, origin))" + ).first()[0] + expected = "POLYGON ((-0.263672 -0.8761313000000002, -0.263672 2.2738687000000004, 2.436328 2.2738687000000004, 2.436328 -0.8761313000000002, -0.263672 -0.8761313000000002))" + assert expected == actual + def test_st_rotate_x(self): baseDf = self.spark.sql( "SELECT ST_GeomFromWKT('LINESTRING (50 160, 50 50, 100 50)') as geom1, ST_GeomFromWKT('LINESTRING(1 2 3, 1 1 1)') AS geom2" diff --git a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java index ae970d591c..f4271b8fa7 100644 --- a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java +++ b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctions.java @@ -1237,6 +1237,27 @@ public void test_ST_Translate() { "GEOMETRYCOLLECTION Z(MULTIPOLYGON Z(((3 2 3, 3 3 3, 4 3 3, 4 2 3, 3 2 3)), ((3 4 3, 5 6 3, 5 7 3, 3 4 3))), POINT Z(3 3 4), LINESTRING ZEMPTY)"); } + @Test + public void test_ST_Scale() { + registerUDF("ST_Scale", byte[].class, double.class, double.class); + verifySqlSingleRes( + "SELECT sedona.ST_AsText(sedona.ST_Scale(sedona.ST_GeomFromWKT('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), 3, 2))", + "POLYGON ((0 0, 0 2, 3 2, 3 0, 0 0))"); + } + + @Test + public void test_ST_ScaleGeom() { + registerUDF("ST_ScaleGeom", byte[].class, byte[].class, byte[].class); + verifySqlSingleRes( + "SELECT sedona.ST_AsText(sedona.ST_ScaleGeom(sedona.ST_GeomFromWKT('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), sedona.ST_Point(3, 2), sedona.ST_Point(1, 2)))", + "POLYGON ((-2 -2, -2 0, 1 0, 1 -2, -2 -2))"); + + registerUDF("ST_ScaleGeom", byte[].class, byte[].class); + verifySqlSingleRes( + "SELECT sedona.ST_AsText(sedona.ST_ScaleGeom(sedona.ST_GeomFromWKT('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), sedona.ST_Point(3, 2)))", + "POLYGON ((0 0, 0 2, 3 2, 3 0, 0 0))"); + } + @Test public void test_ST_RotateX() { registerUDF("ST_RotateX", byte[].class, double.class); diff --git a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java index f64d445958..33f5ba01d8 100644 --- a/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java +++ b/snowflake-tester/src/test/java/org/apache/sedona/snowflake/snowsql/TestFunctionsV2.java @@ -1192,6 +1192,27 @@ public void test_ST_Translate() { "POINT(2 5)"); } + @Test + public void test_ST_Scale() { + registerUDFV2("ST_Scale", String.class, double.class, double.class); + verifySqlSingleRes( + "SELECT ST_AsText(sedona.ST_Scale(ST_GeometryFromWKT('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), 3, 2))", + "POLYGON((0 0,0 2,3 2,3 0,0 0))"); + } + + @Test + public void test_ST_ScaleGeom() { + registerUDFV2("ST_ScaleGeom", String.class, String.class, String.class); + verifySqlSingleRes( + "SELECT ST_AsText(sedona.ST_ScaleGeom(ST_GeometryFromWKT('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), ST_Point(3, 2), ST_Point(1, 2)))", + "POLYGON((-2 -2,-2 0,1 0,1 -2,-2 -2))"); + + registerUDFV2("ST_ScaleGeom", String.class, String.class); + verifySqlSingleRes( + "SELECT ST_AsText(sedona.ST_ScaleGeom(ST_GeometryFromWKT('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), ST_Point(3, 2)))", + "POLYGON((0 0,0 2,3 2,3 0,0 0))"); + } + @Test public void test_ST_RotateX() { registerUDFV2("ST_RotateX", String.class, double.class); diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java index 761204ab66..83c075cd6b 100644 --- a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFs.java @@ -1290,6 +1290,28 @@ public static byte[] ST_Translate(byte[] geom, double deltaX, double deltaY, dou Functions.translate(GeometrySerde.deserialize(geom), deltaX, deltaY, deltaZ)); } + @UDFAnnotations.ParamMeta(argNames = {"geometry", "scaleX", "scaleY"}) + public static byte[] ST_Scale(byte[] geometry, double scaleX, double scaleY) { + return GeometrySerde.serialize( + Functions.scale(GeometrySerde.deserialize(geometry), scaleX, scaleY)); + } + + @UDFAnnotations.ParamMeta(argNames = {"geometry", "factor", "origin"}) + public static byte[] ST_ScaleGeom(byte[] geometry, byte[] factor, byte[] origin) { + return GeometrySerde.serialize( + Functions.scaleGeom( + GeometrySerde.deserialize(geometry), + GeometrySerde.deserialize(factor), + GeometrySerde.deserialize(origin))); + } + + @UDFAnnotations.ParamMeta(argNames = {"geometry", "factor"}) + public static byte[] ST_ScaleGeom(byte[] geometry, byte[] factor) { + return GeometrySerde.serialize( + Functions.scaleGeom( + GeometrySerde.deserialize(geometry), GeometrySerde.deserialize(factor))); + } + @UDFAnnotations.ParamMeta(argNames = {"geometry", "angle"}) public static byte[] ST_RotateX(byte[] geometry, double angle) { return GeometrySerde.serialize(Functions.rotateX(GeometrySerde.deserialize(geometry), angle)); diff --git a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java index 17c099eab6..a645f87836 100644 --- a/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java +++ b/snowflake/src/main/java/org/apache/sedona/snowflake/snowsql/UDFsV2.java @@ -1535,6 +1535,37 @@ public static String ST_Translate(String geom, double deltaX, double deltaY, dou Functions.translate(GeometrySerde.deserGeoJson(geom), deltaX, deltaY, deltaZ)); } + @UDFAnnotations.ParamMeta( + argNames = {"geometry", "scaleX", "scaleY"}, + argTypes = {"Geometry", "double", "double"}, + returnTypes = "Geometry") + public static String ST_Scale(String geometry, double scaleX, double scaleY) { + return GeometrySerde.serGeoJson( + Functions.scale(GeometrySerde.deserGeoJson(geometry), scaleX, scaleY)); + } + + @UDFAnnotations.ParamMeta( + argNames = {"geometry", "factor", "origin"}, + argTypes = {"Geometry", "Geometry", "Geometry"}, + returnTypes = "Geometry") + public static String ST_ScaleGeom(String geometry, String factor, String origin) { + return GeometrySerde.serGeoJson( + Functions.scaleGeom( + GeometrySerde.deserGeoJson(geometry), + GeometrySerde.deserGeoJson(factor), + GeometrySerde.deserGeoJson(origin))); + } + + @UDFAnnotations.ParamMeta( + argNames = {"geometry", "factor"}, + argTypes = {"Geometry", "Geometry"}, + returnTypes = "Geometry") + public static String ST_ScaleGeom(String geometry, String factor) { + return GeometrySerde.serGeoJson( + Functions.scaleGeom( + GeometrySerde.deserGeoJson(geometry), GeometrySerde.deserGeoJson(factor))); + } + @UDFAnnotations.ParamMeta( argNames = {"geometry", "angle"}, argTypes = {"Geometry", "double"}, diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index 6978b5162d..f2ff868b45 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -234,6 +234,8 @@ object Catalog { function[ST_HausdorffDistance](-1), function[ST_DWithin](), function[ST_IsValidReason](), + function[ST_Scale](), + function[ST_ScaleGeom](), function[ST_Rotate](), function[ST_RotateX](), function[ST_RotateY](), diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 725796ee60..b3bf973b9e 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -1726,6 +1726,22 @@ case class ST_IsValidReason(inputExpressions: Seq[Expression]) copy(inputExpressions = newChildren) } +case class ST_Scale(inputExpressions: Seq[Expression]) + extends InferredExpression(inferrableFunction3(Functions.scale)) { + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = + copy(inputExpressions = newChildren) +} + +case class ST_ScaleGeom(inputExpressions: Seq[Expression]) + extends InferredExpression( + inferrableFunction3(Functions.scaleGeom), + inferrableFunction2(Functions.scaleGeom)) { + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = + copy(inputExpressions = newChildren) +} + case class ST_RotateX(inputExpressions: Seq[Expression]) extends InferredExpression(inferrableFunction2(Functions.rotateX)) { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala index 33e6760e24..56a2227174 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala @@ -496,6 +496,23 @@ object st_functions extends DataFrameAPI { def ST_Reverse(geometry: Column): Column = wrapExpression[ST_Reverse](geometry) def ST_Reverse(geometry: String): Column = wrapExpression[ST_Reverse](geometry) + def ST_Scale(geometry: Column, scaleX: Column, scaleY: Column): Column = + wrapExpression[ST_Scale](geometry, scaleX, scaleY) + def ST_Scale(geometry: String, scaleX: Double, scaleY: Double): Column = + wrapExpression[ST_Scale](geometry, scaleX, scaleY) + def ST_Scale(geometry: String, scaleX: String, scaleY: String): Column = + wrapExpression[ST_Scale](geometry, scaleX, scaleY) + + def ST_ScaleGeom(geometry: Column, factor: Column): Column = + wrapExpression[ST_ScaleGeom](geometry, factor) + def ST_ScaleGeom(geometry: String, factor: String): Column = + wrapExpression[ST_ScaleGeom](geometry, factor) + + def ST_ScaleGeom(geometry: Column, factor: Column, origin: Column): Column = + wrapExpression[ST_ScaleGeom](geometry, factor, origin) + def ST_ScaleGeom(geometry: String, factor: String, origin: String): Column = + wrapExpression[ST_ScaleGeom](geometry, factor, origin) + def ST_RotateX(geometry: Column, angle: Column): Column = wrapExpression[ST_RotateX](geometry, angle) def ST_RotateX(geometry: String, angle: Double): Column = diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala index c67525427f..ea6092629f 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/PreserveSRIDSuite.scala @@ -112,6 +112,8 @@ class PreserveSRIDSuite extends TestBaseScala with TableDrivenPropertyChecks { ("ST_Affine(geom1, 1, 2, 1, 2, 1, 2)", 1000), ("ST_BoundingDiagonal(geom1)", 1000), ("ST_DelaunayTriangles(geom4)", 1000), + ("ST_Scale(geom1, 1, 2)", 1000), + ("ST_ScaleGeom(geom1, geom6)", 1000), ("ST_Rotate(geom1, 10)", 1000), ("ST_RotateX(geom1, 10)", 1000), ("ST_Collect(geom1, geom2, geom3)", 1000), @@ -141,7 +143,8 @@ class PreserveSRIDSuite extends TestBaseScala with TableDrivenPropertyChecks { StructField("geom2", GeometryUDT), StructField("geom3", GeometryUDT), StructField("geom4", GeometryUDT), - StructField("geom5", GeometryUDT))) + StructField("geom5", GeometryUDT), + StructField("geom6", GeometryUDT))) val geom1 = Constructors.geomFromWKT("POLYGON ((0 0, 1 0, 0.5 0.5, 1 1, 0 1, 0 0))", 1000) val geom2 = Constructors.geomFromWKT("MULTILINESTRING ((0 0, 0 1), (0 1, 1 1), (1 1, 1 0))", 1000) @@ -155,7 +158,8 @@ class PreserveSRIDSuite extends TestBaseScala with TableDrivenPropertyChecks { |LINESTRING (2 2, 3 2, 4 2), LINESTRING (0 2, 1 3, 2 4), |LINESTRING (2 4, 3 3, 4 2))""".stripMargin, 1000) - val rows = Seq(Row(geom1, geom2, geom3, geom4, geom5)) + val geom6 = Constructors.geomFromWKT("POINT (1 2)", 1000) + val rows = Seq(Row(geom1, geom2, geom3, geom4, geom5, geom6)) sparkSession.createDataFrame(rows.asJava, schema) } } diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala index 7daee87f05..0f60682217 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala @@ -2286,6 +2286,28 @@ class dataFrameAPITestScala extends TestBaseScala { assertEquals("SRID=4326;POINT ZM(1 2 3 100)", point2) } + it("Should pass ST_Scale") { + val baseDf = + sparkSession.sql("SELECT ST_GeomFromWKT('LINESTRING (50 160, 50 50, 100 50)') AS geom") + val actual = baseDf.select(ST_AsText(ST_Scale("geom", -10, 5))).first().get(0) + val expected = "LINESTRING (-500 800, -500 250, -1000 250)" + assertEquals(expected, actual) + } + + it("Should pass ST_ScaleGeom") { + val baseDf = sparkSession.sql( + "SELECT ST_GeomFromWKT('POLYGON ((0 0, 0 1.5, 1.5 1.5, 1.5 0, 0 0))') AS geometry, ST_GeomFromWKT('POINT (1.8 2.1)') AS factor, ST_GeomFromWKT('POINT (0.32959 0.796483)') AS origin") + var actual = baseDf.select(ST_AsText(ST_ScaleGeom("geometry", "factor"))).first().get(0) + var expected = "POLYGON ((0 0, 0 3.1500000000000004, 2.7 3.1500000000000004, 2.7 0, 0 0))" + assertEquals(expected, actual) + + actual = + baseDf.select(ST_AsText(ST_ScaleGeom("geometry", "factor", "origin"))).first().get(0) + expected = + "POLYGON ((-0.263672 -0.8761313000000002, -0.263672 2.2738687000000004, 2.436328 2.2738687000000004, 2.436328 -0.8761313000000002, -0.263672 -0.8761313000000002))" + assertEquals(expected, actual) + } + it("Should pass ST_RotateX") { val geomTestCases = Map( ( diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala index f18c4581fb..e40fcd5b59 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala @@ -3454,6 +3454,27 @@ class functionTestScala } } + it("Should pass ST_Scale") { + val baseDf = + sparkSession.sql("SELECT ST_GeomFromWKT('LINESTRING (50 160, 50 50, 100 50)') AS geom") + val actual = baseDf.selectExpr("ST_AsText(ST_Scale(geom, -10, 5))").first().get(0) + val expected = "LINESTRING (-500 800, -500 250, -1000 250)" + assertEquals(expected, actual) + } + + it("Should pass ST_ScaleGeom") { + val baseDf = sparkSession.sql( + "SELECT ST_GeomFromWKT('POLYGON ((0 0, 0 1.5, 1.5 1.5, 1.5 0, 0 0))') AS geometry, ST_GeomFromWKT('POINT (1.8 2.1)') AS factor, ST_GeomFromWKT('POINT (0.32959 0.796483)') AS origin") + var actual = baseDf.selectExpr("ST_AsText(ST_ScaleGeom(geometry, factor))").first().get(0) + var expected = "POLYGON ((0 0, 0 3.1500000000000004, 2.7 3.1500000000000004, 2.7 0, 0 0))" + assertEquals(expected, actual) + + actual = baseDf.selectExpr("ST_AsText(ST_ScaleGeom(geometry, factor, origin))").first().get(0) + expected = + "POLYGON ((-0.263672 -0.8761313000000002, -0.263672 2.2738687000000004, 2.436328 2.2738687000000004, 2.436328 -0.8761313000000002, -0.263672 -0.8761313000000002))" + assertEquals(expected, actual) + } + it("Should pass ST_RotateX") { val geomTestCases = Map( (