From 42b440b4df38ce823b3f01e7c5efdc727aa91b1d Mon Sep 17 00:00:00 2001 From: Lilli Freischem Date: Tue, 14 May 2024 11:17:48 +0000 Subject: [PATCH] further dev testing for normalisation function and training --- notebooks/2.0-test-normalization.ipynb | 1302 ++++++++++++++++++++---- 1 file changed, 1094 insertions(+), 208 deletions(-) diff --git a/notebooks/2.0-test-normalization.ipynb b/notebooks/2.0-test-normalization.ipynb index d703dce..cd88817 100644 --- a/notebooks/2.0-test-normalization.ipynb +++ b/notebooks/2.0-test-normalization.ipynb @@ -27,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -98,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -108,7 +108,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -117,104 +117,614 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 64, + "metadata": {}, + "outputs": [], + "source": [ + "msg_time = xr.open_dataset(msg_files[1])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[1.64, 3.92, 8.7, 9.66, 10.8, 12.0, 13.4, 0.64, 0.81, 6.25, 7.35]" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(msg_time.band_wavelength.values)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "ds_goes = normalize(goes_files[:10])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "msg_files = os.listdir('/mnt/disks/data/miniset/msg/geoprocessed/')\n", + "msg_files = ['/mnt/disks/data/miniset/msg/geoprocessed/' + f for f in msg_files]" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [], + "source": [ + "from iti.data.geo_editor import BandSelectionEditor\n" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(11, 1289, 891)" + ] + }, + "execution_count": 77, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_dict['data'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7\n" + ] + }, + { + "ename": "IndexError", + "evalue": "index 0 is out of bounds for axis 0 with size 0", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[80], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m# Get indexes of bands to select\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# indexes = [np.where(source_bands == wvl)[0][0] for wvl in target_bands]\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m wvl \u001b[38;5;129;01min\u001b[39;00m target_bands:\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwhere\u001b[49m\u001b[43m(\u001b[49m\u001b[43msource_bands\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mwvl\u001b[49m\u001b[43m)\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m)\n", + "\u001b[0;31mIndexError\u001b[0m: index 0 is out of bounds for axis 0 with size 0" + ] + } + ], + "source": [ + "source_bands = data_dict[\"wavelengths\"]\n", + "# Get indexes of bands to select\n", + "# indexes = [np.where(source_bands == wvl)[0][0] for wvl in target_bands]\n", + "for wvl in target_bands:\n", + " print(np.where(source_bands == wvl)[0][0])" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 1.64, 3.92, 8.7 , 9.66, 10.8 , 12. , 13.4 , 0.64, 0.81,\n", + " 6.25, 7.35]),\n", + " 3.89)" + ] + }, + "execution_count": 82, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "source_bands, wvl" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [], + "source": [ + "data = msg_time.Rad.compute().to_numpy()\n", + "target_bands=[0.64, 3.92, 7.35, 9.66, 13.4]\n", + "data_dict = {\"data\": data,\n", + " \"wavelengths\" : msg_time.band_wavelength.compute().to_numpy()}\n", + "editor = BandSelectionEditor(target_bands=target_bands)\n", + "data_bands_order1=editor.call(data_dict=data_dict)\n", + "\n", + "\n", + "\n", + "target_bands=[7.35, 9.66, 13.4, 0.64, 3.92]\n", + "data_dict = {\"data\": data,\n", + " \"wavelengths\" : msg_time.band_wavelength.compute().to_numpy()}\n", + "editor = BandSelectionEditor(target_bands=target_bands)\n", + "data_bands_order2=editor.call(data_dict=data_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1715683597.6426992" + ] + }, + "execution_count": 93, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import time\n", + "time.time()" + ] + }, + { + "cell_type": "code", + "execution_count": 92, "metadata": {}, "outputs": [ + { + "data": { + "text/plain": [ + "(array([ 0.64, 3.92, 7.35, 9.66, 13.4 ]),\n", + " array([ 7.35, 9.66, 13.4 , 0.64, 3.92]))" + ] + }, + "execution_count": 92, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_bands_order1['wavelengths'], data_bands_order2['wavelengths']" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [], + "source": [ + "logging_config = {}\n", + "wandb_id = logging_config['wandb_id'] if 'wandb_id' in logging_config else None\n", + "log_model = logging_config['wandb_log_model'] if 'wandb_log_model' in logging_config else False\n" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlilli-freischem\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Path /home/freischem/outputs/miniset/wandb/ wasn't writable, using system temp directory.\n", + "wandb: WARNING Path /home/freischem/outputs/miniset/wandb/ wasn't writable, using system temp directory\n" + ] + }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataArray 'Rad' ()> Size: 4B\n",
-       "array(20.26424, dtype=float32)\n",
+       "
<xarray.Dataset> Size: 272B\n",
+       "Dimensions:          (band_wavelength: 16, band: 16)\n",
        "Coordinates:\n",
-       "    band     int8 1B 16
" + " * band_wavelength (band_wavelength) float32 64B 0.47 0.64 ... 12.27 13.27\n", + " * band (band) int8 16B 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16\n", + "Data variables:\n", + " mean (band) float32 64B 144.0 74.45 43.39 ... 89.34 96.57 85.23\n", + " std (band) float64 128B 97.21 80.72 53.49 ... 23.48 23.76 17.25" ], "text/plain": [ - " Size: 4B\n", - "array(20.26424, dtype=float32)\n", + " Size: 272B\n", + "Dimensions: (band_wavelength: 16, band: 16)\n", "Coordinates:\n", - " band int8 1B 16" + " * band_wavelength (band_wavelength) float32 64B 0.47 0.64 ... 12.27 13.27\n", + " * band (band) int8 16B 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16\n", + "Data variables:\n", + " mean (band) float32 64B 144.0 74.45 43.39 ... 89.34 96.57 85.23\n", + " std (band) float64 128B 97.21 80.72 53.49 ... 23.48 23.76 17.25" ] }, - "execution_count": 20, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# goes_time.Rad[15].std()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "ds_goes = normalize(goes_files[:10])" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "msg_files = os.listdir('/mnt/disks/data/miniset/msg/geoprocessed/')\n", - "msg_files = ['/mnt/disks/data/miniset/msg/geoprocessed/' + f for f in msg_files]" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "ds_msg = normalize(msg_files[:10])" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "ds_goes = ds_goes.compute()" + "ds_goes" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ - "ds_msg = ds_msg.compute()" + "from iti.data.editor import Editor\n", + "\n", + "import numpy as np\n", + "\n", + "class MeanStdNormEditor(Editor):\n", + " \"\"\"\n", + " Normalise each band in the data using the mean and std from the norm_ds.\n", + " \"\"\"\n", + " def __init__(self, norm_ds, key=\"data\"):\n", + " \"\"\"\n", + " Args:\n", + " norm_ds (xarray.Dataset): Dataset with normalization values (mean and std)\n", + " key (str): Key in dictionary to apply transformation\n", + " \"\"\"\n", + " self.key = key\n", + " self.norm = norm_ds\n", + "\n", + " def call(self, data_dict, **kwargs):\n", + " data = data_dict[self.key]\n", + " # use wavelengths and only normalise the bands that we have in the data\n", + " data_wavelengths = data_dict[\"wavelengths\"]\n", + " # Get indeces of bands to select\n", + " indeces = [np.where(self.norm.band_wavelength == wvl)[0][0] for wvl in data_wavelengths]\n", + " \n", + " # extract relevant means and stds\n", + " means = self.norm['mean'][indeces].values\n", + " stds = self.norm['std'][indeces].values\n", + "\n", + " # check that number of channels equals number of means & stds\n", + " assert data.shape[0] == means.shape[0]\n", + " assert data.shape[0] == stds.shape[0]\n", + "\n", + " # apply normalization\n", + " data = (data - means[:, None, None]) / stds[:, None, None]\n", + " \n", + " # Update dictionary\n", + " data_dict[self.key] = data\n", + " return data_dict" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 62, "metadata": {}, "outputs": [ { @@ -924,111 +1456,465 @@ " stroke: currentColor;\n", " fill: currentColor;\n", "}\n", - "
<xarray.Dataset> Size: 272B\n",
-       "Dimensions:          (band_wavelength: 16, band: 16)\n",
+       "
<xarray.Dataset> Size: 275MB\n",
+       "Dimensions:          (x: 504, y: 3687, time: 1, band_wavelength: 16, band: 16)\n",
        "Coordinates:\n",
-       "  * band_wavelength  (band_wavelength) float32 64B 0.47 0.64 ... 12.27 13.27\n",
-       "  * band             (band) int8 16B 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16\n",
-       "Data variables:\n",
-       "    mean             (band) float32 64B 144.0 74.45 43.39 ... 89.34 96.57 85.23\n",
-       "    std              (band) float64 128B 97.21 80.72 53.49 ... 23.48 23.76 17.25
  • band
    PandasIndex
    PandasIndex(Index([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], dtype='int8', name='band'))
  • naming_authority :
    gov.nesdis.noaa
    Conventions :
    CF-1.7
    standard_name_vocabulary :
    CF Standard Name Table (v35, 20 July 2016)
    institution :
    DOC/NOAA/NESDIS > U.S. Department of Commerce, National Oceanic and Atmospheric Administration, National Environmental Satellite, Data, and Information Services
    project :
    GOES
    production_site :
    WCDAS
    production_environment :
    OE
    spatial_resolution :
    1km at nadir
    Metadata_Conventions :
    Unidata Dataset Discovery v1.0
    orbital_slot :
    GOES-East
    platform_ID :
    G16
    instrument_type :
    GOES R Series Advanced Baseline Imager
    scene_id :
    Full Disk
    instrument_ID :
    FM1
    title :
    ABI L1b Radiances
    summary :
    Single reflective band ABI L1b Radiance Products are digital maps of outgoing radiance values at the top of the atmosphere for visible and near-IR bands.
    keywords :
    SPECTRAL/ENGINEERING > VISIBLE WAVELENGTHS > VISIBLE RADIANCE
    keywords_vocabulary :
    NASA Global Change Master Directory (GCMD) Earth Science Keywords, Version 7.0.0.0.0
    iso_series_metadata_id :
    a70be540-c38b-11e0-962b-0800200c9a66
    license :
    Unclassified data. Access is restricted to approved users only.
    processing_level :
    National Aeronautics and Space Administration (NASA) L1b
    cdm_data_type :
    Image
    dataset_name :
    OR_ABI-L1b-RadF-M6C01_G16_s20203001400197_e20203001409505_c20203001409565.nc
    production_data_source :
    Realtime
    timeline_id :
    ABI Mode 6
    date_created :
    2020-10-26T14:09:56.5Z
    time_coverage_start :
    2020-10-26T14:00:19.7Z
    time_coverage_end :
    2020-10-26T14:09:50.5Z
    LUT_Filenames :
    SpaceLookParams(FM1A_CDRL79RevP_PR_09_00_02)-637827000.0.h5 QTableBand01(FM1A_CDRL79RevH_DO_07_00_00)-582860861.0.h5 CalTargetTimeIntervals(FM1A_CDRL79RevP_DO_08_00_01)-611906620.0.h5 BandSaturationLimits(FM1A_CDRL79RevH_DO_08_00_00)-600000000.0.h5 SolarSpaceLookParams(FM1A_CDRL79RevH_DO_09_00_00)-600765435.0.h5 DeadRowListParams(FM1A_CDRL79RevH_DO_08_00_00)-600000000.0.h5 Mirror_Record(FM1A_CDRL79RevG_DO_07_00_00)-582860861.0.h5 KalmanAstroConsts(FM1A_CDRL79RevH_DO_08_00_00)-600000000.0.xml KalmanFilterControls(FM1A_CDRL79RevJ_PR_09_02_06)-652953000.0.xml KalmanMeasMaxSensibles(FMAA_INT_ONLY_DO_09_01_00)-652936814.0.xml KalmanPreprocessorControls(FM1A_CDRL79RevJ_DO_07_00_00)-582860861.0.xml KalmanReferenceData(FM1A_CDRL79RevH_DO_08_00_00)-888.0.xml KalmanStarCatalogs(FM1A_CDRL79RevH_DO_08_00_00)-600000000.0.xml ABI_NavigationRDP_Band01(FM1A_CDRL79RevJ_DO_07_00_00)-582860861.0.xml ABI_NavigationParameters_Band01(FM1A_CDRL79RevH_DO_07_00_00)-582860861.0.xml ABI_ResamplingImplementation_Band01(FM1A_CDRL79RevH_DO_07_02_00)-602129336.0.xml ABI_ResamplingParameters_Band01(FM1A_CDRL79RevJ_DO_07_00_00)-582860861.0.xml StarLookParams(FM1A_CDRL79RevH_DO_08_00_00)-600000000.0.h5 StarDetectionParams(FM1A_CDRL79RevJ_DO_07_00_00)-582860861.0.xml ResamplingScaledConversion(FMAA_INT_ONLY_DO_08_00_00)-1111.0.xml BlockReleaseRegions(FMAA_INT_ONLY_DO_08_00_00)-2222.0.csv VNIR_RetrievalParameters(FM1A_CDRL79RevH_DO_08_00_00)-600000000.0.h5 SCT_Record(FM1A_CDRL79RevM_DO_09_00_00)-600765435.0.h5 ICM_ConversionConsts(FM1A_CDRL43-18_DO_09_01_00)-652936750.0.h5 ICM_SensorCoefficients(FM1A_TMABI_18_159_TMABI_18_533_DO_09_01_00)-652936750.0.h5
    id :
    fc5ae74c-413e-4ee6-8b54-bfd6d004f270
  • " ], "text/plain": [ - " Size: 272B\n", - "Dimensions: (band_wavelength: 16, band: 16)\n", + " Size: 275MB\n", + "Dimensions: (x: 504, y: 3687, time: 1, band_wavelength: 16, band: 16)\n", "Coordinates:\n", + " * x (x) float32 2kB 3.108e+06 3.109e+06 ... 3.611e+06 3.612e+06\n", + " * y (y) float32 15kB 501.3 1.503e+03 ... 3.693e+06 3.694e+06\n", + " * time (time) U.S. Department of Commerce,...\n", + " project: GOES\n", + " production_site: WCDAS\n", + " ... ...\n", + " timeline_id: ABI Mode 6\n", + " date_created: 2020-10-26T14:09:56.5Z\n", + " time_coverage_start: 2020-10-26T14:00:19.7Z\n", + " time_coverage_end: 2020-10-26T14:09:50.5Z\n", + " LUT_Filenames: SpaceLookParams(FM1A_CDRL79RevP_PR_09_00_02)-6...\n", + " id: fc5ae74c-413e-4ee6-8b54-bfd6d004f270" ] }, - "execution_count": 12, + "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "ds_goes" + "goes_time" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ - "from iti.data.editor import Editor\n", - "\n", "import numpy as np\n", - "\n", - "class NormalizeEditor(Editor):\n", - " \"\"\"\n", - " Convert radiance values from mW/m^2/sr/cm^-1 to W/m^2/sr/um\n", - " \"\"\"\n", - " def __init__(self, norm_ds, key=\"data\"):\n", - " \"\"\"\n", - " Args:\n", - " norm_ds (xarray.Dataset): Dataset with normalization values\n", - " key (str): Key in dictionary to apply transformation\n", - " \"\"\"\n", - " self.key = key\n", - " self.norm = norm_ds\n", - "\n", - " def call(self, data_dict, **kwargs):\n", - " data = data_dict[self.key]\n", - " # use wavelengths and only normalise the bands that we have in the data\n", - " data_wavelengths = data_dict[\"wavelengths\"]\n", - " # Get indeces of bands to select\n", - " indeces = [np.where(self.norm.band_wavelength == wvl)[0][0] for wvl in data_wavelengths]\n", - " \n", - " # extract relevant means and stds\n", - " means = ds_goes['mean'][indeces].values\n", - " stds = ds_goes['std'][indeces].values\n", - "\n", - " # TODO apply normalization\n", - " data = \n", - "\n", - " # Convert units\n", - " # data = convert_units(data, wavelengths)\n", - " # Update dictionary\n", - " data_dict[self.key] = data\n", - " return data_dict" + "data_wavelengths = np.array([0.47, 1.38, 1.61,])\n", + "indeces = [np.where(ds_goes.band_wavelength == wvl)[0][0] for wvl in data_wavelengths]" ] }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ - "data_wavelengths = np.array([0.47, 1.38, 1.61,])\n", - "indeces = [np.where(ds_goes.band_wavelength == wvl)[0][0] for wvl in data_wavelengths]" + "means = ds_goes['mean'][indeces].values\n", + "stds = ds_goes['std'][indeces].values" ] }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ - "means = ds_goes['mean'][indeces].values\n", - "stds = ds_goes['std'][indeces].values" + "data = goes_time.Rad[indeces].compute().to_numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 61, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_norm = (data - means[:, None, None]) / stds[:, None, None]\n", + "\n", + "np.isclose(((data[2] - means[2]) / stds[2]), data_norm[2]).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-0.38646985, -0.53684598, -0.428241 , ..., -0.71228487,\n", + " -0.7039306 , -0.67886792],\n", + " [-0.31963594, -0.55355446, -0.64545097, ..., -0.6955764 ,\n", + " -0.6955764 , -0.68722212],\n", + " [-0.63709677, -0.59532561, -0.67051372, ..., -0.68722212,\n", + " -0.6955764 , -0.68722212],\n", + " ...,\n", + " [-0.4198868 , -0.40317825, -0.23609364, ..., 0.47401596,\n", + " 0.42389061, 0.40718205],\n", + " [-0.25280219, -0.12748873, 0.14820094, ..., 0.39047366,\n", + " 0.37376526, 0.43224481],\n", + " [-0.25280219, 0.08972139, 0.23174324, ..., 0.32363975,\n", + " 0.39882785, 0.38211946]])" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "((data - means[:, None, None]) / stds[:, None, None])[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-0.38646984, -0.536846 , -0.42824098, ..., -0.71228486,\n", + " -0.7039306 , -0.67886794],\n", + " [-0.31963593, -0.5535545 , -0.64545095, ..., -0.69557637,\n", + " -0.69557637, -0.6872221 ],\n", + " [-0.63709676, -0.5953256 , -0.6705137 , ..., -0.6872221 ,\n", + " -0.69557637, -0.6872221 ],\n", + " ...,\n", + " [-0.4198868 , -0.40317824, -0.23609364, ..., 0.47401595,\n", + " 0.4238906 , 0.40718204],\n", + " [-0.2528022 , -0.12748873, 0.14820093, ..., 0.39047366,\n", + " 0.37376526, 0.4322448 ],\n", + " [-0.2528022 , 0.08972139, 0.23174325, ..., 0.32363975,\n", + " 0.39882785, 0.38211945]], dtype=float32)" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(data[0] - means[0]) / stds[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((3, 3687, 504), (3,))" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(data - means[:, None, None]).shape, stds.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(True, True, True)" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "((data[0] - means[0]) == (data - means[:, None, None])[0]).all(), ((data[2] - means[2]) == (data - means[:, None, None])[2]).all(), ((data[1] - means[1]) == (data - means[:, None, None])[1])[1:].all()" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[-0.38646985, -0.53684598, -0.428241 , ..., -0.71228487,\n", + " -0.7039306 , -0.67886792],\n", + " [-0.31963594, -0.55355446, -0.64545097, ..., -0.6955764 ,\n", + " -0.6955764 , -0.68722212],\n", + " [-0.63709677, -0.59532561, -0.67051372, ..., -0.68722212,\n", + " -0.6955764 , -0.68722212],\n", + " ...,\n", + " [-0.4198868 , -0.40317825, -0.23609364, ..., 0.47401596,\n", + " 0.42389061, 0.40718205],\n", + " [-0.25280219, -0.12748873, 0.14820094, ..., 0.39047366,\n", + " 0.37376526, 0.43224481],\n", + " [-0.25280219, 0.08972139, 0.23174324, ..., 0.32363975,\n", + " 0.39882785, 0.38211946]],\n", + "\n", + " [[ nan, nan, nan, ..., nan,\n", + " nan, nan],\n", + " [-0.41461605, -0.41461605, -0.41461605, ..., -0.41461605,\n", + " -0.31526359, -0.31526359],\n", + " [-0.41461605, -0.41461605, -0.41461605, ..., -0.31526359,\n", + " -0.31526359, -0.31526359],\n", + " ...,\n", + " [-0.41461605, -0.41461605, -0.41461605, ..., -0.31526359,\n", + " -0.31526359, -0.31526359],\n", + " [-0.41461605, -0.41461605, -0.41461605, ..., -0.31526359,\n", + " -0.31526359, -0.31526359],\n", + " [-0.41461605, -0.41461605, -0.41461605, ..., -0.31526359,\n", + " -0.31526359, -0.31526359]],\n", + "\n", + " [[ 0.15756134, -0.20650614, 0.15756134, ..., -0.69192993,\n", + " -0.69192993, -0.69192993],\n", + " [ 0.27891736, -0.20650614, -0.44921818, ..., -0.69192993,\n", + " -0.69192993, -0.69192993],\n", + " [-0.32786216, -0.20650614, -0.44921818, ..., -0.69192993,\n", + " -0.69192993, -0.69192993],\n", + " ...,\n", + " [-0.69192993, -0.69192993, -0.08515041, ..., 1.73518843,\n", + " 1.61383242, 1.61383242],\n", + " [-0.20650614, 0.03620532, 0.88569688, ..., 1.4924764 ,\n", + " 1.4924764 , 1.61383242],\n", + " [-0.08515041, 1.0070529 , 1.37112038, ..., 1.4924764 ,\n", + " 1.61383242, 1.4924764 ]]])" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(data - means[:, None, None]) / stds[:, None, None]" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(True, True, False)" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "((data[0] - means[0]) == (data - means[:, None, None])[0]).all(), ((data[2] - means[2]) == (data - means[:, None, None])[2]).all(), ((data[1] - means[1]) == (data - means[:, None, None])[1]).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "((data[1] - means[1])[1:] == (data - means[:, None, None])[1][1:]).all()" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ nan, nan, nan, ..., nan, nan,\n", + " nan],\n", + " [-3.389072 , -3.389072 , -3.389072 , ..., -3.389072 , -2.5769649,\n", + " -2.5769649],\n", + " [-3.389072 , -3.389072 , -3.389072 , ..., -2.5769649, -2.5769649,\n", + " -2.5769649],\n", + " ...,\n", + " [-3.389072 , -3.389072 , -3.389072 , ..., -2.5769649, -2.5769649,\n", + " -2.5769649],\n", + " [-3.389072 , -3.389072 , -3.389072 , ..., -2.5769649, -2.5769649,\n", + " -2.5769649],\n", + " [-3.389072 , -3.389072 , -3.389072 , ..., -2.5769649, -2.5769649,\n", + " -2.5769649]], dtype=float32)" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(data[1] - means[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[-37.568344 , -52.186256 , -41.628876 , ..., -69.24049 ,\n", + " -68.42838 , -65.992065 ],\n", + " [-31.071487 , -53.81047 , -62.743637 , ..., -67.61628 ,\n", + " -67.61628 , -66.80417 ],\n", + " [-61.931534 , -57.871002 , -65.17996 , ..., -66.80417 ,\n", + " -67.61628 , -66.80417 ],\n", + " ...,\n", + " [-40.816772 , -39.19255 , -22.950424 , ..., 46.078613 ,\n", + " 41.20598 , 39.581757 ],\n", + " [-24.574646 , -12.393051 , 14.406464 , ..., 37.95755 ,\n", + " 36.333344 , 42.01808 ],\n", + " [-24.574646 , 8.721725 , 22.527527 , ..., 31.460693 ,\n", + " 38.769653 , 37.145447 ]],\n", + "\n", + " [[ nan, nan, nan, ..., nan,\n", + " nan, nan],\n", + " [ -3.389072 , -3.389072 , -3.389072 , ..., -3.389072 ,\n", + " -2.5769649 , -2.5769649 ],\n", + " [ -3.389072 , -3.389072 , -3.389072 , ..., -2.5769649 ,\n", + " -2.5769649 , -2.5769649 ],\n", + " ...,\n", + " [ -3.389072 , -3.389072 , -3.389072 , ..., -2.5769649 ,\n", + " -2.5769649 , -2.5769649 ],\n", + " [ -3.389072 , -3.389072 , -3.389072 , ..., -2.5769649 ,\n", + " -2.5769649 , -2.5769649 ],\n", + " [ -3.389072 , -3.389072 , -3.389072 , ..., -2.5769649 ,\n", + " -2.5769649 , -2.5769649 ]],\n", + "\n", + " [[ 1.0543909 , -1.3819265 , 1.0543909 , ..., -4.630353 ,\n", + " -4.630353 , -4.630353 ],\n", + " [ 1.866498 , -1.3819265 , -3.0061407 , ..., -4.630353 ,\n", + " -4.630353 , -4.630353 ],\n", + " [ -2.1940336 , -1.3819265 , -3.0061407 , ..., -4.630353 ,\n", + " -4.630353 , -4.630353 ],\n", + " ...,\n", + " [ -4.630353 , -4.630353 , -0.56982136, ..., 11.611775 ,\n", + " 10.799668 , 10.799668 ],\n", + " [ -1.3819265 , 0.24228382, 5.9270296 , ..., 9.987561 ,\n", + " 9.987561 , 10.799668 ],\n", + " [ -0.56982136, 6.7391367 , 9.175454 , ..., 9.987561 ,\n", + " 10.799668 , 9.987561 ]]], dtype=float32)" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data - means[:, None, None]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.shape[0] == means.shape[0]" ] }, {