Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snowpark XGBRegressor Ignores Sample Weights, Producing Identical Predictions for Different Models #111

Open
robertlessmore opened this issue Jul 29, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@robertlessmore
Copy link

  1. What version of Python are you using?

Python 3.11.8 | packaged by Anaconda, Inc. | (main, Feb 26 2024, 21:34:05) [MSC v.1916 64 bit (AMD64)]

What operating system and processor architecture are you using?
Windows-10-10.0.22631-SP0

  1. What are the component versions in the environment?

_py-xgboost-mutex 2.0 cpu_0
abseil-cpp 20220623.0 h0e60522_0 conda-forge
absl-py 1.4.0 py311haa95532_0
aiobotocore 2.7.0 py311haa95532_0
aiohttp 3.9.3 py311h2bbff1b_0
aioitertools 0.7.1 pyhd3eb1b0_0
aiosignal 1.2.0 pyhd3eb1b0_0
alembic 1.8.1 py311haa95532_0
anyio 3.5.0 py311haa95532_0
appdirs 1.4.4 pyhd3eb1b0_0
argon2-cffi 21.3.0 pyhd3eb1b0_0
argon2-cffi-bindings 21.2.0 py311h2bbff1b_0
arrow-cpp 10.0.1 h9c18f36_4_cpu conda-forge
asn1crypto 1.5.1 py311haa95532_0
asttokens 2.0.5 pyhd3eb1b0_0
async-lru 2.0.4 py311haa95532_0
attrs 23.1.0 py311haa95532_0
aws-c-auth 0.6.19 h2bbff1b_0
aws-c-cal 0.5.20 h2bbff1b_0
aws-c-common 0.8.5 h2bbff1b_0
aws-c-compression 0.2.16 h2bbff1b_0
aws-c-event-stream 0.2.15 hd77b12b_0
aws-c-http 0.6.25 h2bbff1b_0
aws-c-io 0.13.10 h2bbff1b_0
aws-c-mqtt 0.7.13 h2bbff1b_0
aws-c-s3 0.1.51 h2bbff1b_0
aws-c-sdkutils 0.1.6 h2bbff1b_0
aws-checksums 0.1.13 h2bbff1b_0
aws-crt-cpp 0.18.16 hd77b12b_0
aws-sdk-cpp 1.9.379 h2768dcf_5 conda-forge
babel 2.11.0 py311haa95532_0
beautifulsoup4 4.12.2 py311haa95532_0
blas 1.0 mkl
bleach 4.1.0 pyhd3eb1b0_0
blinker 1.6.2 py311haa95532_0
boost-cpp 1.82.0 h59b6b97_2
botocore 1.31.64 py311haa95532_0
bottleneck 1.3.7 py311hd7041d2_0
brotli 1.0.9 h2bbff1b_7
brotli-bin 1.0.9 h2bbff1b_7
brotli-python 1.0.9 py311hd77b12b_7
bzip2 1.0.8 h2bbff1b_5
c-ares 1.19.1 h2bbff1b_0
ca-certificates 2024.7.2 haa95532_0 https://repo.anaconda.com/pkgs/snowflake
cachetools 4.2.2 pyhd3eb1b0_0
certifi 2024.7.4 py311haa95532_0 https://repo.anaconda.com/pkgs/snowflake
cffi 1.16.0 py311h2bbff1b_0
charset-normalizer 2.0.4 pyhd3eb1b0_0
click 8.1.7 py311haa95532_0
cloudpickle 2.2.1 py311haa95532_0
colorama 0.4.6 py311haa95532_0
comm 0.2.1 py311haa95532_0
contourpy 1.2.0 py311h59b6b97_0
cryptography 41.0.7 py311h89fc84f_0
cycler 0.11.0 pyhd3eb1b0_0
databricks-cli 0.17.6 py311haa95532_1
debugpy 1.6.7 py311hd77b12b_0
decorator 5.1.1 pyhd3eb1b0_0
defusedxml 0.7.1 pyhd3eb1b0_0
docker-py 4.4.1 py311haa95532_5
docker-pycreds 0.4.0 pyhd3eb1b0_0
entrypoints 0.4 py311haa95532_0
executing 0.8.3 pyhd3eb1b0_0
filelock 3.13.1 py311haa95532_0
flask 2.2.5 py311haa95532_0
fonttools 4.25.0 pyhd3eb1b0_0
freetype 2.12.1 ha860e81_0
frozenlist 1.4.0 py311h2bbff1b_0
fsspec 2023.10.0 py311haa95532_0
gflags 2.2.2 hd77b12b_1
gitdb 4.0.7 pyhd3eb1b0_0
gitpython 3.1.37 py311haa95532_0
glog 0.6.0 h4797de2_0 conda-forge
greenlet 3.0.1 py311hd77b12b_0
grpc-cpp 1.51.1 h9c18f36_0 conda-forge
icc_rt 2022.1.0 h6049295_2
icu 73.1 h6c2663c_0
idna 3.4 py311haa95532_0
importlib-metadata 6.0.0 py311haa95532_0
importlib_resources 6.1.1 py311haa95532_1
intel-openmp 2023.1.0 h59b6b97_46320
ipykernel 6.28.0 py311haa95532_0
ipython 8.20.0 py311haa95532_0
ipywidgets 8.1.2 py311haa95532_0
itsdangerous 2.0.1 pyhd3eb1b0_0
jedi 0.18.1 py311haa95532_1
jinja2 3.1.3 py311haa95532_0
jmespath 1.0.1 py311haa95532_0
joblib 1.2.0 py311haa95532_0
jpeg 9e h2bbff1b_1
json5 0.9.6 pyhd3eb1b0_0
jsonschema 4.19.2 py311haa95532_0
jsonschema-specifications 2023.7.1 py311haa95532_0
jupyter 1.0.0 py311haa95532_9
jupyter-lsp 2.2.0 py311haa95532_0
jupyter_client 8.6.0 py311haa95532_0
jupyter_console 6.6.3 py311haa95532_0
jupyter_core 5.5.0 py311haa95532_0
jupyter_events 0.8.0 py311haa95532_0
jupyter_server 2.10.0 py311haa95532_0
jupyter_server_terminals 0.4.4 py311haa95532_1
jupyterlab 4.0.11 py311haa95532_0
jupyterlab_pygments 0.1.2 py_0
jupyterlab_server 2.25.1 py311haa95532_0
jupyterlab_widgets 3.0.10 py311haa95532_0
kiwisolver 1.4.4 py311hd77b12b_0
krb5 1.20.1 h5b6d351_0
lerc 3.0 hd77b12b_0
libabseil 20220623.0 cxx17_h1a56200_6 conda-forge
libarrow 10.0.1 h226723c_4_cpu conda-forge
libboost 1.82.0 h3399ecb_2
libbrotlicommon 1.0.9 h2bbff1b_7
libbrotlidec 1.0.9 h2bbff1b_7
libbrotlienc 1.0.9 h2bbff1b_7
libclang 14.0.6 default_hb5a9fac_1
libclang13 14.0.6 default_h8e68704_1
libcrc32c 1.1.2 hd77b12b_0
libcurl 8.5.0 h86230a5_0
libdeflate 1.17 h2bbff1b_1
libevent 2.1.12 h56d1f94_1
libffi 3.4.4 hd77b12b_0
libgoogle-cloud 2.5.0 h5fc25aa_1 conda-forge
libgrpc 1.51.1 h6a6baca_0 conda-forge
libpng 1.6.39 h8cc25b3_0
libpq 12.17 h906ac69_0
libprotobuf 3.21.12 h12be248_2 conda-forge
libsodium 1.0.18 h62dcd97_0
libssh2 1.10.0 he2ea4bf_2
libthrift 0.16.0 h9ce19ad_2 conda-forge
libtiff 4.5.1 hd77b12b_0
libutf8proc 2.8.0 h82a8f57_0 conda-forge
libwebp-base 1.3.2 h2bbff1b_0
libxgboost 1.7.3 hd77b12b_0
libzlib 1.2.13 hcfcfb64_5 conda-forge
lz4-c 1.9.4 h2bbff1b_0
mako 1.2.3 py311haa95532_0
markdown 3.4.1 py311haa95532_0
markupsafe 2.1.3 py311h2bbff1b_0
matplotlib-base 3.8.0 py311hf62ec03_0
matplotlib-inline 0.1.6 py311haa95532_0
mistune 2.0.4 py311haa95532_0
mkl 2023.1.0 h6b88ed4_46358
mkl-service 2.4.0 py311h2bbff1b_1
mkl_fft 1.3.8 py311h2bbff1b_0
mkl_random 1.2.4 py311h59b6b97_0
mlflow 2.3.1 py311hd1fac3c_0
multidict 6.0.4 py311h2bbff1b_0
munkres 1.1.4 py_0
nbclient 0.8.0 py311haa95532_0
nbconvert 7.10.0 py311haa95532_0
nbformat 5.9.2 py311haa95532_0
nest-asyncio 1.6.0 py311haa95532_0
notebook 7.0.8 py311haa95532_0
notebook-shim 0.2.3 py311haa95532_0
numexpr 2.8.7 py311h1fcbade_0
numpy 1.26.4 py311hdab7c0b_0
numpy-base 1.26.4 py311hd01c5d8_0
oauthlib 3.2.2 py311haa95532_0
openjpeg 2.4.0 h4fc8c34_0
openssl 3.3.0 hcfcfb64_0 conda-forge
orc 1.9.0 hada7b9e_1 conda-forge
overrides 7.4.0 py311haa95532_0
packaging 23.1 py311haa95532_0
pandas 1.5.3 py311heda8569_0
pandocfilters 1.5.0 pyhd3eb1b0_0
parso 0.8.3 pyhd3eb1b0_0
patsy 0.5.6 pyhd8ed1ab_0 conda-forge
pillow 10.2.0 py311h2bbff1b_0
pip 23.3.1 py311haa95532_0
platformdirs 3.10.0 py311haa95532_0
ply 3.11 py311haa95532_0
prometheus_client 0.14.1 py311haa95532_0
prompt-toolkit 3.0.43 py311haa95532_0
prompt_toolkit 3.0.43 hd3eb1b0_0
protobuf 4.21.12 py311h12c1d0e_0 conda-forge
psutil 5.9.0 py311h2bbff1b_0
pure_eval 0.2.2 pyhd3eb1b0_0
py-xgboost 1.7.3 py311haa95532_0
pyarrow 10.0.1 py311h8a3a540_0
pycparser 2.21 pyhd3eb1b0_0
pygments 2.15.1 py311haa95532_1
pyjwt 2.4.0 py311haa95532_0
pyopenssl 23.2.0 py311haa95532_0
pyparsing 3.0.9 py311haa95532_0
pyqt 5.15.10 py311hd77b12b_0
pyqt5-sip 12.13.0 py311h2bbff1b_0
pysocks 1.7.1 py311haa95532_0
python 3.11.8 he1021f5_0
python-dateutil 2.8.3+snowflake1 py311haa95532_1 https://repo.anaconda.com/pkgs/snowflake
python-fastjsonschema 2.16.2 py311haa95532_0
python-json-logger 2.0.7 py311haa95532_0
python_abi 3.11 2_cp311 conda-forge
pytimeparse 1.1.8 py311haa95532_0
pytz 2023.3.post1 py311haa95532_0
pywin32 305 py311h2bbff1b_0
pywinpty 2.0.10 py311h5da7b33_0
pyyaml 6.0.1 py311h2bbff1b_0
pyzmq 25.1.2 py311hd77b12b_0
qt-main 5.15.2 h19c9488_10
qtconsole 5.5.1 py311haa95532_0
qtpy 2.4.1 py311haa95532_0
querystring_parser 1.2.4 py311haa95532_0
re2 2022.06.01 h0e60522_1 conda-forge
referencing 0.30.2 py311haa95532_0
requests 2.31.0 py311haa95532_1
retrying 1.3.3 pyhd3eb1b0_2
rfc3339-validator 0.1.4 py311haa95532_0
rfc3986-validator 0.1.1 py311haa95532_0
rpds-py 0.10.6 py311h062c2fa_0
s3fs 2023.10.0 py311haa95532_0
scikit-learn 1.2.2 py311hd77b12b_1
scipy 1.11.4 py311hc1ccb85_0
seaborn 0.13.2 hd8ed1ab_2 conda-forge
seaborn-base 0.13.2 pyhd8ed1ab_2 conda-forge
send2trash 1.8.2 py311haa95532_0
setuptools 68.2.2 py311haa95532_0
sip 6.7.12 py311hd77b12b_0
six 1.16.0 pyhd3eb1b0_1
smmap 4.0.0 pyhd3eb1b0_0
snappy 1.1.10 h6c2663c_1
sniffio 1.3.0 py311haa95532_0
snowflake-connector-python 3.7.0 py311hd77b12b_0
snowflake-ml-python 1.4.0 pypy_0 https://raw.githubusercontent.com/snowflakedb/snowflake-ml-python/conda/releases
snowflake-snowpark-python 1.13.0 py311haa95532_0
sortedcontainers 2.4.0 pyhd3eb1b0_0
soupsieve 2.5 py311haa95532_0
sqlalchemy 2.0.25 py311h2bbff1b_0
sqlite 3.41.2 h2bbff1b_0
sqlparse 0.4.4 py311haa95532_0
stack_data 0.2.0 pyhd3eb1b0_0
statsmodels 0.14.2 py311h0a17f05_0 conda-forge
tabulate 0.9.0 py311haa95532_0
tbb 2021.8.0 h59b6b97_0
terminado 0.17.1 py311haa95532_0
threadpoolctl 2.2.0 pyh0d69192_0
tinycss2 1.2.1 py311haa95532_0
tk 8.6.12 h2bbff1b_0
tomlkit 0.11.1 py311haa95532_0
tornado 6.3.3 py311h2bbff1b_0
traitlets 5.7.1 py311haa95532_0
typing-extensions 4.9.0 py311haa95532_1
typing_extensions 4.9.0 py311haa95532_1
tzdata 2024a h04d1e81_0
ucrt 10.0.20348.0 haa95532_0
urllib3 2.0.7 py311haa95532_0
utf8proc 2.6.1 h2bbff1b_1
vc 14.2 h21ff451_1
vc14_runtime 14.38.33130 h82b7239_18 conda-forge
vs2015_runtime 14.38.33130 hcb4865c_18 conda-forge
waitress 2.0.0 pyhd3eb1b0_0
wcwidth 0.2.5 pyhd3eb1b0_0
webencodings 0.5.1 py311haa95532_1
websocket-client 0.58.0 py311haa95532_4
werkzeug 2.3.8 py311haa95532_0
wheel 0.41.2 py311haa95532_0
widgetsnbextension 4.0.10 py311haa95532_0
win_inet_pton 1.1.0 py311haa95532_0
winpty 0.4.3 4
wrapt 1.14.1 py311h2bbff1b_0
xgboost 1.7.3 py311haa95532_0
xz 5.4.6 h8cc25b3_0
yaml 0.2.5 he774522_0
yarl 1.9.3 py311h2bbff1b_0
zeromq 4.3.5 hd77b12b_0
zipp 3.17.0 py311haa95532_0
zlib 1.2.13 hcfcfb64_5 conda-forge
zstd 1.5.5 hd43e919_0

  1. What did you do?
    from snowflake.ml.modeling.xgboost import XGBRegressor
    from snowflake.snowpark.functions import col, random, sin, when, lit
    from utils import get_session

session = get_session.session()

N = 105
_ONE_MILLION = 10
6

df = session.range(1, N).to_df("ind").with_column(
"x_0", ((random() % _ONE_MILLION)/_ONE_MILLION)
)

df = df

df = df.with_columns(["weights1","weights2","weights3"],[lit(1.0),when(col("ind") < lit(N / 10), 1.0).otherwise(0.0),when(col("ind") > lit(N / 10), 1.0).otherwise(0.0)])

df = df.with_column(
"target",
when(col("ind") < lit(N / 10), 1.0).otherwise(0.0) * col("x_0") +
when(col("ind") > lit(N / 10), 1.0).otherwise(0.0) * sin(10*col("x_0"))
)

parameters = {
"input_cols":["X_0"],
"label_cols":["TARGET"],
}

model1 = XGBRegressor(
**parameters,
sample_weight_col="weights1",
output_cols= ["PREDICTION1"],

)
model2 = XGBRegressor(
**parameters,
sample_weight_col="weights2",
output_cols= ["PREDICTION2"],

)
model3 = XGBRegressor(
**parameters,
sample_weight_col="weights3",
output_cols= ["PREDICTION3"],

)

models = [model1, model2, model3]
for m in models:
m.fit(df)

test = session.range(-1, 1,0.01).to_df("X_0").with_column(
"sinus",
sin(10*col("X_0"))
)

for m in models:
test = m.predict(test)

test_snow = test.toPandas()
print(test_snow)


output:
X_0 SINUS PREDICTION1 PREDICTION2 PREDICTION3
0 -1.00 0.544021 0.515664 0.515664 0.515664
1 -0.99 0.457536 0.405519 0.405519 0.405519
2 -0.98 0.366479 0.183660 0.183660 0.183660
3 -0.97 0.271761 0.211220 0.211220 0.211220
4 -0.96 0.174327 0.039056 0.039056 0.039056
.. ... ... ... ... ...
195 0.95 -0.075151 0.047328 0.047328 0.047328
196 0.96 -0.174327 -0.060364 -0.060364 -0.060364
197 0.97 -0.271761 0.034832 0.034832 0.034832
198 0.98 -0.366479 -0.278535 -0.278535 -0.278535
199 0.99 -0.457536 -0.390598 -0.390598 -0.390598

  1. What did you expect to see?
    I expected different models to produce different predictions due to the varying sample weights (weights1, weights2, weights3). Specifically:
  • PREDICTION1 should reflect a model trained on the entire dataset equally.
  • PREDICTION2 should reflect a model influenced more by the first 10,000 samples, which follow a linear pattern.
  • PREDICTION3 should reflect a model influenced more by the samples beyond 10,000, which follow a sinusoidal pattern.

However, the Snowflake Snowpark implementation of XGBRegressor seems to ignore the sample weights, resulting in identical predictions for all models. Running a similar experiment directly with the standard xgboost library outside of Snowflake results in distinct linear and sinusoidal predictions for model2 and model3, respectively.

@sfc-gh-afero
Copy link

Thank you for reporting this issue, I was able to use your example to reproduce it on my end. We will investigate this issue as a bug.

@sfc-gh-afero sfc-gh-afero added the bug Something isn't working label Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants