From e437f3cd7771b8863f76c7129e92d4ae26e424ef Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Wed, 15 Jan 2025 13:11:21 +0100 Subject: [PATCH 01/18] fix sklearn warnings --- mambular/preprocessing/basis_expansion.py | 5 ++++- mambular/preprocessing/ple_encoding.py | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mambular/preprocessing/basis_expansion.py b/mambular/preprocessing/basis_expansion.py index d98e9e3..2ac9823 100644 --- a/mambular/preprocessing/basis_expansion.py +++ b/mambular/preprocessing/basis_expansion.py @@ -43,7 +43,6 @@ def __init__( if spline_implementation not in ["scipy", "sklearn"]: raise ValueError("Invalid spline implementation. Choose 'scipy' or 'sklearn'.") - self.fitted = False @staticmethod def knot_identification_using_decision_tree(X, y, task="regression", n_knots=5): @@ -76,6 +75,7 @@ def fit(self, X, y=None): raise ValueError("Target variable 'y' must be provided when use_decision_tree=True.") self.knots = [] + self.n_features_in_ = X.shape[1] if self.use_decision_tree and self.spline_implementation == "scipy": self.knots = self.knot_identification_using_decision_tree(X, y, self.task, self.n_knots) @@ -105,6 +105,8 @@ def fit(self, X, y=None): self.transformer.fit(X) self.fitted = True + + elif self.spline_implementation == "sklearn" and not self.use_decision_tree: if self.strategy == "quantile": # print("Using sklearn spline transformer using quantile") @@ -124,6 +126,7 @@ def fit(self, X, y=None): self.fitted = True self.transformer.fit(X) + return self def transform(self, X): diff --git a/mambular/preprocessing/ple_encoding.py b/mambular/preprocessing/ple_encoding.py index a75a47f..01c217a 100644 --- a/mambular/preprocessing/ple_encoding.py +++ b/mambular/preprocessing/ple_encoding.py @@ -74,6 +74,7 @@ def __init__(self, n_bins=20, tree_params={}, task="regression", conditions=None self.pattern = r"-?\d+\.?\d*[eE]?[+-]?\d*" def fit(self, feature, target): + self.n_features_in_ = 1 if self.task == "regression": dt = DecisionTreeRegressor(max_leaf_nodes=self.n_bins) elif self.task == "classification": @@ -84,9 +85,11 @@ def fit(self, feature, target): dt.fit(feature, target) self.conditions = tree_to_code(dt, ["feature"]) + #self.fitted = True return self def transform(self, feature): + if feature.shape == (feature.shape[0], 1): feature = np.squeeze(feature, axis=1) else: @@ -134,6 +137,8 @@ def transform(self, feature): else: return np.array(ple_encoded_feature, dtype=np.float32) + + def get_feature_names_out(self, input_features=None): if input_features is None: raise ValueError("input_features must be specified") From d4c61f393b12d95d562696374a7b43d75e21a92a Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 17 Jan 2025 14:59:25 +0100 Subject: [PATCH 02/18] include predict step --- mambular/base_models/lightning_wrapper.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mambular/base_models/lightning_wrapper.py b/mambular/base_models/lightning_wrapper.py index 1e8dd34..c786f20 100644 --- a/mambular/base_models/lightning_wrapper.py +++ b/mambular/base_models/lightning_wrapper.py @@ -302,6 +302,28 @@ def test_step(self, batch, batch_idx): # type: ignore return test_loss + def predict_step(self, batch, batch_idx): + """Predict step for a single batch. + + Parameters + ---------- + batch : tuple + Batch of data containing numerical features, categorical features, and labels. + batch_idx : int + Index of the batch. + + Returns + ------- + Tensor + Predictions. + """ + + cat_features, num_features, labels = batch + preds = self(num_features=num_features, cat_features=cat_features) + + return preds + + def on_validation_epoch_end(self): """Callback executed at the end of each validation epoch. From e3e39bf18727298a4760f4a4fe6e8e179fa4e117 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 17 Jan 2025 15:34:02 +0100 Subject: [PATCH 03/18] adjust datamodule and dataset to include prediction dataset --- mambular/base_models/lightning_wrapper.py | 2 +- mambular/data_utils/datamodule.py | 69 ++++++++++++----------- mambular/data_utils/dataset.py | 53 ++++++++++------- 3 files changed, 69 insertions(+), 55 deletions(-) diff --git a/mambular/base_models/lightning_wrapper.py b/mambular/base_models/lightning_wrapper.py index c786f20..097a5b2 100644 --- a/mambular/base_models/lightning_wrapper.py +++ b/mambular/base_models/lightning_wrapper.py @@ -318,7 +318,7 @@ def predict_step(self, batch, batch_idx): Predictions. """ - cat_features, num_features, labels = batch + cat_features, num_features = batch preds = self(num_features=num_features, cat_features=cat_features) return preds diff --git a/mambular/data_utils/datamodule.py b/mambular/data_utils/datamodule.py index f78d6d8..d45452b 100644 --- a/mambular/data_utils/datamodule.py +++ b/mambular/data_utils/datamodule.py @@ -188,23 +188,11 @@ def setup(self, stage: str): regression=self.regression, ) self.val_dataset = MambularDataset(val_cat_tensors, val_num_tensors, val_labels, regression=self.regression) - elif stage == "test": - if not self.test_preprocessor_fitted: - raise ValueError( - "The preprocessor has not been fitted. Please fit the preprocessor before transforming the test data." - ) - - self.test_dataset = MambularDataset( - self.test_cat_tensors, - self.test_num_tensors, - train_labels, # type: ignore - regression=self.regression, - ) - def preprocess_test_data(self, X): - self.test_cat_tensors = [] - self.test_num_tensors = [] - test_preprocessed_data = self.preprocessor.transform(X) + def preprocess_new_data(self, X): + cat_tensors = [] + num_tensors = [] + preprocessed_data = self.preprocessor.transform(X) # Populate tensors for categorical features, if present in processed data for key in self.cat_feature_info: # type: ignore @@ -215,21 +203,21 @@ def preprocess_test_data(self, X): else torch.long ) cat_key = "cat_" + key # Assuming categorical keys are prefixed with 'cat_' - if cat_key in test_preprocessed_data: - self.test_cat_tensors.append(torch.tensor(test_preprocessed_data[cat_key], dtype=dtype)) + if cat_key in preprocessed_data: + cat_tensors.append(torch.tensor(preprocessed_data[cat_key], dtype=dtype)) binned_key = "num_" + key # for binned features - if binned_key in test_preprocessed_data: - self.test_cat_tensors.append(torch.tensor(test_preprocessed_data[binned_key], dtype=dtype)) + if binned_key in preprocessed_data: + cat_tensors.append(torch.tensor(preprocessed_data[binned_key], dtype=dtype)) # Populate tensors for numerical features, if present in processed data for key in self.num_feature_info: # type: ignore num_key = "num_" + key # Assuming numerical keys are prefixed with 'num_' - if num_key in test_preprocessed_data: - self.test_num_tensors.append(torch.tensor(test_preprocessed_data[num_key], dtype=torch.float32)) + if num_key in preprocessed_data: + num_tensors.append(torch.tensor(preprocessed_data[num_key], dtype=torch.float32)) - self.test_preprocessor_fitted = True - return self.test_cat_tensors, self.test_num_tensors + + return MambularDataset(cat_tensors, num_tensors, labels=None, regression=self.regression) def train_dataloader(self): """Returns the training dataloader. @@ -237,13 +225,15 @@ def train_dataloader(self): Returns: DataLoader: DataLoader instance for the training dataset. """ - - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - **self.dataloader_kwargs, - ) + if hasattr(self, "train_dataset"): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + **self.dataloader_kwargs, + ) + else: + raise ValueError("No training dataset provided!") def val_dataloader(self): """Returns the validation dataloader. @@ -251,7 +241,10 @@ def val_dataloader(self): Returns: DataLoader: DataLoader instance for the validation dataset. """ - return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.dataloader_kwargs) + if hasattr(self, "val_dataset"): + return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.dataloader_kwargs) + else: + raise ValueError("No validation dataset provided!") def test_dataloader(self): """Returns the test dataloader. @@ -259,4 +252,14 @@ def test_dataloader(self): Returns: DataLoader: DataLoader instance for the test dataset. """ - return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.dataloader_kwargs) + if hasattr(self, "test_dataset"): + return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.dataloader_kwargs) + else: + raise ValueError("No test dataset provided!") + + def predict_dataloader(self): + if hasattr(self, "predict_dataset"): + return DataLoader(self.predict_dataset, batch_size=self.batch_size, **self.dataloader_kwargs) + else: + raise ValueError("No predict dataset provided!") + diff --git a/mambular/data_utils/dataset.py b/mambular/data_utils/dataset.py index b1c07bf..034a581 100644 --- a/mambular/data_utils/dataset.py +++ b/mambular/data_utils/dataset.py @@ -3,6 +3,11 @@ from torch.utils.data import Dataset +import numpy as np +import torch +from torch.utils.data import Dataset + + class MambularDataset(Dataset): """Custom dataset for handling structured data with separate categorical and numerical features, tailored for both regression and classification tasks. @@ -11,28 +16,31 @@ class MambularDataset(Dataset): ---------- cat_features_list (list of Tensors): A list of tensors representing the categorical features. num_features_list (list of Tensors): A list of tensors representing the numerical features. - labels (Tensor): A tensor of labels. + labels (Tensor, optional): A tensor of labels. If None, the dataset is used for prediction. regression (bool, optional): A flag indicating if the dataset is for a regression task. Defaults to True. """ - def __init__(self, cat_features_list, num_features_list, labels, regression=True): + def __init__(self, cat_features_list, num_features_list, labels=None, regression=True): self.cat_features_list = cat_features_list # Categorical features tensors self.num_features_list = num_features_list # Numerical features tensors - self.regression = regression - if not self.regression: - self.num_classes = len(np.unique(labels)) - if self.num_classes > 2: - self.labels = labels.view(-1) + + if labels is not None: + if not self.regression: + self.num_classes = len(np.unique(labels)) + if self.num_classes > 2: + self.labels = labels.view(-1) + else: + self.num_classes = 1 + self.labels = labels else: - self.num_classes = 1 self.labels = labels + self.num_classes = 1 else: - self.labels = labels - self.num_classes = 1 + self.labels = None # No labels in prediction mode def __len__(self): - return len(self.labels) + return len(self.num_features_list[0]) # Use numerical features length def __getitem__(self, idx): """Retrieves the features and label for a given index. @@ -43,21 +51,24 @@ def __getitem__(self, idx): Returns ------- - tuple: A tuple containing two lists of tensors (one for categorical features and one for numerical - features) and a single label (float if regression is True). + tuple: A tuple containing two lists of tensors (one for categorical features and one for numerical features) + and a single label (if available). """ cat_features = [feature_tensor[idx] for feature_tensor in self.cat_features_list] num_features = [ torch.as_tensor(feature_tensor[idx]).clone().detach().to(torch.float32) for feature_tensor in self.num_features_list ] - label = self.labels[idx] - if self.regression: - label = label.clone().detach().to(torch.float32) - elif self.num_classes == 1: - label = label.clone().detach().to(torch.float32) + + if self.labels is not None: + label = self.labels[idx] + if self.regression: + label = label.clone().detach().to(torch.float32) + elif self.num_classes == 1: + label = label.clone().detach().to(torch.float32) + else: + label = label.clone().detach().to(torch.long) + return cat_features, num_features, label else: - label = label.clone().detach().to(torch.long) + return cat_features, num_features # No label in prediction mode - # Keep categorical and numerical features separate - return cat_features, num_features, label From 8cc1e798fb871e938e4b5a5e0c2b9e8957b97d3f Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 17 Jan 2025 15:34:15 +0100 Subject: [PATCH 04/18] fix batch prediction in sklearn models --- mambular/models/sklearn_base_classifier.py | 161 +++++++++------------ mambular/models/sklearn_base_lss.py | 41 ++---- mambular/models/sklearn_base_regressor.py | 30 ++-- 3 files changed, 89 insertions(+), 143 deletions(-) diff --git a/mambular/models/sklearn_base_classifier.py b/mambular/models/sklearn_base_classifier.py index 85b2968..8d13dd4 100644 --- a/mambular/models/sklearn_base_classifier.py +++ b/mambular/models/sklearn_base_classifier.py @@ -376,126 +376,99 @@ def fit( return self def predict(self, X, device=None): - """Predicts target values for the given input samples. - + """Predicts target labels for the given input samples. + Parameters ---------- X : DataFrame or array-like, shape (n_samples, n_features) The input samples for which to predict target values. - - + Returns ------- - predictions : ndarray, shape (n_samples,) or (n_samples, n_outputs) - The predicted target values. + predictions : ndarray, shape (n_samples,) + The predicted class labels. """ # Ensure model and data module are initialized if self.task_model is None or self.data_module is None: raise ValueError("The model or data module has not been fitted yet.") - + # Preprocess the data using the data module - cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) - - # Move tensors to appropriate device - if device is None: - device = next(self.task_model.parameters()).device - if isinstance(cat_tensors, list): - cat_tensors = [tensor.to(device) for tensor in cat_tensors] - else: - cat_tensors = cat_tensors.to(device) - - if isinstance(num_tensors, list): - num_tensors = [tensor.to(device) for tensor in num_tensors] - else: - num_tensors = num_tensors.to(device) - + self.data_module.predict_dataset = self.data_module.preprocess_new_data(X) + # Set model to evaluation mode self.task_model.eval() - - # Perform inference - with torch.no_grad(): - logits = self.task_model(num_features=num_tensors, cat_features=cat_tensors) - - # Check if ensemble is used - if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble - # Average logits across the ensemble dimension (assuming shape: (batch_size, ensemble_size, output_dim)) - logits = logits.mean(dim=1) - if logits.dim() == 1: # Check if logits has only one dimension (shape (N,)) - logits = logits.unsqueeze(1) - - # Check the shape of the logits to determine binary or multi-class classification - if logits.shape[1] == 1: - # Binary classification - probabilities = torch.sigmoid(logits) - predictions = (probabilities > 0.5).long().squeeze() - else: - # Multi-class classification - probabilities = torch.softmax(logits, dim=1) - predictions = torch.argmax(probabilities, dim=1) - + + # Perform inference using PyTorch Lightning's predict function + logits_list = self.trainer.predict(self.task_model, self.data_module) + + # Concatenate predictions from all batches + logits = torch.cat(logits_list, dim=0) + + # Check if ensemble is used + if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble + logits = logits.mean(dim=1) # Average over ensemble dimension + if logits.dim() == 1: # Ensure correct shape + logits = logits.unsqueeze(1) + + # Check the shape of the logits to determine binary or multi-class classification + if logits.shape[1] == 1: + # Binary classification + probabilities = torch.sigmoid(logits) + predictions = (probabilities > 0.5).long().squeeze() + else: + # Multi-class classification + probabilities = torch.softmax(logits, dim=1) + predictions = torch.argmax(probabilities, dim=1) + # Convert predictions to NumPy array and return return predictions.cpu().numpy() - + + def predict_proba(self, X, device=None): - """Predict class probabilities for the given input samples. - + """Predicts class probabilities for the given input samples. + Parameters ---------- - X : array-like or pd.DataFrame of shape (n_samples, n_features) + X : DataFrame or array-like, shape (n_samples, n_features) The input samples for which to predict class probabilities. - - - Notes - ----- - The method preprocesses the input data using the same preprocessor used during training, - sets the model to evaluation mode, and then performs inference to predict the class probabilities. - Softmax is applied to the logits to obtain probabilities, which are then converted from a PyTorch tensor - to a NumPy array before being returned. - + Returns ------- - probabilities : ndarray of shape (n_samples, n_classes) - Predicted class probabilities for each input sample. + probabilities : ndarray, shape (n_samples, n_classes) + The predicted class probabilities. """ - # Preprocess the data - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - device = next(self.task_model.parameters()).device # type: ignore - cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) - if isinstance(cat_tensors, list): - cat_tensors = [tensor.to(device) for tensor in cat_tensors] - else: - cat_tensors = cat_tensors.to(device) - - if isinstance(num_tensors, list): - num_tensors = [tensor.to(device) for tensor in num_tensors] + # Ensure model and data module are initialized + if self.task_model is None or self.data_module is None: + raise ValueError("The model or data module has not been fitted yet.") + + # Preprocess the data using the data module + self.data_module.predict_dataset = self.data_module.preprocess_new_data(X) + + # Set model to evaluation mode + self.task_model.eval() + + # Perform inference using PyTorch Lightning's predict function + logits_list = self.trainer.predict(self.task_model, self.data_module) + + # Concatenate predictions from all batches + logits = torch.cat(logits_list, dim=0) + + # Check if ensemble is used + if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble + logits = logits.mean(dim=1) # Average over ensemble dimension + if logits.dim() == 1: # Ensure correct shape + logits = logits.unsqueeze(1) + + # Compute probabilities + if logits.shape[1] > 1: + probabilities = torch.softmax(logits, dim=1) # Multi-class classification else: - num_tensors = num_tensors.to(device) - - # Set the model to evaluation mode - self.task_model.eval() # type: ignore - - # Perform inference - with torch.no_grad(): - logits = self.task_model( # type: ignore - num_features=num_tensors, cat_features=cat_tensors - ) - # Check if ensemble is used - # If using ensemble - if hasattr(self.task_model.base_model, "returns_ensemble"): # type: ignore - # Average logits across the ensemble dimension - # (assuming shape: (batch_size, ensemble_size, output_dim)) - logits = logits.mean(dim=1) - if logits.dim() == 1: # Check if logits has only one dimension (shape (N,)) - logits = logits.unsqueeze(1) - if logits.shape[1] > 1: - probabilities = torch.softmax(logits, dim=1) - else: - probabilities = torch.sigmoid(logits) - + probabilities = torch.sigmoid(logits) # Binary classification + # Convert probabilities to NumPy array and return return probabilities.cpu().numpy() + def evaluate(self, X, y_true, metrics=None): """Evaluate the model on the given data using specified metrics. diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py index 3a25236..1832086 100644 --- a/mambular/models/sklearn_base_lss.py +++ b/mambular/models/sklearn_base_lss.py @@ -421,7 +421,7 @@ def fit( return self - def predict(self, X, raw=False, device=None): + def predict(self, X, device=None): """Predicts target values for the given input samples. Parameters @@ -440,38 +440,23 @@ def predict(self, X, raw=False, device=None): raise ValueError("The model or data module has not been fitted yet.") # Preprocess the data using the data module - cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) - - # Move tensors to appropriate device - if device is not None: - device = next(self.task_model.parameters()).device - if isinstance(cat_tensors, list): - cat_tensors = [tensor.to(device) for tensor in cat_tensors] - else: - cat_tensors = cat_tensors.to(device) - - if isinstance(num_tensors, list): - num_tensors = [tensor.to(device) for tensor in num_tensors] - else: - num_tensors = num_tensors.to(device) + self.data_module.predict_dataset = self.data_module.preprocess_new_data(X) # Set model to evaluation mode self.task_model.eval() - # Perform inference - with torch.no_grad(): - predictions = self.task_model(num_features=num_tensors, cat_features=cat_tensors) - + # Perform inference using PyTorch Lightning's predict function + predictions_list = self.trainer.predict(self.task_model, self.data_module) + + # Concatenate predictions from all batches + predictions = torch.cat(predictions_list, dim=0) + # Check if ensemble is used - if getattr(self.base_model, "returns_ensemble", False): # If using ensemble - # Average over the ensemble dimension (assuming shape: (batch_size, ensemble_size, output_dim)) - predictions = predictions.mean(dim=1) - - if not raw: - result = self.task_model.family(predictions).cpu().numpy() # type: ignore - return result - else: - return predictions.cpu().numpy() + if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble + predictions = predictions.mean(dim=1) # Average over ensemble dimension + + # Convert predictions to NumPy array and return + return predictions.cpu().numpy() def evaluate(self, X, y_true, metrics=None, distribution_family=None): """Evaluate the model on the given data using specified metrics. diff --git a/mambular/models/sklearn_base_regressor.py b/mambular/models/sklearn_base_regressor.py index e49a357..2a747b9 100644 --- a/mambular/models/sklearn_base_regressor.py +++ b/mambular/models/sklearn_base_regressor.py @@ -387,33 +387,21 @@ def predict(self, X, device=None): raise ValueError("The model or data module has not been fitted yet.") # Preprocess the data using the data module - cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) - - # Move tensors to appropriate device - if device is None: - device = next(self.task_model.parameters()).device - if isinstance(cat_tensors, list): - cat_tensors = [tensor.to(device) for tensor in cat_tensors] - else: - cat_tensors = cat_tensors.to(device) - - if isinstance(num_tensors, list): - num_tensors = [tensor.to(device) for tensor in num_tensors] - else: - num_tensors = num_tensors.to(device) + self.data_module.predict_dataset = self.data_module.preprocess_new_data(X) # Set model to evaluation mode self.task_model.eval() - # Perform inference - with torch.no_grad(): - predictions = self.task_model(num_features=num_tensors, cat_features=cat_tensors) - + # Perform inference using PyTorch Lightning's predict function + predictions_list = self.trainer.predict(self.task_model, self.data_module) + + # Concatenate predictions from all batches + predictions = torch.cat(predictions_list, dim=0) + # Check if ensemble is used if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble - # Average over the ensemble dimension (assuming shape: (batch_size, ensemble_size, output_dim)) - predictions = predictions.mean(dim=1) - + predictions = predictions.mean(dim=1) # Average over ensemble dimension + # Convert predictions to NumPy array and return return predictions.cpu().numpy() From 792c4a2e4ba25af73a48b128bcf32a06c2ce52e7 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 17 Jan 2025 16:27:34 +0100 Subject: [PATCH 05/18] format --- mambular/data_utils/datamodule.py | 86 +++++++++++++++++++++++-------- 1 file changed, 64 insertions(+), 22 deletions(-) diff --git a/mambular/data_utils/datamodule.py b/mambular/data_utils/datamodule.py index d45452b..eb71737 100644 --- a/mambular/data_utils/datamodule.py +++ b/mambular/data_utils/datamodule.py @@ -123,7 +123,9 @@ def preprocess_data( self.y_val = y_val # Fit the preprocessor on the combined training and validation data - combined_X = pd.concat([self.X_train, self.X_val], axis=0).reset_index(drop=True) + combined_X = pd.concat([self.X_train, self.X_val], axis=0).reset_index( + drop=True + ) combined_y = np.concatenate((self.y_train, self.y_val), axis=0) # Fit the preprocessor @@ -156,29 +158,53 @@ def setup(self, stage: str): else torch.long ) - cat_key = "cat_" + key # Assuming categorical keys are prefixed with 'cat_' + cat_key = ( + "cat_" + key + ) # Assuming categorical keys are prefixed with 'cat_' if cat_key in train_preprocessed_data: - train_cat_tensors.append(torch.tensor(train_preprocessed_data[cat_key], dtype=dtype)) + train_cat_tensors.append( + torch.tensor(train_preprocessed_data[cat_key], dtype=dtype) + ) if cat_key in val_preprocessed_data: - val_cat_tensors.append(torch.tensor(val_preprocessed_data[cat_key], dtype=dtype)) + val_cat_tensors.append( + torch.tensor(val_preprocessed_data[cat_key], dtype=dtype) + ) binned_key = "num_" + key # for binned features if binned_key in train_preprocessed_data: - train_cat_tensors.append(torch.tensor(train_preprocessed_data[binned_key], dtype=dtype)) + train_cat_tensors.append( + torch.tensor(train_preprocessed_data[binned_key], dtype=dtype) + ) if binned_key in val_preprocessed_data: - val_cat_tensors.append(torch.tensor(val_preprocessed_data[binned_key], dtype=dtype)) + val_cat_tensors.append( + torch.tensor(val_preprocessed_data[binned_key], dtype=dtype) + ) # Populate tensors for numerical features, if present in processed data for key in self.num_feature_info: # type: ignore - num_key = "num_" + key # Assuming numerical keys are prefixed with 'num_' + num_key = ( + "num_" + key + ) # Assuming numerical keys are prefixed with 'num_' if num_key in train_preprocessed_data: - train_num_tensors.append(torch.tensor(train_preprocessed_data[num_key], dtype=torch.float32)) + train_num_tensors.append( + torch.tensor( + train_preprocessed_data[num_key], dtype=torch.float32 + ) + ) if num_key in val_preprocessed_data: - val_num_tensors.append(torch.tensor(val_preprocessed_data[num_key], dtype=torch.float32)) - - train_labels = torch.tensor(self.y_train, dtype=self.labels_dtype).unsqueeze(dim=1) - val_labels = torch.tensor(self.y_val, dtype=self.labels_dtype).unsqueeze(dim=1) + val_num_tensors.append( + torch.tensor( + val_preprocessed_data[num_key], dtype=torch.float32 + ) + ) + + train_labels = torch.tensor( + self.y_train, dtype=self.labels_dtype + ).unsqueeze(dim=1) + val_labels = torch.tensor(self.y_val, dtype=self.labels_dtype).unsqueeze( + dim=1 + ) # Create datasets self.train_dataset = MambularDataset( @@ -187,7 +213,9 @@ def setup(self, stage: str): train_labels, regression=self.regression, ) - self.val_dataset = MambularDataset(val_cat_tensors, val_num_tensors, val_labels, regression=self.regression) + self.val_dataset = MambularDataset( + val_cat_tensors, val_num_tensors, val_labels, regression=self.regression + ) def preprocess_new_data(self, X): cat_tensors = [] @@ -204,20 +232,27 @@ def preprocess_new_data(self, X): ) cat_key = "cat_" + key # Assuming categorical keys are prefixed with 'cat_' if cat_key in preprocessed_data: - cat_tensors.append(torch.tensor(preprocessed_data[cat_key], dtype=dtype)) + cat_tensors.append( + torch.tensor(preprocessed_data[cat_key], dtype=dtype) + ) binned_key = "num_" + key # for binned features if binned_key in preprocessed_data: - cat_tensors.append(torch.tensor(preprocessed_data[binned_key], dtype=dtype)) + cat_tensors.append( + torch.tensor(preprocessed_data[binned_key], dtype=dtype) + ) # Populate tensors for numerical features, if present in processed data for key in self.num_feature_info: # type: ignore num_key = "num_" + key # Assuming numerical keys are prefixed with 'num_' if num_key in preprocessed_data: - num_tensors.append(torch.tensor(preprocessed_data[num_key], dtype=torch.float32)) + num_tensors.append( + torch.tensor(preprocessed_data[num_key], dtype=torch.float32) + ) - - return MambularDataset(cat_tensors, num_tensors, labels=None, regression=self.regression) + return MambularDataset( + cat_tensors, num_tensors, labels=None, regression=self.regression + ) def train_dataloader(self): """Returns the training dataloader. @@ -242,7 +277,9 @@ def val_dataloader(self): DataLoader: DataLoader instance for the validation dataset. """ if hasattr(self, "val_dataset"): - return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.dataloader_kwargs) + return DataLoader( + self.val_dataset, batch_size=self.batch_size, **self.dataloader_kwargs + ) else: raise ValueError("No validation dataset provided!") @@ -253,13 +290,18 @@ def test_dataloader(self): DataLoader: DataLoader instance for the test dataset. """ if hasattr(self, "test_dataset"): - return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.dataloader_kwargs) + return DataLoader( + self.test_dataset, batch_size=self.batch_size, **self.dataloader_kwargs + ) else: raise ValueError("No test dataset provided!") def predict_dataloader(self): if hasattr(self, "predict_dataset"): - return DataLoader(self.predict_dataset, batch_size=self.batch_size, **self.dataloader_kwargs) + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + **self.dataloader_kwargs, + ) else: raise ValueError("No predict dataset provided!") - From f37b6d3431a64808aa15d6df4e980fd8319e8299 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 17 Jan 2025 16:27:45 +0100 Subject: [PATCH 06/18] adapt lightningmodule to have custom metrics --- mambular/base_models/lightning_wrapper.py | 98 +++++++++++------------ 1 file changed, 46 insertions(+), 52 deletions(-) diff --git a/mambular/base_models/lightning_wrapper.py b/mambular/base_models/lightning_wrapper.py index 097a5b2..5621c4e 100644 --- a/mambular/base_models/lightning_wrapper.py +++ b/mambular/base_models/lightning_wrapper.py @@ -1,9 +1,7 @@ from collections.abc import Callable - import lightning as pl import torch import torch.nn as nn -import torchmetrics class TaskModel(pl.LightningModule): @@ -41,6 +39,8 @@ def __init__( pruning_epoch=5, optimizer_type: str = "Adam", optimizer_args: dict | None = None, + train_metrics: dict[str, Callable] | None = None, + val_metrics: dict[str, Callable] | None = None, **kwargs, ): super().__init__() @@ -53,6 +53,10 @@ def __init__( self.pruning_epoch = pruning_epoch self.val_losses = [] + # Store custom metrics + self.train_metrics = train_metrics or {} + self.val_metrics = val_metrics or {} + self.optimizer_params = { k.replace("optimizer_", ""): v for k, v in optimizer_args.items() # type: ignore @@ -65,16 +69,10 @@ def __init__( if num_classes == 2: if not self.loss_fct: self.loss_fct = nn.BCEWithLogitsLoss() - self.acc = torchmetrics.Accuracy(task="binary") - self.auroc = torchmetrics.AUROC(task="binary") - self.precision = torchmetrics.Precision(task="binary") self.num_classes = 1 elif num_classes > 2: if not self.loss_fct: self.loss_fct = nn.CrossEntropyLoss() - self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes) - self.auroc = torchmetrics.AUROC(task="multiclass", num_classes=num_classes) - self.precision = torchmetrics.Precision(task="multiclass", num_classes=num_classes) else: self.loss_fct = nn.MSELoss() @@ -146,7 +144,10 @@ def compute_loss(self, predictions, y_true): ) if getattr(self.base_model, "returns_ensemble", False): # Ensemble case - if self.loss_fct.__class__.__name__ == "CrossEntropyLoss" and predictions.dim() == 3: + if ( + self.loss_fct.__class__.__name__ == "CrossEntropyLoss" + and predictions.dim() == 3 + ): # Classification case with ensemble: predictions (N, E, k), y_true (N,) N, E, k = predictions.shape loss = 0.0 @@ -191,27 +192,30 @@ def training_step(self, batch, batch_idx): # type: ignore # Check if the model has a `penalty_forward` method if hasattr(self.base_model, "penalty_forward"): - preds, penalty = self.base_model.penalty_forward(num_features=num_features, cat_features=cat_features) + preds, penalty = self.base_model.penalty_forward( + num_features=num_features, cat_features=cat_features + ) loss = self.compute_loss(preds, labels) + penalty else: preds = self(num_features=num_features, cat_features=cat_features) loss = self.compute_loss(preds, labels) # Log the training loss - self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - - # Log additional metrics - if not self.lss and not hasattr(self.base_model, "returns_ensemble"): - if self.num_classes > 1: - acc = self.acc(preds, labels) - self.log( - "train_acc", - acc, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - ) + self.log( + "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + + # Log custom training metrics + for metric_name, metric_fn in self.train_metrics.items(): + metric_value = metric_fn(preds, labels) + self.log( + f"train_{metric_name}", + metric_value, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) return loss @@ -244,18 +248,17 @@ def validation_step(self, batch, batch_idx): # type: ignore logger=True, ) - # Log additional metrics - if not self.lss and not hasattr(self.base_model, "returns_ensemble"): - if self.num_classes > 1: - acc = self.acc(preds, labels) - self.log( - "val_acc", - acc, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) + # Log custom validation metrics + for metric_name, metric_fn in self.val_metrics.items(): + metric_value = metric_fn(preds, labels) + self.log( + f"val_{metric_name}", + metric_value, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) return val_loss @@ -287,19 +290,6 @@ def test_step(self, batch, batch_idx): # type: ignore logger=True, ) - # Log additional metrics - if not self.lss and not hasattr(self.base_model, "returns_ensemble"): - if self.num_classes > 1: - acc = self.acc(preds, labels) - self.log( - "test_acc", - acc, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return test_loss def predict_step(self, batch, batch_idx): @@ -323,7 +313,6 @@ def predict_step(self, batch, batch_idx): return preds - def on_validation_epoch_end(self): """Callback executed at the end of each validation epoch. @@ -363,8 +352,13 @@ def on_validation_epoch_end(self): # Apply pruning logic if needed if self.current_epoch >= self.pruning_epoch: - if self.early_pruning_threshold is not None and val_loss_value > self.early_pruning_threshold: - print(f"Pruned at epoch {self.current_epoch}, val_loss {val_loss_value}") + if ( + self.early_pruning_threshold is not None + and val_loss_value > self.early_pruning_threshold + ): + print( + f"Pruned at epoch {self.current_epoch}, val_loss {val_loss_value}" + ) self.trainer.should_stop = True # Stop training early def epoch_val_loss_at(self, epoch): From 44778a13787a035b799f6268109c41f41d74b582 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 17 Jan 2025 16:34:47 +0100 Subject: [PATCH 07/18] assign datasets --- mambular/data_utils/datamodule.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mambular/data_utils/datamodule.py b/mambular/data_utils/datamodule.py index eb71737..42c6fb4 100644 --- a/mambular/data_utils/datamodule.py +++ b/mambular/data_utils/datamodule.py @@ -254,6 +254,12 @@ def preprocess_new_data(self, X): cat_tensors, num_tensors, labels=None, regression=self.regression ) + def assign_predict_dataset(self, X): + self.predict_dataset = self.preprocess_new_data(X) + + def assign_test_dataset(self, X): + self.test_dataset = self.preprocess_new_data(X) + def train_dataloader(self): """Returns the training dataloader. From fd9a2573f8e34bf511a2e6061a8c29199ebe236d Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 17 Jan 2025 16:35:00 +0100 Subject: [PATCH 08/18] include passing metrics into sklearn models --- mambular/models/sklearn_base_classifier.py | 150 ++++++++++++++------- mambular/models/sklearn_base_lss.py | 125 +++++++++++------ mambular/models/sklearn_base_regressor.py | 122 ++++++++++++----- 3 files changed, 271 insertions(+), 126 deletions(-) diff --git a/mambular/models/sklearn_base_classifier.py b/mambular/models/sklearn_base_classifier.py index 8d13dd4..5b8a195 100644 --- a/mambular/models/sklearn_base_classifier.py +++ b/mambular/models/sklearn_base_classifier.py @@ -9,11 +9,15 @@ from sklearn.base import BaseEstimator from sklearn.metrics import accuracy_score, log_loss, mean_squared_error from skopt import gp_minimize - +from collections.abc import Callable from ..base_models.lightning_wrapper import TaskModel from ..data_utils.datamodule import MambularDataModule from ..preprocessing import Preprocessor -from ..utils.config_mapper import activation_mapper, get_search_space, round_to_nearest_16 +from ..utils.config_mapper import ( + activation_mapper, + get_search_space, + round_to_nearest_16, +) class SklearnBaseClassifier(BaseEstimator): @@ -36,11 +40,15 @@ def __init__(self, model, config, **kwargs): ] self.config_kwargs = { - k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names and not k.startswith("optimizer") + k: v + for k, v in kwargs.items() + if k not in self.preprocessor_arg_names and not k.startswith("optimizer") } self.config = config(**self.config_kwargs) - preprocessor_kwargs = {k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names} + preprocessor_kwargs = { + k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names + } self.preprocessor = Preprocessor(**preprocessor_kwargs) self.task_model = None @@ -60,7 +68,8 @@ def __init__(self, model, config, **kwargs): self.optimizer_kwargs = { k: v for k, v in kwargs.items() - if k not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] + if k + not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] and k.startswith("optimizer_") } @@ -81,7 +90,10 @@ def get_params(self, deep=True): params.update(self.config_kwargs) if deep: - preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()} + preprocessor_params = { + "prepro__" + key: value + for key, value in self.preprocessor.get_params().items() + } params.update(preprocessor_params) return params @@ -99,8 +111,14 @@ def set_params(self, **parameters): self : object Estimator instance. """ - config_params = {k: v for k, v in parameters.items() if not k.startswith("prepro__")} - preprocessor_params = {k.split("__")[1]: v for k, v in parameters.items() if k.startswith("prepro__")} + config_params = { + k: v for k, v in parameters.items() if not k.startswith("prepro__") + } + preprocessor_params = { + k.split("__")[1]: v + for k, v in parameters.items() + if k.startswith("prepro__") + } if config_params: self.config_kwargs.update(config_params) @@ -108,9 +126,7 @@ def set_params(self, **parameters): for key, value in config_params.items(): setattr(self.config, key, value) else: - self.config = self.config_class( # type: ignore - **self.config_kwargs - ) + self.config = self.config_class(**self.config_kwargs) # type: ignore if preprocessor_params: self.preprocessor.set_params(**preprocessor_params) @@ -131,6 +147,8 @@ def build_model( lr_patience: int | None = None, lr_factor: float | None = None, weight_decay: float | None = None, + train_metrics: dict[str, Callable] | None = None, + val_metrics: dict[str, Callable] | None = None, dataloader_kwargs={}, ): """Builds the model using the provided training data. @@ -194,7 +212,9 @@ def build_model( **dataloader_kwargs, ) - self.data_module.preprocess_data(X, y, X_val, y_val, val_size=val_size, random_state=random_state) + self.data_module.preprocess_data( + X, y, X_val, y_val, val_size=val_size, random_state=random_state + ) num_classes = len(np.unique(np.array(y))) @@ -204,10 +224,16 @@ def build_model( config=self.config, cat_feature_info=self.data_module.cat_feature_info, num_feature_info=self.data_module.num_feature_info, - lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience), + lr_patience=( + lr_patience if lr_patience is not None else self.config.lr_patience + ), lr=lr if lr is not None else self.config.lr, lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor, - weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay), + weight_decay=( + weight_decay if weight_decay is not None else self.config.weight_decay + ), + train_metrics=train_metrics, + val_metrics=val_metrics, optimizer_type=self.optimizer_type, optimizer_args=self.optimizer_kwargs, ) @@ -236,7 +262,9 @@ def get_number_of_params(self, requires_grad=True): If the model has not been built prior to calling this method. """ if not self.built: - raise ValueError("The model must be built before the number of parameters can be estimated") + raise ValueError( + "The model must be built before the number of parameters can be estimated" + ) else: if requires_grad: return sum(p.numel() for p in self.task_model.parameters() if p.requires_grad) # type: ignore @@ -262,6 +290,8 @@ def fit( lr_factor: float | None = None, weight_decay: float | None = None, checkpoint_path="model_checkpoints", + train_metrics: dict[str, Callable] | None = None, + val_metrics: dict[str, Callable] | None = None, dataloader_kwargs={}, rebuild=True, **trainer_kwargs, @@ -332,6 +362,8 @@ def fit( lr_patience=lr_patience, lr_factor=lr_factor, weight_decay=weight_decay, + train_metrics=train_metrics, + val_metrics=val_metrics, dataloader_kwargs=dataloader_kwargs, ) @@ -369,20 +401,18 @@ def fit( best_model_path = checkpoint_callback.best_model_path if best_model_path: checkpoint = torch.load(best_model_path) - self.task_model.load_state_dict( # type: ignore - checkpoint["state_dict"] - ) + self.task_model.load_state_dict(checkpoint["state_dict"]) # type: ignore return self def predict(self, X, device=None): """Predicts target labels for the given input samples. - + Parameters ---------- X : DataFrame or array-like, shape (n_samples, n_features) The input samples for which to predict target values. - + Returns ------- predictions : ndarray, shape (n_samples,) @@ -391,25 +421,25 @@ def predict(self, X, device=None): # Ensure model and data module are initialized if self.task_model is None or self.data_module is None: raise ValueError("The model or data module has not been fitted yet.") - + # Preprocess the data using the data module - self.data_module.predict_dataset = self.data_module.preprocess_new_data(X) - + self.data_module.assign_predict_dataset(X) + # Set model to evaluation mode self.task_model.eval() - + # Perform inference using PyTorch Lightning's predict function logits_list = self.trainer.predict(self.task_model, self.data_module) - + # Concatenate predictions from all batches logits = torch.cat(logits_list, dim=0) - + # Check if ensemble is used if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble logits = logits.mean(dim=1) # Average over ensemble dimension if logits.dim() == 1: # Ensure correct shape logits = logits.unsqueeze(1) - + # Check the shape of the logits to determine binary or multi-class classification if logits.shape[1] == 1: # Binary classification @@ -419,19 +449,18 @@ def predict(self, X, device=None): # Multi-class classification probabilities = torch.softmax(logits, dim=1) predictions = torch.argmax(probabilities, dim=1) - + # Convert predictions to NumPy array and return return predictions.cpu().numpy() - - + def predict_proba(self, X, device=None): """Predicts class probabilities for the given input samples. - + Parameters ---------- X : DataFrame or array-like, shape (n_samples, n_features) The input samples for which to predict class probabilities. - + Returns ------- probabilities : ndarray, shape (n_samples, n_classes) @@ -440,35 +469,34 @@ def predict_proba(self, X, device=None): # Ensure model and data module are initialized if self.task_model is None or self.data_module is None: raise ValueError("The model or data module has not been fitted yet.") - + # Preprocess the data using the data module - self.data_module.predict_dataset = self.data_module.preprocess_new_data(X) - + self.data_module.assign_predict_dataset(X) + # Set model to evaluation mode self.task_model.eval() - + # Perform inference using PyTorch Lightning's predict function logits_list = self.trainer.predict(self.task_model, self.data_module) - + # Concatenate predictions from all batches logits = torch.cat(logits_list, dim=0) - + # Check if ensemble is used if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble logits = logits.mean(dim=1) # Average over ensemble dimension if logits.dim() == 1: # Ensure correct shape logits = logits.unsqueeze(1) - + # Compute probabilities if logits.shape[1] > 1: probabilities = torch.softmax(logits, dim=1) # Multi-class classification else: probabilities = torch.sigmoid(logits) # Binary classification - + # Convert probabilities to NumPy array and return return probabilities.cpu().numpy() - def evaluate(self, X, y_true, metrics=None): """Evaluate the model on the given data using specified metrics. @@ -610,9 +638,13 @@ def optimize_hparams( best_val_loss = float("inf") if X_val is not None and y_val is not None: - val_loss = self.evaluate(X_val, y_val, metrics={"Accuracy": (accuracy_score, False)})["Accuracy"] + val_loss = self.evaluate( + X_val, y_val, metrics={"Accuracy": (accuracy_score, False)} + )["Accuracy"] else: - val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"] + val_loss = self.trainer.validate(self.task_model, self.data_module)[0][ + "val_loss" + ] best_val_loss = val_loss best_epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore @@ -638,7 +670,9 @@ def _objective(hyperparams): if param_value in activation_mapper: setattr(self.config, key, activation_mapper[param_value]) else: - raise ValueError(f"Unknown activation function: {param_value}") + raise ValueError( + f"Unknown activation function: {param_value}" + ) else: setattr(self.config, key, param_value) @@ -647,11 +681,15 @@ def _objective(hyperparams): self.config.head_layer_sizes = head_layer_sizes[:head_layer_size_length] # Build the model with updated hyperparameters - self.build_model(X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs) + self.build_model( + X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs + ) # Dynamically set the early pruning threshold if prune_by_epoch: - early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss + early_pruning_threshold = ( + best_epoch_val_loss * 1.5 + ) # Prune based on specific epoch loss else: # Prune based on the best overall validation loss early_pruning_threshold = best_val_loss * 1.5 @@ -663,7 +701,9 @@ def _objective(hyperparams): # Fit the model (limit epochs for faster optimization) try: # Wrap the risky operation (model fitting) in a try-except block - self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False) + self.fit( + X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False + ) # Evaluate validation loss if X_val is not None and y_val is not None: @@ -671,7 +711,9 @@ def _objective(hyperparams): "Mean Squared Error" ] else: - val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"] + val_loss = self.trainer.validate(self.task_model, self.data_module)[ + 0 + ]["val_loss"] # Pruning based on validation loss at specific epoch epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore @@ -688,15 +730,21 @@ def _objective(hyperparams): except Exception as e: # Penalize the hyperparameter configuration with a large value - print(f"Error encountered during fit with hyperparameters {hyperparams}: {e}") - return best_val_loss * 100 # Large value to discourage this configuration + print( + f"Error encountered during fit with hyperparameters {hyperparams}: {e}" + ) + return ( + best_val_loss * 100 + ) # Large value to discourage this configuration # Perform Bayesian optimization using scikit-optimize result = gp_minimize(_objective, param_space, n_calls=time, random_state=42) # Update the model with the best-found hyperparameters best_hparams = result.x # type: ignore - head_layer_sizes = [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None + head_layer_sizes = ( + [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None + ) layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None # Iterate over the best hyperparameters found by optimization diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py index 1832086..1d524bb 100644 --- a/mambular/models/sklearn_base_lss.py +++ b/mambular/models/sklearn_base_lss.py @@ -9,11 +9,15 @@ from sklearn.base import BaseEstimator from sklearn.metrics import accuracy_score, mean_squared_error from skopt import gp_minimize - +from collections.abc import Callable from ..base_models.lightning_wrapper import TaskModel from ..data_utils.datamodule import MambularDataModule from ..preprocessing import Preprocessor -from ..utils.config_mapper import activation_mapper, get_search_space, round_to_nearest_16 +from ..utils.config_mapper import ( + activation_mapper, + get_search_space, + round_to_nearest_16, +) from ..utils.distributional_metrics import ( beta_brier_score, dirichlet_error, @@ -57,11 +61,15 @@ def __init__(self, model, config, **kwargs): ] self.config_kwargs = { - k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names and not k.startswith("optimizer") + k: v + for k, v in kwargs.items() + if k not in self.preprocessor_arg_names and not k.startswith("optimizer") } self.config = config(**self.config_kwargs) - preprocessor_kwargs = {k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names} + preprocessor_kwargs = { + k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names + } self.preprocessor = Preprocessor(**preprocessor_kwargs) self.task_model = None @@ -82,7 +90,8 @@ def __init__(self, model, config, **kwargs): self.optimizer_kwargs = { k: v for k, v in kwargs.items() - if k not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] + if k + not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] and k.startswith("optimizer_") } @@ -103,7 +112,10 @@ def get_params(self, deep=True): params.update(self.config_kwargs) if deep: - preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()} + preprocessor_params = { + "prepro__" + key: value + for key, value in self.preprocessor.get_params().items() + } params.update(preprocessor_params) return params @@ -121,8 +133,14 @@ def set_params(self, **parameters): self : object Estimator instance. """ - config_params = {k: v for k, v in parameters.items() if not k.startswith("prepro__")} - preprocessor_params = {k.split("__")[1]: v for k, v in parameters.items() if k.startswith("prepro__")} + config_params = { + k: v for k, v in parameters.items() if not k.startswith("prepro__") + } + preprocessor_params = { + k.split("__")[1]: v + for k, v in parameters.items() + if k.startswith("prepro__") + } if config_params: self.config_kwargs.update(config_params) @@ -130,9 +148,7 @@ def set_params(self, **parameters): for key, value in config_params.items(): setattr(self.config, key, value) else: - self.config = self.config_class( # type: ignore - **self.config_kwargs - ) + self.config = self.config_class(**self.config_kwargs) # type: ignore if preprocessor_params: self.preprocessor.set_params(**preprocessor_params) @@ -153,6 +169,8 @@ def build_model( lr_patience: int | None = None, lr_factor: float | None = None, weight_decay: float | None = None, + train_metrics: dict[str, Callable] | None = None, + val_metrics: dict[str, Callable] | None = None, dataloader_kwargs={}, ): """Builds the model using the provided training data. @@ -214,7 +232,9 @@ def build_model( **dataloader_kwargs, ) - self.data_module.preprocess_data(X, y, X_val, y_val, val_size=val_size, random_state=random_state) + self.data_module.preprocess_data( + X, y, X_val, y_val, val_size=val_size, random_state=random_state + ) self.task_model = TaskModel( model_class=self.base_model, # type: ignore @@ -224,10 +244,16 @@ def build_model( cat_feature_info=self.data_module.cat_feature_info, num_feature_info=self.data_module.num_feature_info, lr=lr if lr is not None else self.config.lr, - lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience), + lr_patience=( + lr_patience if lr_patience is not None else self.config.lr_patience + ), lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor, - weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay), + weight_decay=( + weight_decay if weight_decay is not None else self.config.weight_decay + ), lss=True, + train_metrics=train_metrics, + val_metrics=val_metrics, optimizer_type=self.optimizer_type, optimizer_args=self.optimizer_kwargs, ) @@ -256,7 +282,9 @@ def get_number_of_params(self, requires_grad=True): If the model has not been built prior to calling this method. """ if not self.built: - raise ValueError("The model must be built before the number of parameters can be estimated") + raise ValueError( + "The model must be built before the number of parameters can be estimated" + ) else: if requires_grad: return sum(p.numel() for p in self.task_model.parameters() if p.requires_grad) # type: ignore @@ -284,6 +312,8 @@ def fit( weight_decay: float | None = None, checkpoint_path="model_checkpoints", distributional_kwargs=None, + train_metrics: dict[str, Callable] | None = None, + val_metrics: dict[str, Callable] | None = None, dataloader_kwargs={}, rebuild=True, **trainer_kwargs, @@ -377,6 +407,8 @@ def fit( lr=lr, lr_patience=lr_patience, lr_factor=lr_factor, + train_metrics=train_metrics, + val_metrics=val_metrics, weight_decay=weight_decay, dataloader_kwargs=dataloader_kwargs, ) @@ -415,9 +447,7 @@ def fit( best_model_path = checkpoint_callback.best_model_path if best_model_path: checkpoint = torch.load(best_model_path) - self.task_model.load_state_dict( # type: ignore - checkpoint["state_dict"] - ) + self.task_model.load_state_dict(checkpoint["state_dict"]) # type: ignore return self @@ -440,21 +470,21 @@ def predict(self, X, device=None): raise ValueError("The model or data module has not been fitted yet.") # Preprocess the data using the data module - self.data_module.predict_dataset = self.data_module.preprocess_new_data(X) + self.data_module.assign_predict_dataset(X) # Set model to evaluation mode self.task_model.eval() # Perform inference using PyTorch Lightning's predict function predictions_list = self.trainer.predict(self.task_model, self.data_module) - + # Concatenate predictions from all batches predictions = torch.cat(predictions_list, dim=0) - + # Check if ensemble is used if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble predictions = predictions.mean(dim=1) # Average over ensemble dimension - + # Convert predictions to NumPy array and return return predictions.cpu().numpy() @@ -487,7 +517,9 @@ def evaluate(self, X, y_true, metrics=None, distribution_family=None): """ # Infer distribution family from model settings if not provided if distribution_family is None: - distribution_family = getattr(self.task_model, "distribution_family", "normal") + distribution_family = getattr( + self.task_model, "distribution_family", "normal" + ) # Setup default metrics if none are provided if metrics is None: @@ -523,7 +555,10 @@ def get_default_metrics(self, distribution_family): "normal": { "MSE": lambda y, pred: mean_squared_error(y, pred[:, 0]), "CRPS": lambda y, pred: np.mean( - [ps.crps_gaussian(y[i], mu=pred[i, 0], sig=np.sqrt(pred[i, 1])) for i in range(len(y))] + [ + ps.crps_gaussian(y[i], mu=pred[i, 0], sig=np.sqrt(pred[i, 1])) + for i in range(len(y)) + ] ), }, "poisson": {"Poisson Deviance": poisson_deviance}, @@ -555,9 +590,7 @@ def score(self, X, y, metric="NLL"): The score calculated using the specified metric. """ predictions = self.predict(X) - score = self.task_model.family.evaluate_nll( # type: ignore - y, predictions - ) + score = self.task_model.family.evaluate_nll(y, predictions) # type: ignore return score def optimize_hparams( @@ -625,7 +658,9 @@ def optimize_hparams( y_val, ) else: - val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"] + val_loss = self.trainer.validate(self.task_model, self.data_module)[0][ + "val_loss" + ] best_val_loss = val_loss best_epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore @@ -651,7 +686,9 @@ def _objective(hyperparams): if param_value in activation_mapper: setattr(self.config, key, activation_mapper[param_value]) else: - raise ValueError(f"Unknown activation function: {param_value}") + raise ValueError( + f"Unknown activation function: {param_value}" + ) else: setattr(self.config, key, param_value) @@ -660,11 +697,15 @@ def _objective(hyperparams): self.config.head_layer_sizes = head_layer_sizes[:head_layer_size_length] # Build the model with updated hyperparameters - self.build_model(X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs) + self.build_model( + X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs + ) # Dynamically set the early pruning threshold if prune_by_epoch: - early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss + early_pruning_threshold = ( + best_epoch_val_loss * 1.5 + ) # Prune based on specific epoch loss else: # Prune based on the best overall validation loss early_pruning_threshold = best_val_loss * 1.5 @@ -686,11 +727,13 @@ def _objective(hyperparams): # Evaluate validation loss if X_val is not None and y_val is not None: - val_loss = self.evaluate(X_val, y_val, metrics={"Mean Squared Error": mean_squared_error})[ - "Mean Squared Error" - ] + val_loss = self.evaluate( + X_val, y_val, metrics={"Mean Squared Error": mean_squared_error} + )["Mean Squared Error"] else: - val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"] + val_loss = self.trainer.validate(self.task_model, self.data_module)[ + 0 + ]["val_loss"] # Pruning based on validation loss at specific epoch epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore @@ -707,15 +750,21 @@ def _objective(hyperparams): except Exception as e: # Penalize the hyperparameter configuration with a large value - print(f"Error encountered during fit with hyperparameters {hyperparams}: {e}") - return best_val_loss * 100 # Large value to discourage this configuration + print( + f"Error encountered during fit with hyperparameters {hyperparams}: {e}" + ) + return ( + best_val_loss * 100 + ) # Large value to discourage this configuration # Perform Bayesian optimization using scikit-optimize result = gp_minimize(_objective, param_space, n_calls=time, random_state=42) # Update the model with the best-found hyperparameters best_hparams = result.x # type: ignore - head_layer_sizes = [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None + head_layer_sizes = ( + [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None + ) layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None # Iterate over the best hyperparameters found by optimization diff --git a/mambular/models/sklearn_base_regressor.py b/mambular/models/sklearn_base_regressor.py index 2a747b9..2482ba9 100644 --- a/mambular/models/sklearn_base_regressor.py +++ b/mambular/models/sklearn_base_regressor.py @@ -7,11 +7,15 @@ from sklearn.base import BaseEstimator from sklearn.metrics import mean_squared_error from skopt import gp_minimize - +from collections.abc import Callable from ..base_models.lightning_wrapper import TaskModel from ..data_utils.datamodule import MambularDataModule from ..preprocessing import Preprocessor -from ..utils.config_mapper import activation_mapper, get_search_space, round_to_nearest_16 +from ..utils.config_mapper import ( + activation_mapper, + get_search_space, + round_to_nearest_16, +) class SklearnBaseRegressor(BaseEstimator): @@ -34,11 +38,15 @@ def __init__(self, model, config, **kwargs): ] self.config_kwargs = { - k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names and not k.startswith("optimizer") + k: v + for k, v in kwargs.items() + if k not in self.preprocessor_arg_names and not k.startswith("optimizer") } self.config = config(**self.config_kwargs) - preprocessor_kwargs = {k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names} + preprocessor_kwargs = { + k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names + } self.preprocessor = Preprocessor(**preprocessor_kwargs) self.base_model = model @@ -58,7 +66,8 @@ def __init__(self, model, config, **kwargs): self.optimizer_kwargs = { k: v for k, v in kwargs.items() - if k not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] + if k + not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] and k.startswith("optimizer_") } @@ -79,7 +88,10 @@ def get_params(self, deep=True): params.update(self.config_kwargs) if deep: - preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()} + preprocessor_params = { + "prepro__" + key: value + for key, value in self.preprocessor.get_params().items() + } params.update(preprocessor_params) return params @@ -97,8 +109,14 @@ def set_params(self, **parameters): self : object Estimator instance. """ - config_params = {k: v for k, v in parameters.items() if not k.startswith("prepro__")} - preprocessor_params = {k.split("__")[1]: v for k, v in parameters.items() if k.startswith("prepro__")} + config_params = { + k: v for k, v in parameters.items() if not k.startswith("prepro__") + } + preprocessor_params = { + k.split("__")[1]: v + for k, v in parameters.items() + if k.startswith("prepro__") + } if config_params: self.config_kwargs.update(config_params) @@ -106,9 +124,7 @@ def set_params(self, **parameters): for key, value in config_params.items(): setattr(self.config, key, value) else: - self.config = self.config_class( # type: ignore - **self.config_kwargs - ) + self.config = self.config_class(**self.config_kwargs) # type: ignore if preprocessor_params: self.preprocessor.set_params(**preprocessor_params) @@ -129,6 +145,8 @@ def build_model( lr_patience: int | None = None, lr_factor: float | None = None, weight_decay: float | None = None, + train_metrics: dict[str, Callable] | None = None, + val_metrics: dict[str, Callable] | None = None, dataloader_kwargs={}, ): """Builds the model using the provided training data. @@ -192,7 +210,9 @@ def build_model( **dataloader_kwargs, ) - self.data_module.preprocess_data(X, y, X_val, y_val, val_size=val_size, random_state=random_state) + self.data_module.preprocess_data( + X, y, X_val, y_val, val_size=val_size, random_state=random_state + ) self.task_model = TaskModel( model_class=self.base_model, # type: ignore @@ -200,9 +220,15 @@ def build_model( cat_feature_info=self.data_module.cat_feature_info, num_feature_info=self.data_module.num_feature_info, lr=lr if lr is not None else self.config.lr, - lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience), + lr_patience=( + lr_patience if lr_patience is not None else self.config.lr_patience + ), lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor, - weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay), + weight_decay=( + weight_decay if weight_decay is not None else self.config.weight_decay + ), + train_metrics=train_metrics, + val_metrics=val_metrics, optimizer_type=self.optimizer_type, optimizer_args=self.optimizer_kwargs, ) @@ -231,7 +257,9 @@ def get_number_of_params(self, requires_grad=True): If the model has not been built prior to calling this method. """ if not self.built: - raise ValueError("The model must be built before the number of parameters can be estimated") + raise ValueError( + "The model must be built before the number of parameters can be estimated" + ) else: if requires_grad: return sum(p.numel() for p in self.task_model.parameters() if p.requires_grad) # type: ignore @@ -258,6 +286,8 @@ def fit( weight_decay: float | None = None, checkpoint_path="model_checkpoints", dataloader_kwargs={}, + train_metrics: dict[str, Callable] | None = None, + val_metrics: dict[str, Callable] | None = None, rebuild=True, **trainer_kwargs, ): @@ -326,6 +356,8 @@ def fit( lr_factor=lr_factor, weight_decay=weight_decay, dataloader_kwargs=dataloader_kwargs, + train_metrics=train_metrics, + val_metrics=val_metrics, ) else: @@ -362,9 +394,7 @@ def fit( best_model_path = checkpoint_callback.best_model_path if best_model_path: checkpoint = torch.load(best_model_path) - self.task_model.load_state_dict( # type: ignore - checkpoint["state_dict"] - ) + self.task_model.load_state_dict(checkpoint["state_dict"]) # type: ignore return self @@ -387,21 +417,21 @@ def predict(self, X, device=None): raise ValueError("The model or data module has not been fitted yet.") # Preprocess the data using the data module - self.data_module.predict_dataset = self.data_module.preprocess_new_data(X) + self.data_module.assign_predict_dataset(X) # Set model to evaluation mode self.task_model.eval() # Perform inference using PyTorch Lightning's predict function predictions_list = self.trainer.predict(self.task_model, self.data_module) - + # Concatenate predictions from all batches predictions = torch.cat(predictions_list, dim=0) - + # Check if ensemble is used if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble predictions = predictions.mean(dim=1) # Average over ensemble dimension - + # Convert predictions to NumPy array and return return predictions.cpu().numpy() @@ -522,11 +552,13 @@ def optimize_hparams( best_val_loss = float("inf") if X_val is not None and y_val is not None: - val_loss = self.evaluate(X_val, y_val, metrics={"Mean Squared Error": mean_squared_error})[ - "Mean Squared Error" - ] + val_loss = self.evaluate( + X_val, y_val, metrics={"Mean Squared Error": mean_squared_error} + )["Mean Squared Error"] else: - val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"] + val_loss = self.trainer.validate(self.task_model, self.data_module)[0][ + "val_loss" + ] best_val_loss = val_loss best_epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore @@ -552,7 +584,9 @@ def _objective(hyperparams): if param_value in activation_mapper: setattr(self.config, key, activation_mapper[param_value]) else: - raise ValueError(f"Unknown activation function: {param_value}") + raise ValueError( + f"Unknown activation function: {param_value}" + ) else: setattr(self.config, key, param_value) @@ -561,11 +595,15 @@ def _objective(hyperparams): self.config.head_layer_sizes = head_layer_sizes[:head_layer_size_length] # Build the model with updated hyperparameters - self.build_model(X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs) + self.build_model( + X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs + ) # Dynamically set the early pruning threshold if prune_by_epoch: - early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss + early_pruning_threshold = ( + best_epoch_val_loss * 1.5 + ) # Prune based on specific epoch loss else: # Prune based on the best overall validation loss early_pruning_threshold = best_val_loss * 1.5 @@ -576,15 +614,19 @@ def _objective(hyperparams): try: # Wrap the risky operation (model fitting) in a try-except block - self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False) + self.fit( + X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False + ) # Evaluate validation loss if X_val is not None and y_val is not None: - val_loss = self.evaluate(X_val, y_val, metrics={"Mean Squared Error": mean_squared_error})[ - "Mean Squared Error" - ] + val_loss = self.evaluate( + X_val, y_val, metrics={"Mean Squared Error": mean_squared_error} + )["Mean Squared Error"] else: - val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"] + val_loss = self.trainer.validate(self.task_model, self.data_module)[ + 0 + ]["val_loss"] # Pruning based on validation loss at specific epoch epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore @@ -601,15 +643,21 @@ def _objective(hyperparams): except Exception as e: # Penalize the hyperparameter configuration with a large value - print(f"Error encountered during fit with hyperparameters {hyperparams}: {e}") - return best_val_loss * 100 # Large value to discourage this configuration + print( + f"Error encountered during fit with hyperparameters {hyperparams}: {e}" + ) + return ( + best_val_loss * 100 + ) # Large value to discourage this configuration # Perform Bayesian optimization using scikit-optimize result = gp_minimize(_objective, param_space, n_calls=time, random_state=42) # Update the model with the best-found hyperparameters best_hparams = result.x # type: ignore - head_layer_sizes = [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None + head_layer_sizes = ( + [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None + ) layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None # Iterate over the best hyperparameters found by optimization From c5d2931dac4ea9bd044d8064b68b20ad37f05252 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 17 Jan 2025 16:52:08 +0100 Subject: [PATCH 09/18] fix ensemble prediction bug --- mambular/models/sklearn_base_classifier.py | 4 ++-- mambular/models/sklearn_base_lss.py | 11 +++++++---- mambular/models/sklearn_base_regressor.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/mambular/models/sklearn_base_classifier.py b/mambular/models/sklearn_base_classifier.py index 5b8a195..c1c264b 100644 --- a/mambular/models/sklearn_base_classifier.py +++ b/mambular/models/sklearn_base_classifier.py @@ -435,7 +435,7 @@ def predict(self, X, device=None): logits = torch.cat(logits_list, dim=0) # Check if ensemble is used - if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble + if getattr(self.base_model, "returns_ensemble", False): # If using ensemble logits = logits.mean(dim=1) # Average over ensemble dimension if logits.dim() == 1: # Ensure correct shape logits = logits.unsqueeze(1) @@ -483,7 +483,7 @@ def predict_proba(self, X, device=None): logits = torch.cat(logits_list, dim=0) # Check if ensemble is used - if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble + if getattr(self.base_model, "returns_ensemble", False): # If using ensemble logits = logits.mean(dim=1) # Average over ensemble dimension if logits.dim() == 1: # Ensure correct shape logits = logits.unsqueeze(1) diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py index 1d524bb..f01cf6f 100644 --- a/mambular/models/sklearn_base_lss.py +++ b/mambular/models/sklearn_base_lss.py @@ -451,7 +451,7 @@ def fit( return self - def predict(self, X, device=None): + def predict(self, X, raw=False, device=None): """Predicts target values for the given input samples. Parameters @@ -482,11 +482,14 @@ def predict(self, X, device=None): predictions = torch.cat(predictions_list, dim=0) # Check if ensemble is used - if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble + if getattr(self.base_model, "returns_ensemble", False): # If using ensemble predictions = predictions.mean(dim=1) # Average over ensemble dimension - # Convert predictions to NumPy array and return - return predictions.cpu().numpy() + if not raw: + result = self.task_model.family(predictions).cpu().numpy() # type: ignore + return result + else: + return predictions.cpu().numpy() def evaluate(self, X, y_true, metrics=None, distribution_family=None): """Evaluate the model on the given data using specified metrics. diff --git a/mambular/models/sklearn_base_regressor.py b/mambular/models/sklearn_base_regressor.py index 2482ba9..4964b5c 100644 --- a/mambular/models/sklearn_base_regressor.py +++ b/mambular/models/sklearn_base_regressor.py @@ -429,7 +429,7 @@ def predict(self, X, device=None): predictions = torch.cat(predictions_list, dim=0) # Check if ensemble is used - if hasattr(self.task_model.base_model, "returns_ensemble"): # If using ensemble + if getattr(self.base_model, "returns_ensemble", False): # If using ensemble predictions = predictions.mean(dim=1) # Average over ensemble dimension # Convert predictions to NumPy array and return From 850a5ccbb792f66212d6f3f1814f2f2ee3538b0a Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 17 Jan 2025 18:46:42 +0100 Subject: [PATCH 10/18] include sentence/word embeddings as preprocessing techniques for categorical --- .../arch_utils/layer_utils/embedding_layer.py | 50 +++++--- mambular/data_utils/datamodule.py | 12 +- mambular/preprocessing/prepro_utils.py | 67 +++++++++- mambular/preprocessing/preprocessor.py | 115 ++++++++++++++---- 4 files changed, 190 insertions(+), 54 deletions(-) diff --git a/mambular/arch_utils/layer_utils/embedding_layer.py b/mambular/arch_utils/layer_utils/embedding_layer.py index 76afd1e..0fb93fd 100644 --- a/mambular/arch_utils/layer_utils/embedding_layer.py +++ b/mambular/arch_utils/layer_utils/embedding_layer.py @@ -22,8 +22,12 @@ def __init__(self, num_feature_info, cat_feature_info, config): super().__init__() self.d_model = getattr(config, "d_model", 128) - self.embedding_activation = getattr(config, "embedding_activation", nn.Identity()) - self.layer_norm_after_embedding = getattr(config, "layer_norm_after_embedding", False) + self.embedding_activation = getattr( + config, "embedding_activation", nn.Identity() + ) + self.layer_norm_after_embedding = getattr( + config, "layer_norm_after_embedding", False + ) self.use_cls = getattr(config, "use_cls", False) self.cls_position = getattr(config, "cls_position", 0) self.embedding_dropout = ( @@ -71,22 +75,26 @@ def __init__(self, num_feature_info, cat_feature_info, config): # for splines and other embeddings # splines followed by linear if n_knots actual knots is less than the defined knots else: - raise ValueError("Invalid embedding_type. Choose from 'linear', 'ndt', or 'plr'.") + raise ValueError( + "Invalid embedding_type. Choose from 'linear', 'ndt', or 'plr'." + ) self.cat_embeddings = nn.ModuleList( [ - nn.Sequential( - nn.Embedding(feature_info["categories"] + 1, self.d_model), - self.embedding_activation, - ) - if feature_info["dimension"] == 1 - else nn.Sequential( - nn.Linear( - feature_info["dimension"], - self.d_model, - bias=self.embedding_bias, - ), - self.embedding_activation, + ( + nn.Sequential( + nn.Embedding(feature_info["categories"] + 1, self.d_model), + self.embedding_activation, + ) + if feature_info["dimension"] == 1 + else nn.Sequential( + nn.Linear( + feature_info["dimension"], + self.d_model, + bias=self.embedding_bias, + ), + self.embedding_activation, + ) ) for feature_name, feature_info in cat_feature_info.items() ] @@ -124,9 +132,7 @@ def forward(self, num_features=None, cat_features=None): # Class token initialization if self.use_cls: batch_size = ( - cat_features[0].size( # type: ignore - 0 - ) + cat_features[0].size(0) # type: ignore if cat_features != [] else num_features[0].size(0) # type: ignore ) # type: ignore @@ -134,7 +140,9 @@ def forward(self, num_features=None, cat_features=None): # Process categorical embeddings if self.cat_embeddings and cat_features is not None: - cat_embeddings = [emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)] + cat_embeddings = [ + emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings) + ] cat_embeddings = torch.stack(cat_embeddings, dim=1) cat_embeddings = torch.squeeze(cat_embeddings, dim=2) if self.layer_norm_after_embedding: @@ -182,7 +190,9 @@ def forward(self, num_features=None, cat_features=None): elif self.cls_position == 1: x = torch.cat([x, cls_tokens], dim=1) # type: ignore else: - raise ValueError("Invalid cls_position value. It should be either 0 or 1.") + raise ValueError( + "Invalid cls_position value. It should be either 0 or 1." + ) # Apply dropout to embeddings if specified in config if self.embedding_dropout is not None: diff --git a/mambular/data_utils/datamodule.py b/mambular/data_utils/datamodule.py index 42c6fb4..b6bfb32 100644 --- a/mambular/data_utils/datamodule.py +++ b/mambular/data_utils/datamodule.py @@ -153,8 +153,10 @@ def setup(self, stage: str): for key in self.cat_feature_info: # type: ignore dtype = ( torch.float32 - if "onehot" - in self.cat_feature_info[key]["preprocessing"] # type: ignore + if any( + x in self.cat_feature_info[key]["preprocessing"] + for x in ["onehot", "pretrained"] + ) else torch.long ) @@ -226,8 +228,10 @@ def preprocess_new_data(self, X): for key in self.cat_feature_info: # type: ignore dtype = ( torch.float32 - if "onehot" - in self.cat_feature_info[key]["preprocessing"] # type: ignore + if any( + x in self.cat_feature_info[key]["preprocessing"] + for x in ["onehot", "pretrained"] + ) else torch.long ) cat_key = "cat_" + key # Assuming categorical keys are prefixed with 'cat_' diff --git a/mambular/preprocessing/prepro_utils.py b/mambular/preprocessing/prepro_utils.py index 704091f..806d3f5 100644 --- a/mambular/preprocessing/prepro_utils.py +++ b/mambular/preprocessing/prepro_utils.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd from sklearn.base import BaseEstimator, TransformerMixin +from sentence_transformers import SentenceTransformer class CustomBinner(TransformerMixin): @@ -10,6 +11,7 @@ def __init__(self, bins): def fit(self, X, y=None): # Fit doesn't need to do anything as we are directly using provided bins + self.n_features_in_ = 1 return self def transform(self, X): @@ -56,7 +58,10 @@ def fit(self, X, y=None): self: Returns the instance itself. """ # Fit should determine the mapping from original categories to sequential integers starting from 0 - self.mapping_ = [{category: i + 1 for i, category in enumerate(np.unique(col))} for col in X.T] + self.mapping_ = [ + {category: i + 1 for i, category in enumerate(np.unique(col))} + for col in X.T + ] for mapping in self.mapping_: mapping[None] = 0 # Assign 0 to unknown values return self @@ -71,7 +76,12 @@ def transform(self, X): X_transformed (ndarray of shape (n_samples, n_features)): The transformed data with integer values. """ # Transform the categories to their mapped integer values - X_transformed = np.array([[self.mapping_[col].get(value, 0) for col, value in enumerate(row)] for row in X]) + X_transformed = np.array( + [ + [self.mapping_[col].get(value, 0) for col, value in enumerate(row)] + for row in X + ] + ) return X_transformed def get_feature_names_out(self, input_features=None): @@ -113,7 +123,9 @@ def fit(self, X, y=None): Returns: self: Returns the instance itself. """ - self.max_bins_ = np.max(X, axis=0).astype(int) + 1 # Find the maximum bin index for each feature + self.max_bins_ = ( + np.max(X, axis=0).astype(int) + 1 + ) # Find the maximum bin index for each feature return self def transform(self, X): @@ -172,6 +184,7 @@ def fit(self, X, y=None): Returns: self: Returns the instance itself. """ + self.n_features_in_ = 1 return self def transform(self, X): @@ -195,7 +208,9 @@ def get_feature_names_out(self, input_features=None): feature_names (array of shape (n_features,)): The original feature names. """ if input_features is None: - raise ValueError("input_features must be provided to generate feature names.") + raise ValueError( + "input_features must be provided to generate feature names." + ) return np.array(input_features) @@ -203,7 +218,51 @@ class ToFloatTransformer(TransformerMixin, BaseEstimator): """A transformer that converts input data to float type.""" def fit(self, X, y=None): + self.n_features_in_ = 1 return self def transform(self, X): return X.astype(float) + + +class LanguageEmbeddingTransformer(TransformerMixin, BaseEstimator): + """A transformer that encodes categorical text features into embeddings using a pre-trained language model.""" + + def __init__(self, model_name="paraphrase-MiniLM-L3-v2"): + """ + Initializes the transformer with a language embedding model. + + Parameters: + - model_name (str): The name of the SentenceTransformer model to use. + """ + self.model_name = model_name + self.model = SentenceTransformer(model_name) + + def fit(self, X, y=None): + """ + Fit method (not required for a transformer but included for compatibility). + """ + self.n_features_in_ = X.shape[1] if len(X.shape) > 1 else 1 + return self + + def transform(self, X): + """ + Transforms input categorical text features into numerical embeddings. + + Parameters: + - X: A 1D or 2D array-like of categorical text features. + + Returns: + - A 2D numpy array with embeddings for each text input. + """ + if isinstance(X, np.ndarray): + X = ( + X.flatten().astype(str).tolist() + ) # Convert to a list of strings if passed as an array + elif isinstance(X, list): + X = [str(x) for x in X] # Ensure everything is a string + + embeddings = self.model.encode( + X, convert_to_numpy=True + ) # Get sentence embeddings + return embeddings diff --git a/mambular/preprocessing/preprocessor.py b/mambular/preprocessing/preprocessor.py index 6d079db..be20529 100644 --- a/mambular/preprocessing/preprocessor.py +++ b/mambular/preprocessing/preprocessor.py @@ -19,7 +19,14 @@ from .basis_expansion import RBFExpansion, SigmoidExpansion, SplineExpansion from .ple_encoding import PLE -from .prepro_utils import ContinuousOrdinalEncoder, CustomBinner, NoTransformer, OneHotFromOrdinal, ToFloatTransformer +from .prepro_utils import ( + ContinuousOrdinalEncoder, + CustomBinner, + NoTransformer, + OneHotFromOrdinal, + ToFloatTransformer, + LanguageEmbeddingTransformer, +) class Preprocessor: @@ -104,10 +111,14 @@ def __init__( ): self.n_bins = n_bins self.numerical_preprocessing = ( - numerical_preprocessing.lower() if numerical_preprocessing is not None else "none" + numerical_preprocessing.lower() + if numerical_preprocessing is not None + else "none" ) self.categorical_preprocessing = ( - categorical_preprocessing.lower() if categorical_preprocessing is not None else "none" + categorical_preprocessing.lower() + if categorical_preprocessing is not None + else "none" ) if self.numerical_preprocessing not in [ "ple", @@ -131,8 +142,15 @@ def __init__( 'rbf', 'sigmoid', or 'None'." ) - if self.categorical_preprocessing not in ["int", "one-hot", "none"]: - raise ValueError("invalid categorical_preprocessing value. Supported values are 'int' and 'one-hot'") + if self.categorical_preprocessing not in [ + "int", + "one-hot", + "pretrained", + "none", + ]: + raise ValueError( + "invalid categorical_preprocessing value. Supported values are 'int', 'pretrained', 'none' and 'one-hot'" + ) self.use_decision_tree_bins = use_decision_tree_bins self.column_transformer = None @@ -223,13 +241,19 @@ def _detect_column_types(self, X): numerical_features.append(col) else: if isinstance(self.cat_cutoff, float): - cutoff_condition = (num_unique_values / total_samples) < self.cat_cutoff + cutoff_condition = ( + num_unique_values / total_samples + ) < self.cat_cutoff elif isinstance(self.cat_cutoff, int): cutoff_condition = num_unique_values < self.cat_cutoff else: - raise ValueError("cat_cutoff should be either a float or an integer.") + raise ValueError( + "cat_cutoff should be either a float or an integer." + ) - if X[col].dtype.kind not in "iufc" or (X[col].dtype.kind == "i" and cutoff_condition): + if X[col].dtype.kind not in "iufc" or ( + X[col].dtype.kind == "i" and cutoff_condition + ): categorical_features.append(col) else: numerical_features.append(col) @@ -256,8 +280,6 @@ def fit(self, X, y=None): X = pd.DataFrame(X) numerical_features, categorical_features = self._detect_column_types(X) - print("Numerical features:", numerical_features) - print("Categorical features:", categorical_features) transformers = [] if numerical_features: @@ -296,7 +318,11 @@ def fit(self, X, y=None): ( "discretizer", KBinsDiscretizer( - n_bins=(bins if isinstance(bins, int) else len(bins) - 1), + n_bins=( + bins + if isinstance(bins, int) + else len(bins) - 1 + ), encode="ordinal", strategy=self.binning_strategy, # type: ignore subsample=200_000 if len(X) > 200_000 else None, @@ -325,13 +351,17 @@ def fit(self, X, y=None): numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.numerical_preprocessing == "minmax": - numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(-1, 1))) + ) elif self.numerical_preprocessing == "quantile": numeric_transformer_steps.append( ( "quantile", - QuantileTransformer(n_quantiles=self.n_bins, random_state=101), + QuantileTransformer( + n_quantiles=self.n_bins, random_state=101 + ), ) ) @@ -339,7 +369,9 @@ def fit(self, X, y=None): if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": - numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(-1, 1))) + ) numeric_transformer_steps.append( ( "polynomial", @@ -354,7 +386,9 @@ def fit(self, X, y=None): if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": - numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(-1, 1))) + ) numeric_transformer_steps.append( ( "splines", @@ -373,7 +407,9 @@ def fit(self, X, y=None): if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": - numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(-1, 1))) + ) numeric_transformer_steps.append( ( "rbf", @@ -390,7 +426,9 @@ def fit(self, X, y=None): if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": - numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(-1, 1))) + ) numeric_transformer_steps.append( ( "sigmoid", @@ -404,8 +442,12 @@ def fit(self, X, y=None): ) elif self.numerical_preprocessing == "ple": - numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) - numeric_transformer_steps.append(("ple", PLE(n_bins=self.n_bins, task=self.task))) + numeric_transformer_steps.append( + ("minmax", MinMaxScaler(feature_range=(-1, 1))) + ) + numeric_transformer_steps.append( + ("ple", PLE(n_bins=self.n_bins, task=self.task)) + ) elif self.numerical_preprocessing == "box-cox": numeric_transformer_steps.append( @@ -463,13 +505,26 @@ def fit(self, X, y=None): ("none", NoTransformer()), ] ) + elif self.categorical_preprocessing == "pretrained": + categorical_transformer = Pipeline( + [ + ("imputer", SimpleImputer(strategy="most_frequent")), + ("pretrained", LanguageEmbeddingTransformer()), + ] + ) else: - raise ValueError(f"Unknown categorical_preprocessing type: {self.categorical_preprocessing}") + raise ValueError( + f"Unknown categorical_preprocessing type: {self.categorical_preprocessing}" + ) # Append the transformer for the current categorical feature - transformers.append((f"cat_{feature}", categorical_transformer, [feature])) + transformers.append( + (f"cat_{feature}", categorical_transformer, [feature]) + ) - self.column_transformer = ColumnTransformer(transformers=transformers, remainder="passthrough") + self.column_transformer = ColumnTransformer( + transformers=transformers, remainder="passthrough" + ) self.column_transformer.fit(X, y) self.fitted = True @@ -495,13 +550,17 @@ def _get_decision_tree_bins(self, X, y, numerical_features): bins = [] for feature in numerical_features: tree_model = ( - DecisionTreeClassifier(max_depth=3) if y.dtype.kind in "bi" else DecisionTreeRegressor(max_depth=3) + DecisionTreeClassifier(max_depth=3) + if y.dtype.kind in "bi" + else DecisionTreeRegressor(max_depth=3) ) tree_model.fit(X[[feature]], y) thresholds = tree_model.tree_.threshold[tree_model.tree_.feature != -2] # type: ignore bin_edges = np.sort(np.unique(thresholds)) - bins.append(np.concatenate(([X[feature].min()], bin_edges, [X[feature].max()]))) + bins.append( + np.concatenate(([X[feature].min()], bin_edges, [X[feature].max()])) + ) return bins def transform(self, X): @@ -657,7 +716,9 @@ def get_feature_info(self, verbose=True): "categories": None, # Numerical features don't have categories } if verbose: - print(f"Numerical Feature: {feature_name}, Info: {numerical_feature_info[feature_name]}") + print( + f"Numerical Feature: {feature_name}, Info: {numerical_feature_info[feature_name]}" + ) # Categorical features elif "continuous_ordinal" in steps: @@ -709,7 +770,9 @@ def get_feature_info(self, verbose=True): "categories": None, # Numerical features don't have categories } if verbose: - print(f"Feature: {feature_name}, Info: {preprocessing_type}, Dimension: {dimension}") + print( + f"Categorical Feature: {feature_name}, Info: {preprocessing_type}, Dimension: {dimension}" + ) if verbose: print("-" * 50) From 50a3883ca80f621556d09dcfd9a4e6ce0c7d8be9 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Fri, 17 Jan 2025 18:53:48 +0100 Subject: [PATCH 11/18] make sentence_transformer input optional dependency --- mambular/preprocessing/prepro_utils.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/mambular/preprocessing/prepro_utils.py b/mambular/preprocessing/prepro_utils.py index 806d3f5..1d5e41b 100644 --- a/mambular/preprocessing/prepro_utils.py +++ b/mambular/preprocessing/prepro_utils.py @@ -1,7 +1,6 @@ import numpy as np import pandas as pd from sklearn.base import BaseEstimator, TransformerMixin -from sentence_transformers import SentenceTransformer class CustomBinner(TransformerMixin): @@ -228,20 +227,29 @@ def transform(self, X): class LanguageEmbeddingTransformer(TransformerMixin, BaseEstimator): """A transformer that encodes categorical text features into embeddings using a pre-trained language model.""" - def __init__(self, model_name="paraphrase-MiniLM-L3-v2"): + def __init__(self, model_name="paraphrase-MiniLM-L3-v2", model=None): """ Initializes the transformer with a language embedding model. Parameters: - - model_name (str): The name of the SentenceTransformer model to use. + - model_name (str): The name of the SentenceTransformer model to use (if model is None). + - model (object, optional): A preloaded SentenceTransformer model instance. """ self.model_name = model_name - self.model = SentenceTransformer(model_name) + self.model = model # Allow user to pass a preloaded model + + if self.model is None: + try: + from sentence_transformers import SentenceTransformer + + self.model = SentenceTransformer(model_name) + except ImportError: + raise ImportError( + "sentence-transformers is not installed. Install it via `pip install sentence-transformers` or provide a preloaded model." + ) def fit(self, X, y=None): - """ - Fit method (not required for a transformer but included for compatibility). - """ + """Fit method (not required for a transformer but included for compatibility).""" self.n_features_in_ = X.shape[1] if len(X.shape) > 1 else 1 return self From fac6a1fe0367a1f4d8ef171381867da95235d5d8 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Sat, 18 Jan 2025 12:41:34 +0100 Subject: [PATCH 12/18] include encoding function to create embeddings --- mambular/base_models/basemodel.py | 36 +++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/mambular/base_models/basemodel.py b/mambular/base_models/basemodel.py index a3b1821..fd21852 100644 --- a/mambular/base_models/basemodel.py +++ b/mambular/base_models/basemodel.py @@ -33,7 +33,11 @@ def save_hyperparameters(self, ignore=[]): List of keys to ignore while saving hyperparameters, by default []. """ # Filter the config and extra hparams for ignored keys - config_hparams = {k: v for k, v in vars(self.config).items() if k not in ignore} if self.config else {} + config_hparams = ( + {k: v for k, v in vars(self.config).items() if k not in ignore} + if self.config + else {} + ) extra_hparams = {k: v for k, v in self.extra_hparams.items() if k not in ignore} config_hparams.update(extra_hparams) @@ -148,7 +152,9 @@ def initialize_pooling_layers(self, config, n_inputs): """Initializes the layers needed for learnable pooling methods based on self.hparams.pooling_method.""" if self.hparams.pooling_method == "learned_flatten": # Flattening + Linear layer - self.learned_flatten_pooling = nn.Linear(n_inputs * config.dim_feedforward, config.dim_feedforward) + self.learned_flatten_pooling = nn.Linear( + n_inputs * config.dim_feedforward, config.dim_feedforward + ) elif self.hparams.pooling_method == "attention": # Attention-based pooling with learnable attention weights @@ -216,3 +222,29 @@ def pool_sequence(self, out): return out else: raise ValueError(f"Invalid pooling method: {self.hparams.pooling_method}") + + def encode(self, num_features, cat_features): + if not hasattr(self, "embedding_layer"): + raise ValueError("The model does not have an embedding layer") + + # Check if at least one of the contextualized embedding methods exists + valid_layers = ["mamba", "rnn", "lstm", "encoder"] + available_layer = next( + (attr for attr in valid_layers if hasattr(self, attr)), None + ) + + if not available_layer: + raise ValueError("The model does not generate contextualized embeddings") + + # Get the actual layer and call it + x = self.embedding_layer(num_features=num_features, cat_features=cat_features) + + if getattr(self.hparams, "shuffle_embeddings", False): + x = x[:, self.perm, :] + + layer = getattr(self, available_layer) + if available_layer == "rnn": + embeddings, _ = layer(x) + else: + embeddings = layer(x) + return embeddings From d08af31ee86da631ebcc3753ac9878f8e96240fb Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Sat, 18 Jan 2025 12:41:56 +0100 Subject: [PATCH 13/18] adjust order in __getitem__ functionality and batch for lightningmodule --- mambular/base_models/lightning_wrapper.py | 8 ++++---- mambular/data_utils/dataset.py | 13 ++++++++----- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/mambular/base_models/lightning_wrapper.py b/mambular/base_models/lightning_wrapper.py index 5621c4e..afc17b8 100644 --- a/mambular/base_models/lightning_wrapper.py +++ b/mambular/base_models/lightning_wrapper.py @@ -188,7 +188,7 @@ def training_step(self, batch, batch_idx): # type: ignore Tensor Training loss. """ - cat_features, num_features, labels = batch + num_features, cat_features, labels = batch # Check if the model has a `penalty_forward` method if hasattr(self.base_model, "penalty_forward"): @@ -235,7 +235,7 @@ def validation_step(self, batch, batch_idx): # type: ignore Validation loss. """ - cat_features, num_features, labels = batch + num_features, cat_features, labels = batch preds = self(num_features=num_features, cat_features=cat_features) val_loss = self.compute_loss(preds, labels) @@ -277,7 +277,7 @@ def test_step(self, batch, batch_idx): # type: ignore Tensor Test loss. """ - cat_features, num_features, labels = batch + num_features, cat_features, labels = batch preds = self(num_features=num_features, cat_features=cat_features) test_loss = self.compute_loss(preds, labels) @@ -308,7 +308,7 @@ def predict_step(self, batch, batch_idx): Predictions. """ - cat_features, num_features = batch + num_features, cat_features = batch preds = self(num_features=num_features, cat_features=cat_features) return preds diff --git a/mambular/data_utils/dataset.py b/mambular/data_utils/dataset.py index 034a581..20076ea 100644 --- a/mambular/data_utils/dataset.py +++ b/mambular/data_utils/dataset.py @@ -20,7 +20,9 @@ class MambularDataset(Dataset): regression (bool, optional): A flag indicating if the dataset is for a regression task. Defaults to True. """ - def __init__(self, cat_features_list, num_features_list, labels=None, regression=True): + def __init__( + self, cat_features_list, num_features_list, labels=None, regression=True + ): self.cat_features_list = cat_features_list # Categorical features tensors self.num_features_list = num_features_list # Numerical features tensors self.regression = regression @@ -54,7 +56,9 @@ def __getitem__(self, idx): tuple: A tuple containing two lists of tensors (one for categorical features and one for numerical features) and a single label (if available). """ - cat_features = [feature_tensor[idx] for feature_tensor in self.cat_features_list] + cat_features = [ + feature_tensor[idx] for feature_tensor in self.cat_features_list + ] num_features = [ torch.as_tensor(feature_tensor[idx]).clone().detach().to(torch.float32) for feature_tensor in self.num_features_list @@ -68,7 +72,6 @@ def __getitem__(self, idx): label = label.clone().detach().to(torch.float32) else: label = label.clone().detach().to(torch.long) - return cat_features, num_features, label + return num_features, cat_features, label else: - return cat_features, num_features # No label in prediction mode - + return num_features, cat_features # No label in prediction mode From 40fef3391378a3d68b99c3b51b60b293488fe0f7 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Sat, 18 Jan 2025 12:42:10 +0100 Subject: [PATCH 14/18] include encoding function in sklearn base classes --- mambular/models/sklearn_base_classifier.py | 53 +++++++++++++++++++++- mambular/models/sklearn_base_lss.py | 53 +++++++++++++++++++++- mambular/models/sklearn_base_regressor.py | 53 ++++++++++++++++++++++ 3 files changed, 157 insertions(+), 2 deletions(-) diff --git a/mambular/models/sklearn_base_classifier.py b/mambular/models/sklearn_base_classifier.py index c1c264b..f4f8699 100644 --- a/mambular/models/sklearn_base_classifier.py +++ b/mambular/models/sklearn_base_classifier.py @@ -18,6 +18,8 @@ get_search_space, round_to_nearest_16, ) +from tqdm import tqdm +from torch.utils.data import DataLoader class SklearnBaseClassifier(BaseEstimator): @@ -176,8 +178,12 @@ def build_model( Learning rate for the optimizer. lr_patience : int, default=10 Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. - factor : float, default=0.1 + lr_factor : float, default=0.1 Factor by which the learning rate will be reduced. + train_metrics : dict, default=None + torch.metrics dict to be logged during training. + val_metrics : dict, default=None + torch.metrics dict to be logged during validation. weight_decay : float, default=0.025 Weight decay (L2 penalty) coefficient. dataloader_kwargs: dict, default={} @@ -336,6 +342,10 @@ def fit( Weight decay (L2 penalty) coefficient. checkpoint_path : str, default="model_checkpoints" Path where the checkpoints are being saved. + train_metrics : dict, default=None + torch.metrics dict to be logged during training. + val_metrics : dict, default=None + torch.metrics dict to be logged during validation. dataloader_kwargs: dict, default={} The kwargs for the pytorch dataloader class. rebuild: bool, default=True @@ -578,6 +588,47 @@ def score(self, X, y, metric=(log_loss, True)): predictions = self.predict(X) return metric_func(y, predictions) + def encode(self, X, batch_size=64): + """ + Encodes input data using the trained model's embedding layer. + + Parameters + ---------- + X : array-like or DataFrame + Input data to be encoded. + batch_size : int, optional, default=64 + Batch size for encoding. + + Returns + ------- + torch.Tensor + Encoded representations of the input data. + + Raises + ------ + ValueError + If the model or data module is not fitted. + """ + # Ensure model and data module are initialized + if self.task_model is None or self.data_module is None: + raise ValueError("The model or data module has not been fitted yet.") + encoded_dataset = self.data_module.preprocess_new_data(X) + + data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False) + + # Process data in batches + encoded_outputs = [] + for num_features, cat_features in tqdm(data_loader): + embeddings = self.task_model.base_model.encode( + num_features, cat_features + ) # Call your encode function + encoded_outputs.append(embeddings) + + # Concatenate all encoded outputs + encoded_outputs = torch.cat(encoded_outputs, dim=0) + + return encoded_outputs + def optimize_hparams( self, X, diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py index f01cf6f..6242cc1 100644 --- a/mambular/models/sklearn_base_lss.py +++ b/mambular/models/sklearn_base_lss.py @@ -39,6 +39,8 @@ Quantile, StudentTDistribution, ) +from tqdm import tqdm +from torch.utils.data import DataLoader class SklearnBaseLSS(BaseEstimator): @@ -198,8 +200,12 @@ def build_model( Learning rate for the optimizer. lr_patience : int, default=10 Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. - factor : float, default=0.1 + lr_factor : float, default=0.1 Factor by which the learning rate will be reduced. + train_metrics : dict, default=None + torch.metrics dict to be logged during training. + val_metrics : dict, default=None + torch.metrics dict to be logged during validation. weight_decay : float, default=0.025 Weight decay (L2 penalty) coefficient. dataloader_kwargs: dict, default={} @@ -361,6 +367,10 @@ def fit( Weight decay (L2 penalty) coefficient. distributional_kwargs : dict, default=None any arguments taht are specific for a certain distribution. + train_metrics : dict, default=None + torch.metrics dict to be logged during training. + val_metrics : dict, default=None + torch.metrics dict to be logged during validation. checkpoint_path : str, default="model_checkpoints" Path where the checkpoints are being saved. dataloader_kwargs: dict, default={} @@ -596,6 +606,47 @@ def score(self, X, y, metric="NLL"): score = self.task_model.family.evaluate_nll(y, predictions) # type: ignore return score + def encode(self, X, batch_size=64): + """ + Encodes input data using the trained model's embedding layer. + + Parameters + ---------- + X : array-like or DataFrame + Input data to be encoded. + batch_size : int, optional, default=64 + Batch size for encoding. + + Returns + ------- + torch.Tensor + Encoded representations of the input data. + + Raises + ------ + ValueError + If the model or data module is not fitted. + """ + # Ensure model and data module are initialized + if self.task_model is None or self.data_module is None: + raise ValueError("The model or data module has not been fitted yet.") + encoded_dataset = self.data_module.preprocess_new_data(X) + + data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False) + + # Process data in batches + encoded_outputs = [] + for num_features, cat_features in tqdm(data_loader): + embeddings = self.task_model.base_model.encode( + num_features, cat_features + ) # Call your encode function + encoded_outputs.append(embeddings) + + # Concatenate all encoded outputs + encoded_outputs = torch.cat(encoded_outputs, dim=0) + + return encoded_outputs + def optimize_hparams( self, X, diff --git a/mambular/models/sklearn_base_regressor.py b/mambular/models/sklearn_base_regressor.py index 4964b5c..e17c2c5 100644 --- a/mambular/models/sklearn_base_regressor.py +++ b/mambular/models/sklearn_base_regressor.py @@ -16,6 +16,8 @@ get_search_space, round_to_nearest_16, ) +from torch.utils.data import DataLoader +from tqdm import tqdm class SklearnBaseRegressor(BaseEstimator): @@ -178,6 +180,10 @@ def build_model( Factor by which the learning rate will be reduced. weight_decay : float, default=0.025 Weight decay (L2 penalty) coefficient. + train_metrics : dict, default=None + torch.metrics dict to be logged during training. + val_metrics : dict, default=None + torch.metrics dict to be logged during validation. dataloader_kwargs: dict, default={} The kwargs for the pytorch dataloader class. @@ -333,6 +339,12 @@ def fit( Path where the checkpoints are being saved. dataloader_kwargs: dict, default={} The kwargs for the pytorch dataloader class. + train_metrics : dict, default=None + torch.metrics dict to be logged during training. + val_metrics : dict, default=None + torch.metrics dict to be logged during validation. + rebuild: bool, default=True + Whether to rebuild the model when it already was built. **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class. @@ -492,6 +504,47 @@ def score(self, X, y, metric=mean_squared_error): predictions = self.predict(X) return metric(y, predictions) + def encode(self, X, batch_size=64): + """ + Encodes input data using the trained model's embedding layer. + + Parameters + ---------- + X : array-like or DataFrame + Input data to be encoded. + batch_size : int, optional, default=64 + Batch size for encoding. + + Returns + ------- + torch.Tensor + Encoded representations of the input data. + + Raises + ------ + ValueError + If the model or data module is not fitted. + """ + # Ensure model and data module are initialized + if self.task_model is None or self.data_module is None: + raise ValueError("The model or data module has not been fitted yet.") + encoded_dataset = self.data_module.preprocess_new_data(X) + + data_loader = DataLoader(encoded_dataset, batch_size=batch_size, shuffle=False) + + # Process data in batches + encoded_outputs = [] + for num_features, cat_features in tqdm(data_loader): + embeddings = self.task_model.base_model.encode( + num_features, cat_features + ) # Call your encode function + encoded_outputs.append(embeddings) + + # Concatenate all encoded outputs + encoded_outputs = torch.cat(encoded_outputs, dim=0) + + return encoded_outputs + def optimize_hparams( self, X, From 0708f3f56f310bfa208862df8e35a3ce1eab25fc Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Sun, 19 Jan 2025 23:02:17 +0100 Subject: [PATCH 15/18] fix: sentence-transformers included --- poetry.lock | 573 ++++++++++++++++++++++++++++++++++++++----------- pyproject.toml | 28 +-- 2 files changed, 460 insertions(+), 141 deletions(-) diff --git a/poetry.lock b/poetry.lock index c0682a8..5583aa0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "accelerate" -version = "1.2.1" +version = "1.3.0" description = "Accelerate" optional = false python-versions = ">=3.9.0" files = [ - {file = "accelerate-1.2.1-py3-none-any.whl", hash = "sha256:be1cbb958cf837e7cdfbde46b812964b1b8ae94c9c7d94d921540beafcee8ddf"}, - {file = "accelerate-1.2.1.tar.gz", hash = "sha256:03e161fc69d495daf2b9b5c8d5b43d06e2145520c04727b5bda56d49f1a43ab5"}, + {file = "accelerate-1.3.0-py3-none-any.whl", hash = "sha256:5788d9e6a7a9f80fed665cf09681c4dddd9dc056bea656db4140ffc285ce423e"}, + {file = "accelerate-1.3.0.tar.gz", hash = "sha256:518631c0adb80bd3d42fb29e7e2dc2256bcd7c786b0ba9119bbaa08611b36d9c"}, ] [package.dependencies] @@ -18,7 +18,7 @@ packaging = ">=20.0" psutil = "*" pyyaml = "*" safetensors = ">=0.4.3" -torch = ">=1.10.0" +torch = ">=2.0.0" [package.extras] deepspeed = ["deepspeed"] @@ -610,13 +610,13 @@ tqdm = ["tqdm"] [[package]] name = "huggingface-hub" -version = "0.27.0" +version = "0.27.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.27.0-py3-none-any.whl", hash = "sha256:8f2e834517f1f1ddf1ecc716f91b120d7333011b7485f665a9a412eacb1a2a81"}, - {file = "huggingface_hub-0.27.0.tar.gz", hash = "sha256:902cce1a1be5739f5589e560198a65a8edcfd3b830b1666f36e4b961f0454fac"}, + {file = "huggingface_hub-0.27.1-py3-none-any.whl", hash = "sha256:1c5155ca7d60b60c2e2fc38cbb3ffb7f7c3adf48f824015b219af9061771daec"}, + {file = "huggingface_hub-0.27.1.tar.gz", hash = "sha256:c004463ca870283909d715d20f066ebd6968c2207dae9393fdffb3c1d4d8f98b"}, ] [package.dependencies] @@ -644,13 +644,13 @@ typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "t [[package]] name = "identify" -version = "2.6.4" +version = "2.6.5" description = "File identification library for Python" optional = false python-versions = ">=3.9" files = [ - {file = "identify-2.6.4-py2.py3-none-any.whl", hash = "sha256:993b0f01b97e0568c179bb9196391ff391bfb88a99099dbf5ce392b68f42d0af"}, - {file = "identify-2.6.4.tar.gz", hash = "sha256:285a7d27e397652e8cafe537a6cc97dd470a970f48fb2e9d979aa38eae5513ac"}, + {file = "identify-2.6.5-py2.py3-none-any.whl", hash = "sha256:14181a47091eb75b337af4c23078c9d09225cd4c48929f521f3bf16b09d02566"}, + {file = "identify-2.6.5.tar.gz", hash = "sha256:c10b33f250e5bba374fae86fb57f3adcebf1161bce7cdf92031915fd480c13bc"}, ] [package.extras] @@ -1290,6 +1290,94 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pillow" +version = "11.1.0" +description = "Python Imaging Library (Fork)" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pillow-11.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:e1abe69aca89514737465752b4bcaf8016de61b3be1397a8fc260ba33321b3a8"}, + {file = "pillow-11.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c640e5a06869c75994624551f45e5506e4256562ead981cce820d5ab39ae2192"}, + {file = "pillow-11.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a07dba04c5e22824816b2615ad7a7484432d7f540e6fa86af60d2de57b0fcee2"}, + {file = "pillow-11.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e267b0ed063341f3e60acd25c05200df4193e15a4a5807075cd71225a2386e26"}, + {file = "pillow-11.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bd165131fd51697e22421d0e467997ad31621b74bfc0b75956608cb2906dda07"}, + {file = "pillow-11.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:abc56501c3fd148d60659aae0af6ddc149660469082859fa7b066a298bde9482"}, + {file = "pillow-11.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:54ce1c9a16a9561b6d6d8cb30089ab1e5eb66918cb47d457bd996ef34182922e"}, + {file = "pillow-11.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:73ddde795ee9b06257dac5ad42fcb07f3b9b813f8c1f7f870f402f4dc54b5269"}, + {file = "pillow-11.1.0-cp310-cp310-win32.whl", hash = "sha256:3a5fe20a7b66e8135d7fd617b13272626a28278d0e578c98720d9ba4b2439d49"}, + {file = "pillow-11.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:b6123aa4a59d75f06e9dd3dac5bf8bc9aa383121bb3dd9a7a612e05eabc9961a"}, + {file = "pillow-11.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:a76da0a31da6fcae4210aa94fd779c65c75786bc9af06289cd1c184451ef7a65"}, + {file = "pillow-11.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:e06695e0326d05b06833b40b7ef477e475d0b1ba3a6d27da1bb48c23209bf457"}, + {file = "pillow-11.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96f82000e12f23e4f29346e42702b6ed9a2f2fea34a740dd5ffffcc8c539eb35"}, + {file = "pillow-11.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3cd561ded2cf2bbae44d4605837221b987c216cff94f49dfeed63488bb228d2"}, + {file = "pillow-11.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f189805c8be5ca5add39e6f899e6ce2ed824e65fb45f3c28cb2841911da19070"}, + {file = "pillow-11.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dd0052e9db3474df30433f83a71b9b23bd9e4ef1de13d92df21a52c0303b8ab6"}, + {file = "pillow-11.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:837060a8599b8f5d402e97197d4924f05a2e0d68756998345c829c33186217b1"}, + {file = "pillow-11.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:aa8dd43daa836b9a8128dbe7d923423e5ad86f50a7a14dc688194b7be5c0dea2"}, + {file = "pillow-11.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0a2f91f8a8b367e7a57c6e91cd25af510168091fb89ec5146003e424e1558a96"}, + {file = "pillow-11.1.0-cp311-cp311-win32.whl", hash = "sha256:c12fc111ef090845de2bb15009372175d76ac99969bdf31e2ce9b42e4b8cd88f"}, + {file = "pillow-11.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fbd43429d0d7ed6533b25fc993861b8fd512c42d04514a0dd6337fb3ccf22761"}, + {file = "pillow-11.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:f7955ecf5609dee9442cbface754f2c6e541d9e6eda87fad7f7a989b0bdb9d71"}, + {file = "pillow-11.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2062ffb1d36544d42fcaa277b069c88b01bb7298f4efa06731a7fd6cc290b81a"}, + {file = "pillow-11.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a85b653980faad27e88b141348707ceeef8a1186f75ecc600c395dcac19f385b"}, + {file = "pillow-11.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9409c080586d1f683df3f184f20e36fb647f2e0bc3988094d4fd8c9f4eb1b3b3"}, + {file = "pillow-11.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7fdadc077553621911f27ce206ffcbec7d3f8d7b50e0da39f10997e8e2bb7f6a"}, + {file = "pillow-11.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:93a18841d09bcdd774dcdc308e4537e1f867b3dec059c131fde0327899734aa1"}, + {file = "pillow-11.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9aa9aeddeed452b2f616ff5507459e7bab436916ccb10961c4a382cd3e03f47f"}, + {file = "pillow-11.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3cdcdb0b896e981678eee140d882b70092dac83ac1cdf6b3a60e2216a73f2b91"}, + {file = "pillow-11.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:36ba10b9cb413e7c7dfa3e189aba252deee0602c86c309799da5a74009ac7a1c"}, + {file = "pillow-11.1.0-cp312-cp312-win32.whl", hash = "sha256:cfd5cd998c2e36a862d0e27b2df63237e67273f2fc78f47445b14e73a810e7e6"}, + {file = "pillow-11.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:a697cd8ba0383bba3d2d3ada02b34ed268cb548b369943cd349007730c92bddf"}, + {file = "pillow-11.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:4dd43a78897793f60766563969442020e90eb7847463eca901e41ba186a7d4a5"}, + {file = "pillow-11.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ae98e14432d458fc3de11a77ccb3ae65ddce70f730e7c76140653048c71bfcbc"}, + {file = "pillow-11.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cc1331b6d5a6e144aeb5e626f4375f5b7ae9934ba620c0ac6b3e43d5e683a0f0"}, + {file = "pillow-11.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:758e9d4ef15d3560214cddbc97b8ef3ef86ce04d62ddac17ad39ba87e89bd3b1"}, + {file = "pillow-11.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b523466b1a31d0dcef7c5be1f20b942919b62fd6e9a9be199d035509cbefc0ec"}, + {file = "pillow-11.1.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:9044b5e4f7083f209c4e35aa5dd54b1dd5b112b108648f5c902ad586d4f945c5"}, + {file = "pillow-11.1.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:3764d53e09cdedd91bee65c2527815d315c6b90d7b8b79759cc48d7bf5d4f114"}, + {file = "pillow-11.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:31eba6bbdd27dde97b0174ddf0297d7a9c3a507a8a1480e1e60ef914fe23d352"}, + {file = "pillow-11.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:b5d658fbd9f0d6eea113aea286b21d3cd4d3fd978157cbf2447a6035916506d3"}, + {file = "pillow-11.1.0-cp313-cp313-win32.whl", hash = "sha256:f86d3a7a9af5d826744fabf4afd15b9dfef44fe69a98541f666f66fbb8d3fef9"}, + {file = "pillow-11.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:593c5fd6be85da83656b93ffcccc2312d2d149d251e98588b14fbc288fd8909c"}, + {file = "pillow-11.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:11633d58b6ee5733bde153a8dafd25e505ea3d32e261accd388827ee987baf65"}, + {file = "pillow-11.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:70ca5ef3b3b1c4a0812b5c63c57c23b63e53bc38e758b37a951e5bc466449861"}, + {file = "pillow-11.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8000376f139d4d38d6851eb149b321a52bb8893a88dae8ee7d95840431977081"}, + {file = "pillow-11.1.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ee85f0696a17dd28fbcfceb59f9510aa71934b483d1f5601d1030c3c8304f3c"}, + {file = "pillow-11.1.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:dd0e081319328928531df7a0e63621caf67652c8464303fd102141b785ef9547"}, + {file = "pillow-11.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e63e4e5081de46517099dc30abe418122f54531a6ae2ebc8680bcd7096860eab"}, + {file = "pillow-11.1.0-cp313-cp313t-win32.whl", hash = "sha256:dda60aa465b861324e65a78c9f5cf0f4bc713e4309f83bc387be158b077963d9"}, + {file = "pillow-11.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ad5db5781c774ab9a9b2c4302bbf0c1014960a0a7be63278d13ae6fdf88126fe"}, + {file = "pillow-11.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:67cd427c68926108778a9005f2a04adbd5e67c442ed21d95389fe1d595458756"}, + {file = "pillow-11.1.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:bf902d7413c82a1bfa08b06a070876132a5ae6b2388e2712aab3a7cbc02205c6"}, + {file = "pillow-11.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c1eec9d950b6fe688edee07138993e54ee4ae634c51443cfb7c1e7613322718e"}, + {file = "pillow-11.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e275ee4cb11c262bd108ab2081f750db2a1c0b8c12c1897f27b160c8bd57bbc"}, + {file = "pillow-11.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4db853948ce4e718f2fc775b75c37ba2efb6aaea41a1a5fc57f0af59eee774b2"}, + {file = "pillow-11.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:ab8a209b8485d3db694fa97a896d96dd6533d63c22829043fd9de627060beade"}, + {file = "pillow-11.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:54251ef02a2309b5eec99d151ebf5c9904b77976c8abdcbce7891ed22df53884"}, + {file = "pillow-11.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5bb94705aea800051a743aa4874bb1397d4695fb0583ba5e425ee0328757f196"}, + {file = "pillow-11.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:89dbdb3e6e9594d512780a5a1c42801879628b38e3efc7038094430844e271d8"}, + {file = "pillow-11.1.0-cp39-cp39-win32.whl", hash = "sha256:e5449ca63da169a2e6068dd0e2fcc8d91f9558aba89ff6d02121ca8ab11e79e5"}, + {file = "pillow-11.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:3362c6ca227e65c54bf71a5f88b3d4565ff1bcbc63ae72c34b07bbb1cc59a43f"}, + {file = "pillow-11.1.0-cp39-cp39-win_arm64.whl", hash = "sha256:b20be51b37a75cc54c2c55def3fa2c65bb94ba859dde241cd0a4fd302de5ae0a"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8c730dc3a83e5ac137fbc92dfcfe1511ce3b2b5d7578315b63dbbb76f7f51d90"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7d33d2fae0e8b170b6a6c57400e077412240f6f5bb2a342cf1ee512a787942bb"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8d65b38173085f24bc07f8b6c505cbb7418009fa1a1fcb111b1f4961814a442"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:015c6e863faa4779251436db398ae75051469f7c903b043a48f078e437656f83"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d44ff19eea13ae4acdaaab0179fa68c0c6f2f45d66a4d8ec1eda7d6cecbcc15f"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d3d8da4a631471dfaf94c10c85f5277b1f8e42ac42bade1ac67da4b4a7359b73"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4637b88343166249fe8aa94e7c4a62a180c4b3898283bb5d3d2fd5fe10d8e4e0"}, + {file = "pillow-11.1.0.tar.gz", hash = "sha256:368da70808b36d73b4b390a8ffac11069f8a5c85f29eff1f1b01bcf3ef5b2a20"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=8.1)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] +fpx = ["olefile"] +mic = ["olefile"] +tests = ["check-manifest", "coverage (>=7.4.2)", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout", "trove-classifiers (>=2024.10.12)"] +typing = ["typing-extensions"] +xmp = ["defusedxml"] + [[package]] name = "platformdirs" version = "4.3.6" @@ -1649,6 +1737,109 @@ files = [ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] +[[package]] +name = "regex" +version = "2024.11.6" +description = "Alternative regular expression module, to replace re." +optional = false +python-versions = ">=3.8" +files = [ + {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff590880083d60acc0433f9c3f713c51f7ac6ebb9adf889c79a261ecf541aa91"}, + {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:658f90550f38270639e83ce492f27d2c8d2cd63805c65a13a14d36ca126753f0"}, + {file = "regex-2024.11.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:164d8b7b3b4bcb2068b97428060b2a53be050085ef94eca7f240e7947f1b080e"}, + {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3660c82f209655a06b587d55e723f0b813d3a7db2e32e5e7dc64ac2a9e86fde"}, + {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d22326fcdef5e08c154280b71163ced384b428343ae16a5ab2b3354aed12436e"}, + {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f1ac758ef6aebfc8943560194e9fd0fa18bcb34d89fd8bd2af18183afd8da3a2"}, + {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:997d6a487ff00807ba810e0f8332c18b4eb8d29463cfb7c820dc4b6e7562d0cf"}, + {file = "regex-2024.11.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02a02d2bb04fec86ad61f3ea7f49c015a0681bf76abb9857f945d26159d2968c"}, + {file = "regex-2024.11.6-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f02f93b92358ee3f78660e43b4b0091229260c5d5c408d17d60bf26b6c900e86"}, + {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:06eb1be98df10e81ebaded73fcd51989dcf534e3c753466e4b60c4697a003b67"}, + {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:040df6fe1a5504eb0f04f048e6d09cd7c7110fef851d7c567a6b6e09942feb7d"}, + {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabbfc59f2c6edba2a6622c647b716e34e8e3867e0ab975412c5c2f79b82da2"}, + {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8447d2d39b5abe381419319f942de20b7ecd60ce86f16a23b0698f22e1b70008"}, + {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:da8f5fc57d1933de22a9e23eec290a0d8a5927a5370d24bda9a6abe50683fe62"}, + {file = "regex-2024.11.6-cp310-cp310-win32.whl", hash = "sha256:b489578720afb782f6ccf2840920f3a32e31ba28a4b162e13900c3e6bd3f930e"}, + {file = "regex-2024.11.6-cp310-cp310-win_amd64.whl", hash = "sha256:5071b2093e793357c9d8b2929dfc13ac5f0a6c650559503bb81189d0a3814519"}, + {file = "regex-2024.11.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5478c6962ad548b54a591778e93cd7c456a7a29f8eca9c49e4f9a806dcc5d638"}, + {file = "regex-2024.11.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c89a8cc122b25ce6945f0423dc1352cb9593c68abd19223eebbd4e56612c5b7"}, + {file = "regex-2024.11.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:94d87b689cdd831934fa3ce16cc15cd65748e6d689f5d2b8f4f4df2065c9fa20"}, + {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1062b39a0a2b75a9c694f7a08e7183a80c63c0d62b301418ffd9c35f55aaa114"}, + {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:167ed4852351d8a750da48712c3930b031f6efdaa0f22fa1933716bfcd6bf4a3"}, + {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d548dafee61f06ebdb584080621f3e0c23fff312f0de1afc776e2a2ba99a74f"}, + {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a19f302cd1ce5dd01a9099aaa19cae6173306d1302a43b627f62e21cf18ac0"}, + {file = "regex-2024.11.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bec9931dfb61ddd8ef2ebc05646293812cb6b16b60cf7c9511a832b6f1854b55"}, + {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9714398225f299aa85267fd222f7142fcb5c769e73d7733344efc46f2ef5cf89"}, + {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:202eb32e89f60fc147a41e55cb086db2a3f8cb82f9a9a88440dcfc5d37faae8d"}, + {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:4181b814e56078e9b00427ca358ec44333765f5ca1b45597ec7446d3a1ef6e34"}, + {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:068376da5a7e4da51968ce4c122a7cd31afaaec4fccc7856c92f63876e57b51d"}, + {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ac10f2c4184420d881a3475fb2c6f4d95d53a8d50209a2500723d831036f7c45"}, + {file = "regex-2024.11.6-cp311-cp311-win32.whl", hash = "sha256:c36f9b6f5f8649bb251a5f3f66564438977b7ef8386a52460ae77e6070d309d9"}, + {file = "regex-2024.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:02e28184be537f0e75c1f9b2f8847dc51e08e6e171c6bde130b2687e0c33cf60"}, + {file = "regex-2024.11.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:52fb28f528778f184f870b7cf8f225f5eef0a8f6e3778529bdd40c7b3920796a"}, + {file = "regex-2024.11.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdd6028445d2460f33136c55eeb1f601ab06d74cb3347132e1c24250187500d9"}, + {file = "regex-2024.11.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:805e6b60c54bf766b251e94526ebad60b7de0c70f70a4e6210ee2891acb70bf2"}, + {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b85c2530be953a890eaffde05485238f07029600e8f098cdf1848d414a8b45e4"}, + {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb26437975da7dc36b7efad18aa9dd4ea569d2357ae6b783bf1118dabd9ea577"}, + {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abfa5080c374a76a251ba60683242bc17eeb2c9818d0d30117b4486be10c59d3"}, + {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b7fa6606c2881c1db9479b0eaa11ed5dfa11c8d60a474ff0e095099f39d98e"}, + {file = "regex-2024.11.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c32f75920cf99fe6b6c539c399a4a128452eaf1af27f39bce8909c9a3fd8cbe"}, + {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:982e6d21414e78e1f51cf595d7f321dcd14de1f2881c5dc6a6e23bbbbd68435e"}, + {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a7c2155f790e2fb448faed6dd241386719802296ec588a8b9051c1f5c481bc29"}, + {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149f5008d286636e48cd0b1dd65018548944e495b0265b45e1bffecce1ef7f39"}, + {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:e5364a4502efca094731680e80009632ad6624084aff9a23ce8c8c6820de3e51"}, + {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0a86e7eeca091c09e021db8eb72d54751e527fa47b8d5787caf96d9831bd02ad"}, + {file = "regex-2024.11.6-cp312-cp312-win32.whl", hash = "sha256:32f9a4c643baad4efa81d549c2aadefaeba12249b2adc5af541759237eee1c54"}, + {file = "regex-2024.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:a93c194e2df18f7d264092dc8539b8ffb86b45b899ab976aa15d48214138e81b"}, + {file = "regex-2024.11.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a6ba92c0bcdf96cbf43a12c717eae4bc98325ca3730f6b130ffa2e3c3c723d84"}, + {file = "regex-2024.11.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:525eab0b789891ac3be914d36893bdf972d483fe66551f79d3e27146191a37d4"}, + {file = "regex-2024.11.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:086a27a0b4ca227941700e0b31425e7a28ef1ae8e5e05a33826e17e47fbfdba0"}, + {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bde01f35767c4a7899b7eb6e823b125a64de314a8ee9791367c9a34d56af18d0"}, + {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b583904576650166b3d920d2bcce13971f6f9e9a396c673187f49811b2769dc7"}, + {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c4de13f06a0d54fa0d5ab1b7138bfa0d883220965a29616e3ea61b35d5f5fc7"}, + {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3cde6e9f2580eb1665965ce9bf17ff4952f34f5b126beb509fee8f4e994f143c"}, + {file = "regex-2024.11.6-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d7f453dca13f40a02b79636a339c5b62b670141e63efd511d3f8f73fba162b3"}, + {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:59dfe1ed21aea057a65c6b586afd2a945de04fc7db3de0a6e3ed5397ad491b07"}, + {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b97c1e0bd37c5cd7902e65f410779d39eeda155800b65fc4d04cc432efa9bc6e"}, + {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f9d1e379028e0fc2ae3654bac3cbbef81bf3fd571272a42d56c24007979bafb6"}, + {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:13291b39131e2d002a7940fb176e120bec5145f3aeb7621be6534e46251912c4"}, + {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f51f88c126370dcec4908576c5a627220da6c09d0bff31cfa89f2523843316d"}, + {file = "regex-2024.11.6-cp313-cp313-win32.whl", hash = "sha256:63b13cfd72e9601125027202cad74995ab26921d8cd935c25f09c630436348ff"}, + {file = "regex-2024.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:2b3361af3198667e99927da8b84c1b010752fa4b1115ee30beaa332cabc3ef1a"}, + {file = "regex-2024.11.6-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3a51ccc315653ba012774efca4f23d1d2a8a8f278a6072e29c7147eee7da446b"}, + {file = "regex-2024.11.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ad182d02e40de7459b73155deb8996bbd8e96852267879396fb274e8700190e3"}, + {file = "regex-2024.11.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ba9b72e5643641b7d41fa1f6d5abda2c9a263ae835b917348fc3c928182ad467"}, + {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40291b1b89ca6ad8d3f2b82782cc33807f1406cf68c8d440861da6304d8ffbbd"}, + {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cdf58d0e516ee426a48f7b2c03a332a4114420716d55769ff7108c37a09951bf"}, + {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a36fdf2af13c2b14738f6e973aba563623cb77d753bbbd8d414d18bfaa3105dd"}, + {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1cee317bfc014c2419a76bcc87f071405e3966da434e03e13beb45f8aced1a6"}, + {file = "regex-2024.11.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50153825ee016b91549962f970d6a4442fa106832e14c918acd1c8e479916c4f"}, + {file = "regex-2024.11.6-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ea1bfda2f7162605f6e8178223576856b3d791109f15ea99a9f95c16a7636fb5"}, + {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:df951c5f4a1b1910f1a99ff42c473ff60f8225baa1cdd3539fe2819d9543e9df"}, + {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:072623554418a9911446278f16ecb398fb3b540147a7828c06e2011fa531e773"}, + {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:f654882311409afb1d780b940234208a252322c24a93b442ca714d119e68086c"}, + {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:89d75e7293d2b3e674db7d4d9b1bee7f8f3d1609428e293771d1a962617150cc"}, + {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:f65557897fc977a44ab205ea871b690adaef6b9da6afda4790a2484b04293a5f"}, + {file = "regex-2024.11.6-cp38-cp38-win32.whl", hash = "sha256:6f44ec28b1f858c98d3036ad5d7d0bfc568bdd7a74f9c24e25f41ef1ebfd81a4"}, + {file = "regex-2024.11.6-cp38-cp38-win_amd64.whl", hash = "sha256:bb8f74f2f10dbf13a0be8de623ba4f9491faf58c24064f32b65679b021ed0001"}, + {file = "regex-2024.11.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5704e174f8ccab2026bd2f1ab6c510345ae8eac818b613d7d73e785f1310f839"}, + {file = "regex-2024.11.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:220902c3c5cc6af55d4fe19ead504de80eb91f786dc102fbd74894b1551f095e"}, + {file = "regex-2024.11.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e7e351589da0850c125f1600a4c4ba3c722efefe16b297de54300f08d734fbf"}, + {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5056b185ca113c88e18223183aa1a50e66507769c9640a6ff75859619d73957b"}, + {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e34b51b650b23ed3354b5a07aab37034d9f923db2a40519139af34f485f77d0"}, + {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5670bce7b200273eee1840ef307bfa07cda90b38ae56e9a6ebcc9f50da9c469b"}, + {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08986dce1339bc932923e7d1232ce9881499a0e02925f7402fb7c982515419ef"}, + {file = "regex-2024.11.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:93c0b12d3d3bc25af4ebbf38f9ee780a487e8bf6954c115b9f015822d3bb8e48"}, + {file = "regex-2024.11.6-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:764e71f22ab3b305e7f4c21f1a97e1526a25ebdd22513e251cf376760213da13"}, + {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:f056bf21105c2515c32372bbc057f43eb02aae2fda61052e2f7622c801f0b4e2"}, + {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:69ab78f848845569401469da20df3e081e6b5a11cb086de3eed1d48f5ed57c95"}, + {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:86fddba590aad9208e2fa8b43b4c098bb0ec74f15718bb6a704e3c63e2cef3e9"}, + {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:684d7a212682996d21ca12ef3c17353c021fe9de6049e19ac8481ec35574a70f"}, + {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a03e02f48cd1abbd9f3b7e3586d97c8f7a9721c436f51a5245b3b9483044480b"}, + {file = "regex-2024.11.6-cp39-cp39-win32.whl", hash = "sha256:41758407fc32d5c3c5de163888068cfee69cb4c2be844e7ac517a52770f9af57"}, + {file = "regex-2024.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b2837718570f95dd41675328e111345f9b7095d821bac435aac173ac80b19983"}, + {file = "regex-2024.11.6.tar.gz", hash = "sha256:7ab159b063c52a0333c884e4679f8d7a85112ee3078fe3d9004b2dd875585519"}, +] + [[package]] name = "requests" version = "2.32.3" @@ -1672,53 +1863,53 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "ruff" -version = "0.8.5" +version = "0.9.2" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.8.5-py3-none-linux_armv6l.whl", hash = "sha256:5ad11a5e3868a73ca1fa4727fe7e33735ea78b416313f4368c504dbeb69c0f88"}, - {file = "ruff-0.8.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f69ab37771ea7e0715fead8624ec42996d101269a96e31f4d31be6fc33aa19b7"}, - {file = "ruff-0.8.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b5462d7804558ccff9c08fe8cbf6c14b7efe67404316696a2dde48297b1925bb"}, - {file = "ruff-0.8.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d56de7220a35607f9fe59f8a6d018e14504f7b71d784d980835e20fc0611cd50"}, - {file = "ruff-0.8.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9d99cf80b0429cbebf31cbbf6f24f05a29706f0437c40413d950e67e2d4faca4"}, - {file = "ruff-0.8.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b75ac29715ac60d554a049dbb0ef3b55259076181c3369d79466cb130eb5afd"}, - {file = "ruff-0.8.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c9d526a62c9eda211b38463528768fd0ada25dad524cb33c0e99fcff1c67b5dc"}, - {file = "ruff-0.8.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:587c5e95007612c26509f30acc506c874dab4c4abbacd0357400bd1aa799931b"}, - {file = "ruff-0.8.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:622b82bf3429ff0e346835ec213aec0a04d9730480cbffbb6ad9372014e31bbd"}, - {file = "ruff-0.8.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f99be814d77a5dac8a8957104bdd8c359e85c86b0ee0e38dca447cb1095f70fb"}, - {file = "ruff-0.8.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c01c048f9c3385e0fd7822ad0fd519afb282af9cf1778f3580e540629df89725"}, - {file = "ruff-0.8.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:7512e8cb038db7f5db6aae0e24735ff9ea03bb0ed6ae2ce534e9baa23c1dc9ea"}, - {file = "ruff-0.8.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:762f113232acd5b768d6b875d16aad6b00082add40ec91c927f0673a8ec4ede8"}, - {file = "ruff-0.8.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:03a90200c5dfff49e4c967b405f27fdfa81594cbb7c5ff5609e42d7fe9680da5"}, - {file = "ruff-0.8.5-py3-none-win32.whl", hash = "sha256:8710ffd57bdaa6690cbf6ecff19884b8629ec2a2a2a2f783aa94b1cc795139ed"}, - {file = "ruff-0.8.5-py3-none-win_amd64.whl", hash = "sha256:4020d8bf8d3a32325c77af452a9976a9ad6455773bcb94991cf15bd66b347e47"}, - {file = "ruff-0.8.5-py3-none-win_arm64.whl", hash = "sha256:134ae019ef13e1b060ab7136e7828a6d83ea727ba123381307eb37c6bd5e01cb"}, - {file = "ruff-0.8.5.tar.gz", hash = "sha256:1098d36f69831f7ff2a1da3e6407d5fbd6dfa2559e4f74ff2d260c5588900317"}, + {file = "ruff-0.9.2-py3-none-linux_armv6l.whl", hash = "sha256:80605a039ba1454d002b32139e4970becf84b5fee3a3c3bf1c2af6f61a784347"}, + {file = "ruff-0.9.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b9aab82bb20afd5f596527045c01e6ae25a718ff1784cb92947bff1f83068b00"}, + {file = "ruff-0.9.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fbd337bac1cfa96be615f6efcd4bc4d077edbc127ef30e2b8ba2a27e18c054d4"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82b35259b0cbf8daa22a498018e300b9bb0174c2bbb7bcba593935158a78054d"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8b6a9701d1e371bf41dca22015c3f89769da7576884d2add7317ec1ec8cb9c3c"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9cc53e68b3c5ae41e8faf83a3b89f4a5d7b2cb666dff4b366bb86ed2a85b481f"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:8efd9da7a1ee314b910da155ca7e8953094a7c10d0c0a39bfde3fcfd2a015684"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3292c5a22ea9a5f9a185e2d131dc7f98f8534a32fb6d2ee7b9944569239c648d"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a605fdcf6e8b2d39f9436d343d1f0ff70c365a1e681546de0104bef81ce88df"}, + {file = "ruff-0.9.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c547f7f256aa366834829a08375c297fa63386cbe5f1459efaf174086b564247"}, + {file = "ruff-0.9.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d18bba3d3353ed916e882521bc3e0af403949dbada344c20c16ea78f47af965e"}, + {file = "ruff-0.9.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b338edc4610142355ccf6b87bd356729b62bf1bc152a2fad5b0c7dc04af77bfe"}, + {file = "ruff-0.9.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:492a5e44ad9b22a0ea98cf72e40305cbdaf27fac0d927f8bc9e1df316dcc96eb"}, + {file = "ruff-0.9.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:af1e9e9fe7b1f767264d26b1075ac4ad831c7db976911fa362d09b2d0356426a"}, + {file = "ruff-0.9.2-py3-none-win32.whl", hash = "sha256:71cbe22e178c5da20e1514e1e01029c73dc09288a8028a5d3446e6bba87a5145"}, + {file = "ruff-0.9.2-py3-none-win_amd64.whl", hash = "sha256:c5e1d6abc798419cf46eed03f54f2e0c3adb1ad4b801119dedf23fcaf69b55b5"}, + {file = "ruff-0.9.2-py3-none-win_arm64.whl", hash = "sha256:a1b63fa24149918f8b37cef2ee6fff81f24f0d74b6f0bdc37bc3e1f2143e41c6"}, + {file = "ruff-0.9.2.tar.gz", hash = "sha256:b5eceb334d55fae5f316f783437392642ae18e16dcf4f1858d55d3c2a0f8f5d0"}, ] [[package]] name = "safetensors" -version = "0.5.0" +version = "0.5.2" description = "" optional = false python-versions = ">=3.7" files = [ - {file = "safetensors-0.5.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c683b9b485bee43422ba2855f72777c37647190281e03da4c8d2a69fa5336558"}, - {file = "safetensors-0.5.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:6106aa835deb7263f7014f74c05842ab828d6c11d789f2e7e98f26b1a305e72d"}, - {file = "safetensors-0.5.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1349611f74f55c5ee1c1c144c536a2743c38f7d8bf60b9fc8267e0efc0591a2"}, - {file = "safetensors-0.5.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:56d936028ac799e18644b08a91fd98b4b62ae3dcd0440b1cfcb56535785589f1"}, - {file = "safetensors-0.5.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2f26afada2233576ffea6b80042c2c0a8105c164254af56168ec14299ad3122"}, - {file = "safetensors-0.5.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:20067e7a5e63f0cbc88457b2a1161e70ff73af4cc3a24bce90309430cd6f6e7e"}, - {file = "safetensors-0.5.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649d6a4aa34d5174ae87289068ccc2fec2a1a998ecf83425aa5a42c3eff69bcf"}, - {file = "safetensors-0.5.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:debff88f41d569a3e93a955469f83864e432af35bb34b16f65a9ddf378daa3ae"}, - {file = "safetensors-0.5.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:bdf6a3e366ea8ba1a0538db6099229e95811194432c684ea28ea7ae28763b8dc"}, - {file = "safetensors-0.5.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:0371afd84c200a80eb7103bf715108b0c3846132fb82453ae018609a15551580"}, - {file = "safetensors-0.5.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5ec7fc8c3d2f32ebf1c7011bc886b362e53ee0a1ec6d828c39d531fed8b325d6"}, - {file = "safetensors-0.5.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:53715e4ea0ef23c08f004baae0f609a7773de7d4148727760417c6760cfd6b76"}, - {file = "safetensors-0.5.0-cp38-abi3-win32.whl", hash = "sha256:b85565bc2f0456961a788d2f11d9d892eec46603db0e4923aa9512c2355aa727"}, - {file = "safetensors-0.5.0-cp38-abi3-win_amd64.whl", hash = "sha256:f451941f8aa11e7be5c3fa450e264609a2b1e65fa38ae590a74e55a94d646b76"}, - {file = "safetensors-0.5.0.tar.gz", hash = "sha256:c47b34c549fa1e0c655c4644da31332c61332c732c47c8dd9399347e9aac69d1"}, + {file = "safetensors-0.5.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:45b6092997ceb8aa3801693781a71a99909ab9cc776fbc3fa9322d29b1d3bef2"}, + {file = "safetensors-0.5.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:6d0d6a8ee2215a440e1296b843edf44fd377b055ba350eaba74655a2fe2c4bae"}, + {file = "safetensors-0.5.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86016d40bcaa3bcc9a56cd74d97e654b5f4f4abe42b038c71e4f00a089c4526c"}, + {file = "safetensors-0.5.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:990833f70a5f9c7d3fc82c94507f03179930ff7d00941c287f73b6fcbf67f19e"}, + {file = "safetensors-0.5.2-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3dfa7c2f3fe55db34eba90c29df94bcdac4821043fc391cb5d082d9922013869"}, + {file = "safetensors-0.5.2-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46ff2116150ae70a4e9c490d2ab6b6e1b1b93f25e520e540abe1b81b48560c3a"}, + {file = "safetensors-0.5.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ab696dfdc060caffb61dbe4066b86419107a24c804a4e373ba59be699ebd8d5"}, + {file = "safetensors-0.5.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:03c937100f38c9ff4c1507abea9928a6a9b02c9c1c9c3609ed4fb2bf413d4975"}, + {file = "safetensors-0.5.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:a00e737948791b94dad83cf0eafc09a02c4d8c2171a239e8c8572fe04e25960e"}, + {file = "safetensors-0.5.2-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:d3a06fae62418ec8e5c635b61a8086032c9e281f16c63c3af46a6efbab33156f"}, + {file = "safetensors-0.5.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:1506e4c2eda1431099cebe9abf6c76853e95d0b7a95addceaa74c6019c65d8cf"}, + {file = "safetensors-0.5.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5c5b5d9da594f638a259fca766046f44c97244cc7ab8bef161b3e80d04becc76"}, + {file = "safetensors-0.5.2-cp38-abi3-win32.whl", hash = "sha256:fe55c039d97090d1f85277d402954dd6ad27f63034fa81985a9cc59655ac3ee2"}, + {file = "safetensors-0.5.2-cp38-abi3-win_amd64.whl", hash = "sha256:78abdddd03a406646107f973c7843276e7b64e5e32623529dc17f3d94a20f589"}, + {file = "safetensors-0.5.2.tar.gz", hash = "sha256:cb4a8d98ba12fa016f4241932b1fc5e702e5143f5374bba0bbcf7ddc1c4cf2b8"}, ] [package.extras] @@ -1736,41 +1927,41 @@ torch = ["safetensors[numpy]", "torch (>=1.10)"] [[package]] name = "scikit-learn" -version = "1.6.0" +version = "1.6.1" description = "A set of python modules for machine learning and data mining" optional = false python-versions = ">=3.9" files = [ - {file = "scikit_learn-1.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:366fb3fa47dce90afed3d6106183f4978d6f24cfd595c2373424171b915ee718"}, - {file = "scikit_learn-1.6.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:59cd96a8d9f8dfd546f5d6e9787e1b989e981388d7803abbc9efdcde61e47460"}, - {file = "scikit_learn-1.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efa7a579606c73a0b3d210e33ea410ea9e1af7933fe324cb7e6fbafae4ea5948"}, - {file = "scikit_learn-1.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a46d3ca0f11a540b8eaddaf5e38172d8cd65a86cb3e3632161ec96c0cffb774c"}, - {file = "scikit_learn-1.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:5be4577769c5dde6e1b53de8e6520f9b664ab5861dd57acee47ad119fd7405d6"}, - {file = "scikit_learn-1.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1f50b4f24cf12a81c3c09958ae3b864d7534934ca66ded3822de4996d25d7285"}, - {file = "scikit_learn-1.6.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:eb9ae21f387826da14b0b9cb1034f5048ddb9182da429c689f5f4a87dc96930b"}, - {file = "scikit_learn-1.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0baa91eeb8c32632628874a5c91885eaedd23b71504d24227925080da075837a"}, - {file = "scikit_learn-1.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c716d13ba0a2f8762d96ff78d3e0cde90bc9c9b5c13d6ab6bb9b2d6ca6705fd"}, - {file = "scikit_learn-1.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:9aafd94bafc841b626681e626be27bf1233d5a0f20f0a6fdb4bee1a1963c6643"}, - {file = "scikit_learn-1.6.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:04a5ba45c12a5ff81518aa4f1604e826a45d20e53da47b15871526cda4ff5174"}, - {file = "scikit_learn-1.6.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:21fadfc2ad7a1ce8bd1d90f23d17875b84ec765eecbbfc924ff11fb73db582ce"}, - {file = "scikit_learn-1.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30f34bb5fde90e020653bb84dcb38b6c83f90c70680dbd8c38bd9becbad7a127"}, - {file = "scikit_learn-1.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1dad624cffe3062276a0881d4e441bc9e3b19d02d17757cd6ae79a9d192a0027"}, - {file = "scikit_learn-1.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:2fce7950a3fad85e0a61dc403df0f9345b53432ac0e47c50da210d22c60b6d85"}, - {file = "scikit_learn-1.6.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e5453b2e87ef8accedc5a8a4e6709f887ca01896cd7cc8a174fe39bd4bb00aef"}, - {file = "scikit_learn-1.6.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:5fe11794236fb83bead2af26a87ced5d26e3370b8487430818b915dafab1724e"}, - {file = "scikit_learn-1.6.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:61fe3dcec0d82ae280877a818ab652f4988371e32dd5451e75251bece79668b1"}, - {file = "scikit_learn-1.6.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b44e3a51e181933bdf9a4953cc69c6025b40d2b49e238233f149b98849beb4bf"}, - {file = "scikit_learn-1.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:a17860a562bac54384454d40b3f6155200c1c737c9399e6a97962c63fce503ac"}, - {file = "scikit_learn-1.6.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:98717d3c152f6842d36a70f21e1468fb2f1a2f8f2624d9a3f382211798516426"}, - {file = "scikit_learn-1.6.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:34e20bfac8ff0ebe0ff20fb16a4d6df5dc4cc9ce383e00c2ab67a526a3c67b18"}, - {file = "scikit_learn-1.6.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eba06d75815406091419e06dd650b91ebd1c5f836392a0d833ff36447c2b1bfa"}, - {file = "scikit_learn-1.6.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b6916d1cec1ff163c7d281e699d7a6a709da2f2c5ec7b10547e08cc788ddd3ae"}, - {file = "scikit_learn-1.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:66b1cf721a9f07f518eb545098226796c399c64abdcbf91c2b95d625068363da"}, - {file = "scikit_learn-1.6.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7b35b60cf4cd6564b636e4a40516b3c61a4fa7a8b1f7a3ce80c38ebe04750bc3"}, - {file = "scikit_learn-1.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a73b1c2038c93bc7f4bf21f6c9828d5116c5d2268f7a20cfbbd41d3074d52083"}, - {file = "scikit_learn-1.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c3fa7d3dd5a0ec2d0baba0d644916fa2ab180ee37850c5d536245df916946bd"}, - {file = "scikit_learn-1.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:df778486a32518cda33818b7e3ce48c78cef1d5f640a6bc9d97c6d2e71449a51"}, - {file = "scikit_learn-1.6.0.tar.gz", hash = "sha256:9d58481f9f7499dff4196927aedd4285a0baec8caa3790efbe205f13de37dd6e"}, + {file = "scikit_learn-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d056391530ccd1e501056160e3c9673b4da4805eb67eb2bdf4e983e1f9c9204e"}, + {file = "scikit_learn-1.6.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0c8d036eb937dbb568c6242fa598d551d88fb4399c0344d95c001980ec1c7d36"}, + {file = "scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8634c4bd21a2a813e0a7e3900464e6d593162a29dd35d25bdf0103b3fce60ed5"}, + {file = "scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:775da975a471c4f6f467725dff0ced5c7ac7bda5e9316b260225b48475279a1b"}, + {file = "scikit_learn-1.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:8a600c31592bd7dab31e1c61b9bbd6dea1b3433e67d264d17ce1017dbdce8002"}, + {file = "scikit_learn-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:72abc587c75234935e97d09aa4913a82f7b03ee0b74111dcc2881cba3c5a7b33"}, + {file = "scikit_learn-1.6.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b3b00cdc8f1317b5f33191df1386c0befd16625f49d979fe77a8d44cae82410d"}, + {file = "scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc4765af3386811c3ca21638f63b9cf5ecf66261cc4815c1db3f1e7dc7b79db2"}, + {file = "scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25fc636bdaf1cc2f4a124a116312d837148b5e10872147bdaf4887926b8c03d8"}, + {file = "scikit_learn-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:fa909b1a36e000a03c382aade0bd2063fd5680ff8b8e501660c0f59f021a6415"}, + {file = "scikit_learn-1.6.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:926f207c804104677af4857b2c609940b743d04c4c35ce0ddc8ff4f053cddc1b"}, + {file = "scikit_learn-1.6.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c2cae262064e6a9b77eee1c8e768fc46aa0b8338c6a8297b9b6759720ec0ff2"}, + {file = "scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1061b7c028a8663fb9a1a1baf9317b64a257fcb036dae5c8752b2abef31d136f"}, + {file = "scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e69fab4ebfc9c9b580a7a80111b43d214ab06250f8a7ef590a4edf72464dd86"}, + {file = "scikit_learn-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:70b1d7e85b1c96383f872a519b3375f92f14731e279a7b4c6cfd650cf5dffc52"}, + {file = "scikit_learn-1.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2ffa1e9e25b3d93990e74a4be2c2fc61ee5af85811562f1288d5d055880c4322"}, + {file = "scikit_learn-1.6.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:dc5cf3d68c5a20ad6d571584c0750ec641cc46aeef1c1507be51300e6003a7e1"}, + {file = "scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c06beb2e839ecc641366000ca84f3cf6fa9faa1777e29cf0c04be6e4d096a348"}, + {file = "scikit_learn-1.6.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8ca8cb270fee8f1f76fa9bfd5c3507d60c6438bbee5687f81042e2bb98e5a97"}, + {file = "scikit_learn-1.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:7a1c43c8ec9fde528d664d947dc4c0789be4077a3647f232869f41d9bf50e0fb"}, + {file = "scikit_learn-1.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a17c1dea1d56dcda2fac315712f3651a1fea86565b64b48fa1bc090249cbf236"}, + {file = "scikit_learn-1.6.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:6a7aa5f9908f0f28f4edaa6963c0a6183f1911e63a69aa03782f0d924c830a35"}, + {file = "scikit_learn-1.6.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0650e730afb87402baa88afbf31c07b84c98272622aaba002559b614600ca691"}, + {file = "scikit_learn-1.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:3f59fe08dc03ea158605170eb52b22a105f238a5d512c4470ddeca71feae8e5f"}, + {file = "scikit_learn-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6849dd3234e87f55dce1db34c89a810b489ead832aaf4d4550b7ea85628be6c1"}, + {file = "scikit_learn-1.6.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:e7be3fa5d2eb9be7d77c3734ff1d599151bb523674be9b834e8da6abe132f44e"}, + {file = "scikit_learn-1.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44a17798172df1d3c1065e8fcf9019183f06c87609b49a124ebdf57ae6cb0107"}, + {file = "scikit_learn-1.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8b7a3b86e411e4bce21186e1c180d792f3d99223dcfa3b4f597ecc92fa1a422"}, + {file = "scikit_learn-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:7a73d457070e3318e32bdb3aa79a8d990474f19035464dfd8bede2883ab5dc3b"}, + {file = "scikit_learn-1.6.1.tar.gz", hash = "sha256:b4fc2525eca2c69a59260f583c56a7557c6ccdf8deafdba6e060f94c1c59738e"}, ] [package.dependencies] @@ -1814,51 +2005,51 @@ plots = ["matplotlib (>=2.0.0)"] [[package]] name = "scipy" -version = "1.15.0" +version = "1.15.1" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.10" files = [ - {file = "scipy-1.15.0-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:aeac60d3562a7bf2f35549bdfdb6b1751c50590f55ce7322b4b2fc821dc27fca"}, - {file = "scipy-1.15.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5abbdc6ede5c5fed7910cf406a948e2c0869231c0db091593a6b2fa78be77e5d"}, - {file = "scipy-1.15.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:eb1533c59f0ec6c55871206f15a5c72d1fae7ad3c0a8ca33ca88f7c309bbbf8c"}, - {file = "scipy-1.15.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:de112c2dae53107cfeaf65101419662ac0a54e9a088c17958b51c95dac5de56d"}, - {file = "scipy-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2240e1fd0782e62e1aacdc7234212ee271d810f67e9cd3b8d521003a82603ef8"}, - {file = "scipy-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d35aef233b098e4de88b1eac29f0df378278e7e250a915766786b773309137c4"}, - {file = "scipy-1.15.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1b29e4fc02e155a5fd1165f1e6a73edfdd110470736b0f48bcbe48083f0eee37"}, - {file = "scipy-1.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:0e5b34f8894f9904cc578008d1a9467829c1817e9f9cb45e6d6eeb61d2ab7731"}, - {file = "scipy-1.15.0-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:46e91b5b16909ff79224b56e19cbad65ca500b3afda69225820aa3afbf9ec020"}, - {file = "scipy-1.15.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:82bff2eb01ccf7cea8b6ee5274c2dbeadfdac97919da308ee6d8e5bcbe846443"}, - {file = "scipy-1.15.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:9c8254fe21dd2c6c8f7757035ec0c31daecf3bb3cffd93bc1ca661b731d28136"}, - {file = "scipy-1.15.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:c9624eeae79b18cab1a31944b5ef87aa14b125d6ab69b71db22f0dbd962caf1e"}, - {file = "scipy-1.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d13bbc0658c11f3d19df4138336e4bce2c4fbd78c2755be4bf7b8e235481557f"}, - {file = "scipy-1.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdca4c7bb8dc41307e5f39e9e5d19c707d8e20a29845e7533b3bb20a9d4ccba0"}, - {file = "scipy-1.15.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6f376d7c767731477bac25a85d0118efdc94a572c6b60decb1ee48bf2391a73b"}, - {file = "scipy-1.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:61513b989ee8d5218fbeb178b2d51534ecaddba050db949ae99eeb3d12f6825d"}, - {file = "scipy-1.15.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5beb0a2200372b7416ec73fdae94fe81a6e85e44eb49c35a11ac356d2b8eccc6"}, - {file = "scipy-1.15.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fde0f3104dfa1dfbc1f230f65506532d0558d43188789eaf68f97e106249a913"}, - {file = "scipy-1.15.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:35c68f7044b4e7ad73a3e68e513dda946989e523df9b062bd3cf401a1a882192"}, - {file = "scipy-1.15.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:52475011be29dfcbecc3dfe3060e471ac5155d72e9233e8d5616b84e2b542054"}, - {file = "scipy-1.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5972e3f96f7dda4fd3bb85906a17338e65eaddfe47f750e240f22b331c08858e"}, - {file = "scipy-1.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe00169cf875bed0b3c40e4da45b57037dc21d7c7bf0c85ed75f210c281488f1"}, - {file = "scipy-1.15.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:161f80a98047c219c257bf5ce1777c574bde36b9d962a46b20d0d7e531f86863"}, - {file = "scipy-1.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:327163ad73e54541a675240708244644294cb0a65cca420c9c79baeb9648e479"}, - {file = "scipy-1.15.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0fcb16eb04d84670722ce8d93b05257df471704c913cb0ff9dc5a1c31d1e9422"}, - {file = "scipy-1.15.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:767e8cf6562931f8312f4faa7ddea412cb783d8df49e62c44d00d89f41f9bbe8"}, - {file = "scipy-1.15.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:37ce9394cdcd7c5f437583fc6ef91bd290014993900643fdfc7af9b052d1613b"}, - {file = "scipy-1.15.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:6d26f17c64abd6c6c2dfb39920f61518cc9e213d034b45b2380e32ba78fde4c0"}, - {file = "scipy-1.15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e2448acd79c6374583581a1ded32ac71a00c2b9c62dfa87a40e1dd2520be111"}, - {file = "scipy-1.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36be480e512d38db67f377add5b759fb117edd987f4791cdf58e59b26962bee4"}, - {file = "scipy-1.15.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ccb6248a9987193fe74363a2d73b93bc2c546e0728bd786050b7aef6e17db03c"}, - {file = "scipy-1.15.0-cp313-cp313-win_amd64.whl", hash = "sha256:952d2e9eaa787f0a9e95b6e85da3654791b57a156c3e6609e65cc5176ccfe6f2"}, - {file = "scipy-1.15.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:b1432102254b6dc7766d081fa92df87832ac25ff0b3d3a940f37276e63eb74ff"}, - {file = "scipy-1.15.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:4e08c6a36f46abaedf765dd2dfcd3698fa4bd7e311a9abb2d80e33d9b2d72c34"}, - {file = "scipy-1.15.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:ec915cd26d76f6fc7ae8522f74f5b2accf39546f341c771bb2297f3871934a52"}, - {file = "scipy-1.15.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:351899dd2a801edd3691622172bc8ea01064b1cada794f8641b89a7dc5418db6"}, - {file = "scipy-1.15.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9baff912ea4f78a543d183ed6f5b3bea9784509b948227daaf6f10727a0e2e5"}, - {file = "scipy-1.15.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:cd9d9198a7fd9a77f0eb5105ea9734df26f41faeb2a88a0e62e5245506f7b6df"}, - {file = "scipy-1.15.0-cp313-cp313t-win_amd64.whl", hash = "sha256:129f899ed275c0515d553b8d31696924e2ca87d1972421e46c376b9eb87de3d2"}, - {file = "scipy-1.15.0.tar.gz", hash = "sha256:300742e2cc94e36a2880ebe464a1c8b4352a7b0f3e36ec3d2ac006cdbe0219ac"}, + {file = "scipy-1.15.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:c64ded12dcab08afff9e805a67ff4480f5e69993310e093434b10e85dc9d43e1"}, + {file = "scipy-1.15.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5b190b935e7db569960b48840e5bef71dc513314cc4e79a1b7d14664f57fd4ff"}, + {file = "scipy-1.15.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:4b17d4220df99bacb63065c76b0d1126d82bbf00167d1730019d2a30d6ae01ea"}, + {file = "scipy-1.15.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:63b9b6cd0333d0eb1a49de6f834e8aeaefe438df8f6372352084535ad095219e"}, + {file = "scipy-1.15.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f151e9fb60fbf8e52426132f473221a49362091ce7a5e72f8aa41f8e0da4f25"}, + {file = "scipy-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21e10b1dd56ce92fba3e786007322542361984f8463c6d37f6f25935a5a6ef52"}, + {file = "scipy-1.15.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5dff14e75cdbcf07cdaa1c7707db6017d130f0af9ac41f6ce443a93318d6c6e0"}, + {file = "scipy-1.15.1-cp310-cp310-win_amd64.whl", hash = "sha256:f82fcf4e5b377f819542fbc8541f7b5fbcf1c0017d0df0bc22c781bf60abc4d8"}, + {file = "scipy-1.15.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:5bd8d27d44e2c13d0c1124e6a556454f52cd3f704742985f6b09e75e163d20d2"}, + {file = "scipy-1.15.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:be3deeb32844c27599347faa077b359584ba96664c5c79d71a354b80a0ad0ce0"}, + {file = "scipy-1.15.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:5eb0ca35d4b08e95da99a9f9c400dc9f6c21c424298a0ba876fdc69c7afacedf"}, + {file = "scipy-1.15.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:74bb864ff7640dea310a1377d8567dc2cb7599c26a79ca852fc184cc851954ac"}, + {file = "scipy-1.15.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:667f950bf8b7c3a23b4199db24cb9bf7512e27e86d0e3813f015b74ec2c6e3df"}, + {file = "scipy-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:395be70220d1189756068b3173853029a013d8c8dd5fd3d1361d505b2aa58fa7"}, + {file = "scipy-1.15.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ce3a000cd28b4430426db2ca44d96636f701ed12e2b3ca1f2b1dd7abdd84b39a"}, + {file = "scipy-1.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:3fe1d95944f9cf6ba77aa28b82dd6bb2a5b52f2026beb39ecf05304b8392864b"}, + {file = "scipy-1.15.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c09aa9d90f3500ea4c9b393ee96f96b0ccb27f2f350d09a47f533293c78ea776"}, + {file = "scipy-1.15.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:0ac102ce99934b162914b1e4a6b94ca7da0f4058b6d6fd65b0cef330c0f3346f"}, + {file = "scipy-1.15.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:09c52320c42d7f5c7748b69e9f0389266fd4f82cf34c38485c14ee976cb8cb04"}, + {file = "scipy-1.15.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:cdde8414154054763b42b74fe8ce89d7f3d17a7ac5dd77204f0e142cdc9239e9"}, + {file = "scipy-1.15.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c9d8fc81d6a3b6844235e6fd175ee1d4c060163905a2becce8e74cb0d7554ce"}, + {file = "scipy-1.15.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0fb57b30f0017d4afa5fe5f5b150b8f807618819287c21cbe51130de7ccdaed2"}, + {file = "scipy-1.15.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:491d57fe89927fa1aafbe260f4cfa5ffa20ab9f1435025045a5315006a91b8f5"}, + {file = "scipy-1.15.1-cp312-cp312-win_amd64.whl", hash = "sha256:900f3fa3db87257510f011c292a5779eb627043dd89731b9c461cd16ef76ab3d"}, + {file = "scipy-1.15.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:100193bb72fbff37dbd0bf14322314fc7cbe08b7ff3137f11a34d06dc0ee6b85"}, + {file = "scipy-1.15.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:2114a08daec64980e4b4cbdf5bee90935af66d750146b1d2feb0d3ac30613692"}, + {file = "scipy-1.15.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:6b3e71893c6687fc5e29208d518900c24ea372a862854c9888368c0b267387ab"}, + {file = "scipy-1.15.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:837299eec3d19b7e042923448d17d95a86e43941104d33f00da7e31a0f715d3c"}, + {file = "scipy-1.15.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82add84e8a9fb12af5c2c1a3a3f1cb51849d27a580cb9e6bd66226195142be6e"}, + {file = "scipy-1.15.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:070d10654f0cb6abd295bc96c12656f948e623ec5f9a4eab0ddb1466c000716e"}, + {file = "scipy-1.15.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:55cc79ce4085c702ac31e49b1e69b27ef41111f22beafb9b49fea67142b696c4"}, + {file = "scipy-1.15.1-cp313-cp313-win_amd64.whl", hash = "sha256:c352c1b6d7cac452534517e022f8f7b8d139cd9f27e6fbd9f3cbd0bfd39f5bef"}, + {file = "scipy-1.15.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0458839c9f873062db69a03de9a9765ae2e694352c76a16be44f93ea45c28d2b"}, + {file = "scipy-1.15.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:af0b61c1de46d0565b4b39c6417373304c1d4f5220004058bdad3061c9fa8a95"}, + {file = "scipy-1.15.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:71ba9a76c2390eca6e359be81a3e879614af3a71dfdabb96d1d7ab33da6f2364"}, + {file = "scipy-1.15.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:14eaa373c89eaf553be73c3affb11ec6c37493b7eaaf31cf9ac5dffae700c2e0"}, + {file = "scipy-1.15.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f735bc41bd1c792c96bc426dece66c8723283695f02df61dcc4d0a707a42fc54"}, + {file = "scipy-1.15.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2722a021a7929d21168830790202a75dbb20b468a8133c74a2c0230c72626b6c"}, + {file = "scipy-1.15.1-cp313-cp313t-win_amd64.whl", hash = "sha256:bc7136626261ac1ed988dca56cfc4ab5180f75e0ee52e58f1e6aa74b5f3eacd5"}, + {file = "scipy-1.15.1.tar.gz", hash = "sha256:033a75ddad1463970c96a88063a1df87ccfddd526437136b6ee81ff0312ebdf6"}, ] [package.dependencies] @@ -1869,25 +2060,52 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodest doc = ["intersphinx_registry", "jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.16.5)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<8.0.0)", "sphinx-copybutton", "sphinx-design (>=0.4.0)"] test = ["Cython", "array-api-strict (>=2.0,<2.1.1)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +[[package]] +name = "sentence-transformers" +version = "3.3.1" +description = "State-of-the-Art Text Embeddings" +optional = false +python-versions = ">=3.9" +files = [ + {file = "sentence_transformers-3.3.1-py3-none-any.whl", hash = "sha256:abffcc79dab37b7d18d21a26d5914223dd42239cfe18cb5e111c66c54b658ae7"}, + {file = "sentence_transformers-3.3.1.tar.gz", hash = "sha256:9635dbfb11c6b01d036b9cfcee29f7716ab64cf2407ad9f403a2e607da2ac48b"}, +] + +[package.dependencies] +huggingface-hub = ">=0.20.0" +Pillow = "*" +scikit-learn = "*" +scipy = "*" +torch = ">=1.11.0" +tqdm = "*" +transformers = ">=4.41.0,<5.0.0" + +[package.extras] +dev = ["accelerate (>=0.20.3)", "datasets", "peft", "pre-commit", "pytest", "pytest-cov"] +onnx = ["optimum[onnxruntime] (>=1.23.1)"] +onnx-gpu = ["optimum[onnxruntime-gpu] (>=1.23.1)"] +openvino = ["optimum-intel[openvino] (>=1.20.0)"] +train = ["accelerate (>=0.20.3)", "datasets"] + [[package]] name = "setuptools" -version = "75.6.0" +version = "75.8.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.9" files = [ - {file = "setuptools-75.6.0-py3-none-any.whl", hash = "sha256:ce74b49e8f7110f9bf04883b730f4765b774ef3ef28f722cce7c273d253aaf7d"}, - {file = "setuptools-75.6.0.tar.gz", hash = "sha256:8199222558df7c86216af4f84c30e9b34a61d8ba19366cc914424cdbd28252f6"}, + {file = "setuptools-75.8.0-py3-none-any.whl", hash = "sha256:e3982f444617239225d675215d51f6ba05f845d4eec313da4418fdbb56fb27e3"}, + {file = "setuptools-75.8.0.tar.gz", hash = "sha256:c5afc8f407c626b8313a86e10311dd3f661c6cd9c09d4bf8c15c0e11f9f2b0e6"}, ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.7.0)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"] core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] -type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (>=1.12,<1.14)", "pytest-mypy"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"] [[package]] name = "six" @@ -1928,6 +2146,38 @@ files = [ {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, ] +[[package]] +name = "tokenizers" +version = "0.21.0" +description = "" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"}, + {file = "tokenizers-0.21.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f53ea537c925422a2e0e92a24cce96f6bc5046bbef24a1652a5edc8ba975f62e"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b177fb54c4702ef611de0c069d9169f0004233890e0c4c5bd5508ae05abf193"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b43779a269f4629bebb114e19c3fca0223296ae9fea8bb9a7a6c6fb0657ff8e"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aeb255802be90acfd363626753fda0064a8df06031012fe7d52fd9a905eb00e"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8b09dbeb7a8d73ee204a70f94fc06ea0f17dcf0844f16102b9f414f0b7463ba"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:400832c0904f77ce87c40f1a8a27493071282f785724ae62144324f171377273"}, + {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84ca973b3a96894d1707e189c14a774b701596d579ffc7e69debfc036a61a04"}, + {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:eb7202d231b273c34ec67767378cd04c767e967fda12d4a9e36208a34e2f137e"}, + {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:089d56db6782a73a27fd8abf3ba21779f5b85d4a9f35e3b493c7bbcbbf0d539b"}, + {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:c87ca3dc48b9b1222d984b6b7490355a6fdb411a2d810f6f05977258400ddb74"}, + {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4145505a973116f91bc3ac45988a92e618a6f83eb458f49ea0790df94ee243ff"}, + {file = "tokenizers-0.21.0-cp39-abi3-win32.whl", hash = "sha256:eb1702c2f27d25d9dd5b389cc1f2f51813e99f8ca30d9e25348db6585a97e24a"}, + {file = "tokenizers-0.21.0-cp39-abi3-win_amd64.whl", hash = "sha256:87841da5a25a3a5f70c102de371db120f41873b854ba65e52bccd57df5a3780c"}, + {file = "tokenizers-0.21.0.tar.gz", hash = "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4"}, +] + +[package.dependencies] +huggingface-hub = ">=0.16.4,<1.0" + +[package.extras] +dev = ["tokenizers[testing]"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] + [[package]] name = "tomli" version = "2.2.1" @@ -2070,6 +2320,75 @@ notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] +[[package]] +name = "transformers" +version = "4.48.0" +description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +optional = false +python-versions = ">=3.9.0" +files = [ + {file = "transformers-4.48.0-py3-none-any.whl", hash = "sha256:6d3de6d71cb5f2a10f9775ccc17abce9620195caaf32ec96542bd2a6937f25b0"}, + {file = "transformers-4.48.0.tar.gz", hash = "sha256:03fdfcbfb8b0367fb6c9fbe9d1c9aa54dfd847618be9b52400b2811d22799cb1"}, +] + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.24.0,<1.0" +numpy = ">=1.17" +packaging = ">=20.0" +pyyaml = ">=5.1" +regex = "!=2019.12.17" +requests = "*" +safetensors = ">=0.4.1" +tokenizers = ">=0.21,<0.22" +tqdm = ">=4.27" + +[package.extras] +accelerate = ["accelerate (>=0.26.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=2.0)"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"] +audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +benchmark = ["optimum-benchmark (>=0.3.0)"] +codecarbon = ["codecarbon (>=2.8.1)"] +deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] +flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +ftfy = ["ftfy"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +modelcreation = ["cookiecutter (==1.7.3)"] +natten = ["natten (>=0.14.6,<0.15.0)"] +onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] +onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +optuna = ["optuna"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] +retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +ruff = ["ruff (==0.5.1)"] +sagemaker = ["sagemaker (>=2.31.0)"] +sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic", "starlette", "uvicorn"] +sigopt = ["sigopt"] +sklearn = ["scikit-learn"] +speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +tiktoken = ["blobfile", "tiktoken"] +timm = ["timm (<=1.0.11)"] +tokenizers = ["tokenizers (>=0.21,<0.22)"] +torch = ["accelerate (>=0.26.0)", "torch (>=2.0)"] +torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.24.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"] +video = ["av (==9.2.0)"] +vision = ["Pillow (>=10.0.1,<=15.0)"] + [[package]] name = "triton" version = "3.1.0" @@ -2143,13 +2462,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "virtualenv" -version = "20.28.1" +version = "20.29.1" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.8" files = [ - {file = "virtualenv-20.28.1-py3-none-any.whl", hash = "sha256:412773c85d4dab0409b83ec36f7a6499e72eaf08c80e81e9576bca61831c71cb"}, - {file = "virtualenv-20.28.1.tar.gz", hash = "sha256:5d34ab240fdb5d21549b76f9e8ff3af28252f5499fb6d6f031adac4e5a8c5329"}, + {file = "virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779"}, + {file = "virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35"}, ] [package.dependencies] @@ -2260,4 +2579,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = ">=3.10, <=3.12" -content-hash = "c7606af7fb47a2fb5e856b23ef3e06a1740544bda46470dafeb7c7a3ca794d5e" +content-hash = "a2145b5f1d55eea1ccc3ef498aa90aaa96a62ab56928a3adf33a51b2700361be" diff --git a/pyproject.toml b/pyproject.toml index 70b8acc..e2ce51b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ scikit-optimize = "^0.10.2" einops = "^0.8.0" accelerate = "^1.2.1" scipy = "^1.15.0" - +sentence-transformers = "^3.3.1" [tool.poetry.group.dev.dependencies] pytest = "^8.1" @@ -56,19 +56,6 @@ line-length = 120 target-version = "py310" exclude = ["*.ipynb", "mambular/arch_utils/mamba_utils.mamba_orginal.py"] -ignore = [ - "B006", - "F401", # Ignore unused imports - "F841", # Ignore unused local variables - "E501", # Ignore line length - "D100", # Missing module-level docstring - "D101", # Missing class-level docstring - "D102", # Missing method-level docstring - "D103", # Missing function-level docstring - "B007", - "S307", -] - [tool.ruff.lint] select = [ "A", # flake8-buildins @@ -83,6 +70,19 @@ select = [ "W", # pycodestyle - warnings ] +ignore = [ + "B006", + "F401", # Ignore unused imports + "F841", # Ignore unused local variables + "E501", # Ignore line length + "D100", # Missing module-level docstring + "D101", # Missing class-level docstring + "D102", # Missing method-level docstring + "D103", # Missing function-level docstring + "B007", + "S307", +] + [tool.ruff.lint.per-file-ignores] # allow asserts in test files (bandit) From c4df54130787628d289c3671a70fb4f0084da574 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Sun, 19 Jan 2025 23:05:52 +0100 Subject: [PATCH 16/18] fix: B904 --- mambular/preprocessing/prepro_utils.py | 34 ++++++++------------------ 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/mambular/preprocessing/prepro_utils.py b/mambular/preprocessing/prepro_utils.py index 1d5e41b..4c9e964 100644 --- a/mambular/preprocessing/prepro_utils.py +++ b/mambular/preprocessing/prepro_utils.py @@ -57,10 +57,7 @@ def fit(self, X, y=None): self: Returns the instance itself. """ # Fit should determine the mapping from original categories to sequential integers starting from 0 - self.mapping_ = [ - {category: i + 1 for i, category in enumerate(np.unique(col))} - for col in X.T - ] + self.mapping_ = [{category: i + 1 for i, category in enumerate(np.unique(col))} for col in X.T] for mapping in self.mapping_: mapping[None] = 0 # Assign 0 to unknown values return self @@ -75,12 +72,7 @@ def transform(self, X): X_transformed (ndarray of shape (n_samples, n_features)): The transformed data with integer values. """ # Transform the categories to their mapped integer values - X_transformed = np.array( - [ - [self.mapping_[col].get(value, 0) for col, value in enumerate(row)] - for row in X - ] - ) + X_transformed = np.array([[self.mapping_[col].get(value, 0) for col, value in enumerate(row)] for row in X]) return X_transformed def get_feature_names_out(self, input_features=None): @@ -122,9 +114,7 @@ def fit(self, X, y=None): Returns: self: Returns the instance itself. """ - self.max_bins_ = ( - np.max(X, axis=0).astype(int) + 1 - ) # Find the maximum bin index for each feature + self.max_bins_ = np.max(X, axis=0).astype(int) + 1 # Find the maximum bin index for each feature return self def transform(self, X): @@ -207,9 +197,7 @@ def get_feature_names_out(self, input_features=None): feature_names (array of shape (n_features,)): The original feature names. """ if input_features is None: - raise ValueError( - "input_features must be provided to generate feature names." - ) + raise ValueError("input_features must be provided to generate feature names.") return np.array(input_features) @@ -243,10 +231,10 @@ def __init__(self, model_name="paraphrase-MiniLM-L3-v2", model=None): from sentence_transformers import SentenceTransformer self.model = SentenceTransformer(model_name) - except ImportError: + except ImportError as e: raise ImportError( "sentence-transformers is not installed. Install it via `pip install sentence-transformers` or provide a preloaded model." - ) + ) from e def fit(self, X, y=None): """Fit method (not required for a transformer but included for compatibility).""" @@ -264,13 +252,11 @@ def transform(self, X): - A 2D numpy array with embeddings for each text input. """ if isinstance(X, np.ndarray): - X = ( - X.flatten().astype(str).tolist() - ) # Convert to a list of strings if passed as an array + X = X.flatten().astype(str).tolist() # Convert to a list of strings if passed as an array elif isinstance(X, list): X = [str(x) for x in X] # Ensure everything is a string - embeddings = self.model.encode( - X, convert_to_numpy=True - ) # Get sentence embeddings + if self.model is None: + raise ValueError("Model is not initialized. Ensure that the model is properly loaded.") + embeddings = self.model.encode(X, convert_to_numpy=True) # Get sentence embeddings return embeddings From 2473e5c99e3a6e7a70fcf98fff3bb2e0a7fb56a3 Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Sun, 19 Jan 2025 23:08:01 +0100 Subject: [PATCH 17/18] chore: auto formatting --- mambular/base_models/lightning_wrapper.py | 23 ++--- mambular/data_utils/dataset.py | 13 +-- mambular/models/sklearn_base_classifier.py | 101 ++++++--------------- mambular/models/sklearn_base_lss.py | 100 ++++++-------------- mambular/models/sklearn_base_regressor.py | 101 ++++++--------------- mambular/preprocessing/basis_expansion.py | 6 +- mambular/preprocessing/ple_encoding.py | 7 +- mambular/preprocessing/preprocessor.py | 84 +++++------------ 8 files changed, 117 insertions(+), 318 deletions(-) diff --git a/mambular/base_models/lightning_wrapper.py b/mambular/base_models/lightning_wrapper.py index afc17b8..1d8530e 100644 --- a/mambular/base_models/lightning_wrapper.py +++ b/mambular/base_models/lightning_wrapper.py @@ -1,4 +1,5 @@ from collections.abc import Callable + import lightning as pl import torch import torch.nn as nn @@ -144,10 +145,7 @@ def compute_loss(self, predictions, y_true): ) if getattr(self.base_model, "returns_ensemble", False): # Ensemble case - if ( - self.loss_fct.__class__.__name__ == "CrossEntropyLoss" - and predictions.dim() == 3 - ): + if self.loss_fct.__class__.__name__ == "CrossEntropyLoss" and predictions.dim() == 3: # Classification case with ensemble: predictions (N, E, k), y_true (N,) N, E, k = predictions.shape loss = 0.0 @@ -192,18 +190,14 @@ def training_step(self, batch, batch_idx): # type: ignore # Check if the model has a `penalty_forward` method if hasattr(self.base_model, "penalty_forward"): - preds, penalty = self.base_model.penalty_forward( - num_features=num_features, cat_features=cat_features - ) + preds, penalty = self.base_model.penalty_forward(num_features=num_features, cat_features=cat_features) loss = self.compute_loss(preds, labels) + penalty else: preds = self(num_features=num_features, cat_features=cat_features) loss = self.compute_loss(preds, labels) # Log the training loss - self.log( - "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True - ) + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) # Log custom training metrics for metric_name, metric_fn in self.train_metrics.items(): @@ -352,13 +346,8 @@ def on_validation_epoch_end(self): # Apply pruning logic if needed if self.current_epoch >= self.pruning_epoch: - if ( - self.early_pruning_threshold is not None - and val_loss_value > self.early_pruning_threshold - ): - print( - f"Pruned at epoch {self.current_epoch}, val_loss {val_loss_value}" - ) + if self.early_pruning_threshold is not None and val_loss_value > self.early_pruning_threshold: + print(f"Pruned at epoch {self.current_epoch}, val_loss {val_loss_value}") self.trainer.should_stop = True # Stop training early def epoch_val_loss_at(self, epoch): diff --git a/mambular/data_utils/dataset.py b/mambular/data_utils/dataset.py index 20076ea..6bb0485 100644 --- a/mambular/data_utils/dataset.py +++ b/mambular/data_utils/dataset.py @@ -3,11 +3,6 @@ from torch.utils.data import Dataset -import numpy as np -import torch -from torch.utils.data import Dataset - - class MambularDataset(Dataset): """Custom dataset for handling structured data with separate categorical and numerical features, tailored for both regression and classification tasks. @@ -20,9 +15,7 @@ class MambularDataset(Dataset): regression (bool, optional): A flag indicating if the dataset is for a regression task. Defaults to True. """ - def __init__( - self, cat_features_list, num_features_list, labels=None, regression=True - ): + def __init__(self, cat_features_list, num_features_list, labels=None, regression=True): self.cat_features_list = cat_features_list # Categorical features tensors self.num_features_list = num_features_list # Numerical features tensors self.regression = regression @@ -56,9 +49,7 @@ def __getitem__(self, idx): tuple: A tuple containing two lists of tensors (one for categorical features and one for numerical features) and a single label (if available). """ - cat_features = [ - feature_tensor[idx] for feature_tensor in self.cat_features_list - ] + cat_features = [feature_tensor[idx] for feature_tensor in self.cat_features_list] num_features = [ torch.as_tensor(feature_tensor[idx]).clone().detach().to(torch.float32) for feature_tensor in self.num_features_list diff --git a/mambular/models/sklearn_base_classifier.py b/mambular/models/sklearn_base_classifier.py index f4f8699..6317e62 100644 --- a/mambular/models/sklearn_base_classifier.py +++ b/mambular/models/sklearn_base_classifier.py @@ -1,4 +1,5 @@ import warnings +from collections.abc import Callable from typing import Optional import lightning as pl @@ -9,17 +10,13 @@ from sklearn.base import BaseEstimator from sklearn.metrics import accuracy_score, log_loss, mean_squared_error from skopt import gp_minimize -from collections.abc import Callable +from torch.utils.data import DataLoader +from tqdm import tqdm + from ..base_models.lightning_wrapper import TaskModel from ..data_utils.datamodule import MambularDataModule from ..preprocessing import Preprocessor -from ..utils.config_mapper import ( - activation_mapper, - get_search_space, - round_to_nearest_16, -) -from tqdm import tqdm -from torch.utils.data import DataLoader +from ..utils.config_mapper import activation_mapper, get_search_space, round_to_nearest_16 class SklearnBaseClassifier(BaseEstimator): @@ -42,15 +39,11 @@ def __init__(self, model, config, **kwargs): ] self.config_kwargs = { - k: v - for k, v in kwargs.items() - if k not in self.preprocessor_arg_names and not k.startswith("optimizer") + k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names and not k.startswith("optimizer") } self.config = config(**self.config_kwargs) - preprocessor_kwargs = { - k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names - } + preprocessor_kwargs = {k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names} self.preprocessor = Preprocessor(**preprocessor_kwargs) self.task_model = None @@ -70,8 +63,7 @@ def __init__(self, model, config, **kwargs): self.optimizer_kwargs = { k: v for k, v in kwargs.items() - if k - not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] + if k not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] and k.startswith("optimizer_") } @@ -92,10 +84,7 @@ def get_params(self, deep=True): params.update(self.config_kwargs) if deep: - preprocessor_params = { - "prepro__" + key: value - for key, value in self.preprocessor.get_params().items() - } + preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()} params.update(preprocessor_params) return params @@ -113,14 +102,8 @@ def set_params(self, **parameters): self : object Estimator instance. """ - config_params = { - k: v for k, v in parameters.items() if not k.startswith("prepro__") - } - preprocessor_params = { - k.split("__")[1]: v - for k, v in parameters.items() - if k.startswith("prepro__") - } + config_params = {k: v for k, v in parameters.items() if not k.startswith("prepro__")} + preprocessor_params = {k.split("__")[1]: v for k, v in parameters.items() if k.startswith("prepro__")} if config_params: self.config_kwargs.update(config_params) @@ -218,9 +201,7 @@ def build_model( **dataloader_kwargs, ) - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) + self.data_module.preprocess_data(X, y, X_val, y_val, val_size=val_size, random_state=random_state) num_classes = len(np.unique(np.array(y))) @@ -230,14 +211,10 @@ def build_model( config=self.config, cat_feature_info=self.data_module.cat_feature_info, num_feature_info=self.data_module.num_feature_info, - lr_patience=( - lr_patience if lr_patience is not None else self.config.lr_patience - ), + lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience), lr=lr if lr is not None else self.config.lr, lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor, - weight_decay=( - weight_decay if weight_decay is not None else self.config.weight_decay - ), + weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay), train_metrics=train_metrics, val_metrics=val_metrics, optimizer_type=self.optimizer_type, @@ -268,9 +245,7 @@ def get_number_of_params(self, requires_grad=True): If the model has not been built prior to calling this method. """ if not self.built: - raise ValueError( - "The model must be built before the number of parameters can be estimated" - ) + raise ValueError("The model must be built before the number of parameters can be estimated") else: if requires_grad: return sum(p.numel() for p in self.task_model.parameters() if p.requires_grad) # type: ignore @@ -442,7 +417,7 @@ def predict(self, X, device=None): logits_list = self.trainer.predict(self.task_model, self.data_module) # Concatenate predictions from all batches - logits = torch.cat(logits_list, dim=0) + logits = torch.cat(logits_list, dim=0) # type: ignore # Check if ensemble is used if getattr(self.base_model, "returns_ensemble", False): # If using ensemble @@ -619,9 +594,7 @@ def encode(self, X, batch_size=64): # Process data in batches encoded_outputs = [] for num_features, cat_features in tqdm(data_loader): - embeddings = self.task_model.base_model.encode( - num_features, cat_features - ) # Call your encode function + embeddings = self.task_model.base_model.encode(num_features, cat_features) # Call your encode function encoded_outputs.append(embeddings) # Concatenate all encoded outputs @@ -689,13 +662,9 @@ def optimize_hparams( best_val_loss = float("inf") if X_val is not None and y_val is not None: - val_loss = self.evaluate( - X_val, y_val, metrics={"Accuracy": (accuracy_score, False)} - )["Accuracy"] + val_loss = self.evaluate(X_val, y_val, metrics={"Accuracy": (accuracy_score, False)})["Accuracy"] else: - val_loss = self.trainer.validate(self.task_model, self.data_module)[0][ - "val_loss" - ] + val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"] best_val_loss = val_loss best_epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore @@ -721,9 +690,7 @@ def _objective(hyperparams): if param_value in activation_mapper: setattr(self.config, key, activation_mapper[param_value]) else: - raise ValueError( - f"Unknown activation function: {param_value}" - ) + raise ValueError(f"Unknown activation function: {param_value}") else: setattr(self.config, key, param_value) @@ -732,15 +699,11 @@ def _objective(hyperparams): self.config.head_layer_sizes = head_layer_sizes[:head_layer_size_length] # Build the model with updated hyperparameters - self.build_model( - X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs - ) + self.build_model(X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs) # Dynamically set the early pruning threshold if prune_by_epoch: - early_pruning_threshold = ( - best_epoch_val_loss * 1.5 - ) # Prune based on specific epoch loss + early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss else: # Prune based on the best overall validation loss early_pruning_threshold = best_val_loss * 1.5 @@ -752,9 +715,7 @@ def _objective(hyperparams): # Fit the model (limit epochs for faster optimization) try: # Wrap the risky operation (model fitting) in a try-except block - self.fit( - X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False - ) + self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False) # Evaluate validation loss if X_val is not None and y_val is not None: @@ -762,9 +723,7 @@ def _objective(hyperparams): "Mean Squared Error" ] else: - val_loss = self.trainer.validate(self.task_model, self.data_module)[ - 0 - ]["val_loss"] + val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"] # Pruning based on validation loss at specific epoch epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore @@ -781,21 +740,15 @@ def _objective(hyperparams): except Exception as e: # Penalize the hyperparameter configuration with a large value - print( - f"Error encountered during fit with hyperparameters {hyperparams}: {e}" - ) - return ( - best_val_loss * 100 - ) # Large value to discourage this configuration + print(f"Error encountered during fit with hyperparameters {hyperparams}: {e}") + return best_val_loss * 100 # Large value to discourage this configuration # Perform Bayesian optimization using scikit-optimize result = gp_minimize(_objective, param_space, n_calls=time, random_state=42) # Update the model with the best-found hyperparameters best_hparams = result.x # type: ignore - head_layer_sizes = ( - [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None - ) + head_layer_sizes = [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None # Iterate over the best hyperparameters found by optimization diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py index 6242cc1..e1b63a7 100644 --- a/mambular/models/sklearn_base_lss.py +++ b/mambular/models/sklearn_base_lss.py @@ -1,4 +1,5 @@ import warnings +from collections.abc import Callable import lightning as pl import numpy as np @@ -9,7 +10,9 @@ from sklearn.base import BaseEstimator from sklearn.metrics import accuracy_score, mean_squared_error from skopt import gp_minimize -from collections.abc import Callable +from torch.utils.data import DataLoader +from tqdm import tqdm + from ..base_models.lightning_wrapper import TaskModel from ..data_utils.datamodule import MambularDataModule from ..preprocessing import Preprocessor @@ -39,8 +42,6 @@ Quantile, StudentTDistribution, ) -from tqdm import tqdm -from torch.utils.data import DataLoader class SklearnBaseLSS(BaseEstimator): @@ -63,15 +64,11 @@ def __init__(self, model, config, **kwargs): ] self.config_kwargs = { - k: v - for k, v in kwargs.items() - if k not in self.preprocessor_arg_names and not k.startswith("optimizer") + k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names and not k.startswith("optimizer") } self.config = config(**self.config_kwargs) - preprocessor_kwargs = { - k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names - } + preprocessor_kwargs = {k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names} self.preprocessor = Preprocessor(**preprocessor_kwargs) self.task_model = None @@ -92,8 +89,7 @@ def __init__(self, model, config, **kwargs): self.optimizer_kwargs = { k: v for k, v in kwargs.items() - if k - not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] + if k not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] and k.startswith("optimizer_") } @@ -114,10 +110,7 @@ def get_params(self, deep=True): params.update(self.config_kwargs) if deep: - preprocessor_params = { - "prepro__" + key: value - for key, value in self.preprocessor.get_params().items() - } + preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()} params.update(preprocessor_params) return params @@ -135,14 +128,8 @@ def set_params(self, **parameters): self : object Estimator instance. """ - config_params = { - k: v for k, v in parameters.items() if not k.startswith("prepro__") - } - preprocessor_params = { - k.split("__")[1]: v - for k, v in parameters.items() - if k.startswith("prepro__") - } + config_params = {k: v for k, v in parameters.items() if not k.startswith("prepro__")} + preprocessor_params = {k.split("__")[1]: v for k, v in parameters.items() if k.startswith("prepro__")} if config_params: self.config_kwargs.update(config_params) @@ -238,9 +225,7 @@ def build_model( **dataloader_kwargs, ) - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) + self.data_module.preprocess_data(X, y, X_val, y_val, val_size=val_size, random_state=random_state) self.task_model = TaskModel( model_class=self.base_model, # type: ignore @@ -250,13 +235,9 @@ def build_model( cat_feature_info=self.data_module.cat_feature_info, num_feature_info=self.data_module.num_feature_info, lr=lr if lr is not None else self.config.lr, - lr_patience=( - lr_patience if lr_patience is not None else self.config.lr_patience - ), + lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience), lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor, - weight_decay=( - weight_decay if weight_decay is not None else self.config.weight_decay - ), + weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay), lss=True, train_metrics=train_metrics, val_metrics=val_metrics, @@ -288,9 +269,7 @@ def get_number_of_params(self, requires_grad=True): If the model has not been built prior to calling this method. """ if not self.built: - raise ValueError( - "The model must be built before the number of parameters can be estimated" - ) + raise ValueError("The model must be built before the number of parameters can be estimated") else: if requires_grad: return sum(p.numel() for p in self.task_model.parameters() if p.requires_grad) # type: ignore @@ -530,9 +509,7 @@ def evaluate(self, X, y_true, metrics=None, distribution_family=None): """ # Infer distribution family from model settings if not provided if distribution_family is None: - distribution_family = getattr( - self.task_model, "distribution_family", "normal" - ) + distribution_family = getattr(self.task_model, "distribution_family", "normal") # Setup default metrics if none are provided if metrics is None: @@ -568,10 +545,7 @@ def get_default_metrics(self, distribution_family): "normal": { "MSE": lambda y, pred: mean_squared_error(y, pred[:, 0]), "CRPS": lambda y, pred: np.mean( - [ - ps.crps_gaussian(y[i], mu=pred[i, 0], sig=np.sqrt(pred[i, 1])) - for i in range(len(y)) - ] + [ps.crps_gaussian(y[i], mu=pred[i, 0], sig=np.sqrt(pred[i, 1])) for i in range(len(y))] ), }, "poisson": {"Poisson Deviance": poisson_deviance}, @@ -637,9 +611,7 @@ def encode(self, X, batch_size=64): # Process data in batches encoded_outputs = [] for num_features, cat_features in tqdm(data_loader): - embeddings = self.task_model.base_model.encode( - num_features, cat_features - ) # Call your encode function + embeddings = self.task_model.base_model.encode(num_features, cat_features) # Call your encode function encoded_outputs.append(embeddings) # Concatenate all encoded outputs @@ -712,9 +684,7 @@ def optimize_hparams( y_val, ) else: - val_loss = self.trainer.validate(self.task_model, self.data_module)[0][ - "val_loss" - ] + val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"] best_val_loss = val_loss best_epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore @@ -740,9 +710,7 @@ def _objective(hyperparams): if param_value in activation_mapper: setattr(self.config, key, activation_mapper[param_value]) else: - raise ValueError( - f"Unknown activation function: {param_value}" - ) + raise ValueError(f"Unknown activation function: {param_value}") else: setattr(self.config, key, param_value) @@ -751,15 +719,11 @@ def _objective(hyperparams): self.config.head_layer_sizes = head_layer_sizes[:head_layer_size_length] # Build the model with updated hyperparameters - self.build_model( - X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs - ) + self.build_model(X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs) # Dynamically set the early pruning threshold if prune_by_epoch: - early_pruning_threshold = ( - best_epoch_val_loss * 1.5 - ) # Prune based on specific epoch loss + early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss else: # Prune based on the best overall validation loss early_pruning_threshold = best_val_loss * 1.5 @@ -781,13 +745,11 @@ def _objective(hyperparams): # Evaluate validation loss if X_val is not None and y_val is not None: - val_loss = self.evaluate( - X_val, y_val, metrics={"Mean Squared Error": mean_squared_error} - )["Mean Squared Error"] + val_loss = self.evaluate(X_val, y_val, metrics={"Mean Squared Error": mean_squared_error})[ + "Mean Squared Error" + ] else: - val_loss = self.trainer.validate(self.task_model, self.data_module)[ - 0 - ]["val_loss"] + val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"] # Pruning based on validation loss at specific epoch epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore @@ -804,21 +766,15 @@ def _objective(hyperparams): except Exception as e: # Penalize the hyperparameter configuration with a large value - print( - f"Error encountered during fit with hyperparameters {hyperparams}: {e}" - ) - return ( - best_val_loss * 100 - ) # Large value to discourage this configuration + print(f"Error encountered during fit with hyperparameters {hyperparams}: {e}") + return best_val_loss * 100 # Large value to discourage this configuration # Perform Bayesian optimization using scikit-optimize result = gp_minimize(_objective, param_space, n_calls=time, random_state=42) # Update the model with the best-found hyperparameters best_hparams = result.x # type: ignore - head_layer_sizes = ( - [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None - ) + head_layer_sizes = [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None # Iterate over the best hyperparameters found by optimization diff --git a/mambular/models/sklearn_base_regressor.py b/mambular/models/sklearn_base_regressor.py index e17c2c5..04f9ac3 100644 --- a/mambular/models/sklearn_base_regressor.py +++ b/mambular/models/sklearn_base_regressor.py @@ -1,4 +1,5 @@ import warnings +from collections.abc import Callable import lightning as pl import pandas as pd @@ -7,7 +8,9 @@ from sklearn.base import BaseEstimator from sklearn.metrics import mean_squared_error from skopt import gp_minimize -from collections.abc import Callable +from torch.utils.data import DataLoader +from tqdm import tqdm + from ..base_models.lightning_wrapper import TaskModel from ..data_utils.datamodule import MambularDataModule from ..preprocessing import Preprocessor @@ -16,8 +19,6 @@ get_search_space, round_to_nearest_16, ) -from torch.utils.data import DataLoader -from tqdm import tqdm class SklearnBaseRegressor(BaseEstimator): @@ -40,15 +41,11 @@ def __init__(self, model, config, **kwargs): ] self.config_kwargs = { - k: v - for k, v in kwargs.items() - if k not in self.preprocessor_arg_names and not k.startswith("optimizer") + k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names and not k.startswith("optimizer") } self.config = config(**self.config_kwargs) - preprocessor_kwargs = { - k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names - } + preprocessor_kwargs = {k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names} self.preprocessor = Preprocessor(**preprocessor_kwargs) self.base_model = model @@ -68,8 +65,7 @@ def __init__(self, model, config, **kwargs): self.optimizer_kwargs = { k: v for k, v in kwargs.items() - if k - not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] + if k not in ["lr", "weight_decay", "patience", "lr_patience", "optimizer_type"] and k.startswith("optimizer_") } @@ -90,10 +86,7 @@ def get_params(self, deep=True): params.update(self.config_kwargs) if deep: - preprocessor_params = { - "prepro__" + key: value - for key, value in self.preprocessor.get_params().items() - } + preprocessor_params = {"prepro__" + key: value for key, value in self.preprocessor.get_params().items()} params.update(preprocessor_params) return params @@ -111,14 +104,8 @@ def set_params(self, **parameters): self : object Estimator instance. """ - config_params = { - k: v for k, v in parameters.items() if not k.startswith("prepro__") - } - preprocessor_params = { - k.split("__")[1]: v - for k, v in parameters.items() - if k.startswith("prepro__") - } + config_params = {k: v for k, v in parameters.items() if not k.startswith("prepro__")} + preprocessor_params = {k.split("__")[1]: v for k, v in parameters.items() if k.startswith("prepro__")} if config_params: self.config_kwargs.update(config_params) @@ -216,9 +203,7 @@ def build_model( **dataloader_kwargs, ) - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) + self.data_module.preprocess_data(X, y, X_val, y_val, val_size=val_size, random_state=random_state) self.task_model = TaskModel( model_class=self.base_model, # type: ignore @@ -226,13 +211,9 @@ def build_model( cat_feature_info=self.data_module.cat_feature_info, num_feature_info=self.data_module.num_feature_info, lr=lr if lr is not None else self.config.lr, - lr_patience=( - lr_patience if lr_patience is not None else self.config.lr_patience - ), + lr_patience=(lr_patience if lr_patience is not None else self.config.lr_patience), lr_factor=lr_factor if lr_factor is not None else self.config.lr_factor, - weight_decay=( - weight_decay if weight_decay is not None else self.config.weight_decay - ), + weight_decay=(weight_decay if weight_decay is not None else self.config.weight_decay), train_metrics=train_metrics, val_metrics=val_metrics, optimizer_type=self.optimizer_type, @@ -263,9 +244,7 @@ def get_number_of_params(self, requires_grad=True): If the model has not been built prior to calling this method. """ if not self.built: - raise ValueError( - "The model must be built before the number of parameters can be estimated" - ) + raise ValueError("The model must be built before the number of parameters can be estimated") else: if requires_grad: return sum(p.numel() for p in self.task_model.parameters() if p.requires_grad) # type: ignore @@ -535,9 +514,7 @@ def encode(self, X, batch_size=64): # Process data in batches encoded_outputs = [] for num_features, cat_features in tqdm(data_loader): - embeddings = self.task_model.base_model.encode( - num_features, cat_features - ) # Call your encode function + embeddings = self.task_model.base_model.encode(num_features, cat_features) # Call your encode function encoded_outputs.append(embeddings) # Concatenate all encoded outputs @@ -605,13 +582,11 @@ def optimize_hparams( best_val_loss = float("inf") if X_val is not None and y_val is not None: - val_loss = self.evaluate( - X_val, y_val, metrics={"Mean Squared Error": mean_squared_error} - )["Mean Squared Error"] - else: - val_loss = self.trainer.validate(self.task_model, self.data_module)[0][ - "val_loss" + val_loss = self.evaluate(X_val, y_val, metrics={"Mean Squared Error": mean_squared_error})[ + "Mean Squared Error" ] + else: + val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"] best_val_loss = val_loss best_epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore @@ -637,9 +612,7 @@ def _objective(hyperparams): if param_value in activation_mapper: setattr(self.config, key, activation_mapper[param_value]) else: - raise ValueError( - f"Unknown activation function: {param_value}" - ) + raise ValueError(f"Unknown activation function: {param_value}") else: setattr(self.config, key, param_value) @@ -648,15 +621,11 @@ def _objective(hyperparams): self.config.head_layer_sizes = head_layer_sizes[:head_layer_size_length] # Build the model with updated hyperparameters - self.build_model( - X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs - ) + self.build_model(X, y, X_val=X_val, y_val=y_val, lr=self.config.lr, **optimize_kwargs) # Dynamically set the early pruning threshold if prune_by_epoch: - early_pruning_threshold = ( - best_epoch_val_loss * 1.5 - ) # Prune based on specific epoch loss + early_pruning_threshold = best_epoch_val_loss * 1.5 # Prune based on specific epoch loss else: # Prune based on the best overall validation loss early_pruning_threshold = best_val_loss * 1.5 @@ -667,19 +636,15 @@ def _objective(hyperparams): try: # Wrap the risky operation (model fitting) in a try-except block - self.fit( - X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False - ) + self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False) # Evaluate validation loss if X_val is not None and y_val is not None: - val_loss = self.evaluate( - X_val, y_val, metrics={"Mean Squared Error": mean_squared_error} - )["Mean Squared Error"] + val_loss = self.evaluate(X_val, y_val, metrics={"Mean Squared Error": mean_squared_error})[ + "Mean Squared Error" + ] else: - val_loss = self.trainer.validate(self.task_model, self.data_module)[ - 0 - ]["val_loss"] + val_loss = self.trainer.validate(self.task_model, self.data_module)[0]["val_loss"] # Pruning based on validation loss at specific epoch epoch_val_loss = self.task_model.epoch_val_loss_at( # type: ignore @@ -696,21 +661,15 @@ def _objective(hyperparams): except Exception as e: # Penalize the hyperparameter configuration with a large value - print( - f"Error encountered during fit with hyperparameters {hyperparams}: {e}" - ) - return ( - best_val_loss * 100 - ) # Large value to discourage this configuration + print(f"Error encountered during fit with hyperparameters {hyperparams}: {e}") + return best_val_loss * 100 # Large value to discourage this configuration # Perform Bayesian optimization using scikit-optimize result = gp_minimize(_objective, param_space, n_calls=time, random_state=42) # Update the model with the best-found hyperparameters best_hparams = result.x # type: ignore - head_layer_sizes = ( - [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None - ) + head_layer_sizes = [] if "head_layer_sizes" in self.config.__dataclass_fields__ else None layer_sizes = [] if "layer_sizes" in self.config.__dataclass_fields__ else None # Iterate over the best hyperparameters found by optimization diff --git a/mambular/preprocessing/basis_expansion.py b/mambular/preprocessing/basis_expansion.py index 2ac9823..8ee46ed 100644 --- a/mambular/preprocessing/basis_expansion.py +++ b/mambular/preprocessing/basis_expansion.py @@ -43,7 +43,6 @@ def __init__( if spline_implementation not in ["scipy", "sklearn"]: raise ValueError("Invalid spline implementation. Choose 'scipy' or 'sklearn'.") - @staticmethod def knot_identification_using_decision_tree(X, y, task="regression", n_knots=5): # Use DecisionTreeClassifier for classification tasks @@ -75,7 +74,7 @@ def fit(self, X, y=None): raise ValueError("Target variable 'y' must be provided when use_decision_tree=True.") self.knots = [] - self.n_features_in_ = X.shape[1] + self.n_features_in_ = X.shape[1] if self.use_decision_tree and self.spline_implementation == "scipy": self.knots = self.knot_identification_using_decision_tree(X, y, self.task, self.n_knots) @@ -105,8 +104,6 @@ def fit(self, X, y=None): self.transformer.fit(X) self.fitted = True - - elif self.spline_implementation == "sklearn" and not self.use_decision_tree: if self.strategy == "quantile": # print("Using sklearn spline transformer using quantile") @@ -126,7 +123,6 @@ def fit(self, X, y=None): self.fitted = True self.transformer.fit(X) - return self def transform(self, X): diff --git a/mambular/preprocessing/ple_encoding.py b/mambular/preprocessing/ple_encoding.py index 01c217a..3a70f24 100644 --- a/mambular/preprocessing/ple_encoding.py +++ b/mambular/preprocessing/ple_encoding.py @@ -74,7 +74,7 @@ def __init__(self, n_bins=20, tree_params={}, task="regression", conditions=None self.pattern = r"-?\d+\.?\d*[eE]?[+-]?\d*" def fit(self, feature, target): - self.n_features_in_ = 1 + self.n_features_in_ = 1 if self.task == "regression": dt = DecisionTreeRegressor(max_leaf_nodes=self.n_bins) elif self.task == "classification": @@ -85,11 +85,10 @@ def fit(self, feature, target): dt.fit(feature, target) self.conditions = tree_to_code(dt, ["feature"]) - #self.fitted = True + # self.fitted = True return self def transform(self, feature): - if feature.shape == (feature.shape[0], 1): feature = np.squeeze(feature, axis=1) else: @@ -137,8 +136,6 @@ def transform(self, feature): else: return np.array(ple_encoded_feature, dtype=np.float32) - - def get_feature_names_out(self, input_features=None): if input_features is None: raise ValueError("input_features must be specified") diff --git a/mambular/preprocessing/preprocessor.py b/mambular/preprocessing/preprocessor.py index be20529..a691649 100644 --- a/mambular/preprocessing/preprocessor.py +++ b/mambular/preprocessing/preprocessor.py @@ -22,10 +22,10 @@ from .prepro_utils import ( ContinuousOrdinalEncoder, CustomBinner, + LanguageEmbeddingTransformer, NoTransformer, OneHotFromOrdinal, ToFloatTransformer, - LanguageEmbeddingTransformer, ) @@ -111,14 +111,10 @@ def __init__( ): self.n_bins = n_bins self.numerical_preprocessing = ( - numerical_preprocessing.lower() - if numerical_preprocessing is not None - else "none" + numerical_preprocessing.lower() if numerical_preprocessing is not None else "none" ) self.categorical_preprocessing = ( - categorical_preprocessing.lower() - if categorical_preprocessing is not None - else "none" + categorical_preprocessing.lower() if categorical_preprocessing is not None else "none" ) if self.numerical_preprocessing not in [ "ple", @@ -241,19 +237,13 @@ def _detect_column_types(self, X): numerical_features.append(col) else: if isinstance(self.cat_cutoff, float): - cutoff_condition = ( - num_unique_values / total_samples - ) < self.cat_cutoff + cutoff_condition = (num_unique_values / total_samples) < self.cat_cutoff elif isinstance(self.cat_cutoff, int): cutoff_condition = num_unique_values < self.cat_cutoff else: - raise ValueError( - "cat_cutoff should be either a float or an integer." - ) + raise ValueError("cat_cutoff should be either a float or an integer.") - if X[col].dtype.kind not in "iufc" or ( - X[col].dtype.kind == "i" and cutoff_condition - ): + if X[col].dtype.kind not in "iufc" or (X[col].dtype.kind == "i" and cutoff_condition): categorical_features.append(col) else: numerical_features.append(col) @@ -318,11 +308,7 @@ def fit(self, X, y=None): ( "discretizer", KBinsDiscretizer( - n_bins=( - bins - if isinstance(bins, int) - else len(bins) - 1 - ), + n_bins=(bins if isinstance(bins, int) else len(bins) - 1), encode="ordinal", strategy=self.binning_strategy, # type: ignore subsample=200_000 if len(X) > 200_000 else None, @@ -351,17 +337,13 @@ def fit(self, X, y=None): numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.numerical_preprocessing == "minmax": - numeric_transformer_steps.append( - ("minmax", MinMaxScaler(feature_range=(-1, 1))) - ) + numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) elif self.numerical_preprocessing == "quantile": numeric_transformer_steps.append( ( "quantile", - QuantileTransformer( - n_quantiles=self.n_bins, random_state=101 - ), + QuantileTransformer(n_quantiles=self.n_bins, random_state=101), ) ) @@ -369,9 +351,7 @@ def fit(self, X, y=None): if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": - numeric_transformer_steps.append( - ("minmax", MinMaxScaler(feature_range=(-1, 1))) - ) + numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) numeric_transformer_steps.append( ( "polynomial", @@ -386,9 +366,7 @@ def fit(self, X, y=None): if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": - numeric_transformer_steps.append( - ("minmax", MinMaxScaler(feature_range=(-1, 1))) - ) + numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) numeric_transformer_steps.append( ( "splines", @@ -407,9 +385,7 @@ def fit(self, X, y=None): if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": - numeric_transformer_steps.append( - ("minmax", MinMaxScaler(feature_range=(-1, 1))) - ) + numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) numeric_transformer_steps.append( ( "rbf", @@ -426,9 +402,7 @@ def fit(self, X, y=None): if self.scaling_strategy == "standardization": numeric_transformer_steps.append(("scaler", StandardScaler())) elif self.scaling_strategy == "minmax": - numeric_transformer_steps.append( - ("minmax", MinMaxScaler(feature_range=(-1, 1))) - ) + numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) numeric_transformer_steps.append( ( "sigmoid", @@ -442,12 +416,8 @@ def fit(self, X, y=None): ) elif self.numerical_preprocessing == "ple": - numeric_transformer_steps.append( - ("minmax", MinMaxScaler(feature_range=(-1, 1))) - ) - numeric_transformer_steps.append( - ("ple", PLE(n_bins=self.n_bins, task=self.task)) - ) + numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1)))) + numeric_transformer_steps.append(("ple", PLE(n_bins=self.n_bins, task=self.task))) elif self.numerical_preprocessing == "box-cox": numeric_transformer_steps.append( @@ -513,18 +483,12 @@ def fit(self, X, y=None): ] ) else: - raise ValueError( - f"Unknown categorical_preprocessing type: {self.categorical_preprocessing}" - ) + raise ValueError(f"Unknown categorical_preprocessing type: {self.categorical_preprocessing}") # Append the transformer for the current categorical feature - transformers.append( - (f"cat_{feature}", categorical_transformer, [feature]) - ) + transformers.append((f"cat_{feature}", categorical_transformer, [feature])) - self.column_transformer = ColumnTransformer( - transformers=transformers, remainder="passthrough" - ) + self.column_transformer = ColumnTransformer(transformers=transformers, remainder="passthrough") self.column_transformer.fit(X, y) self.fitted = True @@ -550,17 +514,13 @@ def _get_decision_tree_bins(self, X, y, numerical_features): bins = [] for feature in numerical_features: tree_model = ( - DecisionTreeClassifier(max_depth=3) - if y.dtype.kind in "bi" - else DecisionTreeRegressor(max_depth=3) + DecisionTreeClassifier(max_depth=3) if y.dtype.kind in "bi" else DecisionTreeRegressor(max_depth=3) ) tree_model.fit(X[[feature]], y) thresholds = tree_model.tree_.threshold[tree_model.tree_.feature != -2] # type: ignore bin_edges = np.sort(np.unique(thresholds)) - bins.append( - np.concatenate(([X[feature].min()], bin_edges, [X[feature].max()])) - ) + bins.append(np.concatenate(([X[feature].min()], bin_edges, [X[feature].max()]))) return bins def transform(self, X): @@ -716,9 +676,7 @@ def get_feature_info(self, verbose=True): "categories": None, # Numerical features don't have categories } if verbose: - print( - f"Numerical Feature: {feature_name}, Info: {numerical_feature_info[feature_name]}" - ) + print(f"Numerical Feature: {feature_name}, Info: {numerical_feature_info[feature_name]}") # Categorical features elif "continuous_ordinal" in steps: From 75c2d1b439c357bb58290612110c2a9eeed59d8d Mon Sep 17 00:00:00 2001 From: Manish Kumar Date: Sun, 19 Jan 2025 23:28:02 +0100 Subject: [PATCH 18/18] exclude sentence-transformers --- poetry.lock | 321 +------------------------------------------------ pyproject.toml | 1 - 2 files changed, 1 insertion(+), 321 deletions(-) diff --git a/poetry.lock b/poetry.lock index 5583aa0..eeb1f93 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1290,94 +1290,6 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] -[[package]] -name = "pillow" -version = "11.1.0" -description = "Python Imaging Library (Fork)" -optional = false -python-versions = ">=3.9" -files = [ - {file = "pillow-11.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:e1abe69aca89514737465752b4bcaf8016de61b3be1397a8fc260ba33321b3a8"}, - {file = "pillow-11.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c640e5a06869c75994624551f45e5506e4256562ead981cce820d5ab39ae2192"}, - {file = "pillow-11.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a07dba04c5e22824816b2615ad7a7484432d7f540e6fa86af60d2de57b0fcee2"}, - {file = "pillow-11.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e267b0ed063341f3e60acd25c05200df4193e15a4a5807075cd71225a2386e26"}, - {file = "pillow-11.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bd165131fd51697e22421d0e467997ad31621b74bfc0b75956608cb2906dda07"}, - {file = "pillow-11.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:abc56501c3fd148d60659aae0af6ddc149660469082859fa7b066a298bde9482"}, - {file = "pillow-11.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:54ce1c9a16a9561b6d6d8cb30089ab1e5eb66918cb47d457bd996ef34182922e"}, - {file = "pillow-11.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:73ddde795ee9b06257dac5ad42fcb07f3b9b813f8c1f7f870f402f4dc54b5269"}, - {file = "pillow-11.1.0-cp310-cp310-win32.whl", hash = "sha256:3a5fe20a7b66e8135d7fd617b13272626a28278d0e578c98720d9ba4b2439d49"}, - {file = "pillow-11.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:b6123aa4a59d75f06e9dd3dac5bf8bc9aa383121bb3dd9a7a612e05eabc9961a"}, - {file = "pillow-11.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:a76da0a31da6fcae4210aa94fd779c65c75786bc9af06289cd1c184451ef7a65"}, - {file = "pillow-11.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:e06695e0326d05b06833b40b7ef477e475d0b1ba3a6d27da1bb48c23209bf457"}, - {file = "pillow-11.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96f82000e12f23e4f29346e42702b6ed9a2f2fea34a740dd5ffffcc8c539eb35"}, - {file = "pillow-11.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3cd561ded2cf2bbae44d4605837221b987c216cff94f49dfeed63488bb228d2"}, - {file = "pillow-11.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f189805c8be5ca5add39e6f899e6ce2ed824e65fb45f3c28cb2841911da19070"}, - {file = "pillow-11.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dd0052e9db3474df30433f83a71b9b23bd9e4ef1de13d92df21a52c0303b8ab6"}, - {file = "pillow-11.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:837060a8599b8f5d402e97197d4924f05a2e0d68756998345c829c33186217b1"}, - {file = "pillow-11.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:aa8dd43daa836b9a8128dbe7d923423e5ad86f50a7a14dc688194b7be5c0dea2"}, - {file = "pillow-11.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0a2f91f8a8b367e7a57c6e91cd25af510168091fb89ec5146003e424e1558a96"}, - {file = "pillow-11.1.0-cp311-cp311-win32.whl", hash = "sha256:c12fc111ef090845de2bb15009372175d76ac99969bdf31e2ce9b42e4b8cd88f"}, - {file = "pillow-11.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fbd43429d0d7ed6533b25fc993861b8fd512c42d04514a0dd6337fb3ccf22761"}, - {file = "pillow-11.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:f7955ecf5609dee9442cbface754f2c6e541d9e6eda87fad7f7a989b0bdb9d71"}, - {file = "pillow-11.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2062ffb1d36544d42fcaa277b069c88b01bb7298f4efa06731a7fd6cc290b81a"}, - {file = "pillow-11.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a85b653980faad27e88b141348707ceeef8a1186f75ecc600c395dcac19f385b"}, - {file = "pillow-11.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9409c080586d1f683df3f184f20e36fb647f2e0bc3988094d4fd8c9f4eb1b3b3"}, - {file = "pillow-11.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7fdadc077553621911f27ce206ffcbec7d3f8d7b50e0da39f10997e8e2bb7f6a"}, - {file = "pillow-11.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:93a18841d09bcdd774dcdc308e4537e1f867b3dec059c131fde0327899734aa1"}, - {file = "pillow-11.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9aa9aeddeed452b2f616ff5507459e7bab436916ccb10961c4a382cd3e03f47f"}, - {file = "pillow-11.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3cdcdb0b896e981678eee140d882b70092dac83ac1cdf6b3a60e2216a73f2b91"}, - {file = "pillow-11.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:36ba10b9cb413e7c7dfa3e189aba252deee0602c86c309799da5a74009ac7a1c"}, - {file = "pillow-11.1.0-cp312-cp312-win32.whl", hash = "sha256:cfd5cd998c2e36a862d0e27b2df63237e67273f2fc78f47445b14e73a810e7e6"}, - {file = "pillow-11.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:a697cd8ba0383bba3d2d3ada02b34ed268cb548b369943cd349007730c92bddf"}, - {file = "pillow-11.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:4dd43a78897793f60766563969442020e90eb7847463eca901e41ba186a7d4a5"}, - {file = "pillow-11.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ae98e14432d458fc3de11a77ccb3ae65ddce70f730e7c76140653048c71bfcbc"}, - {file = "pillow-11.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cc1331b6d5a6e144aeb5e626f4375f5b7ae9934ba620c0ac6b3e43d5e683a0f0"}, - {file = "pillow-11.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:758e9d4ef15d3560214cddbc97b8ef3ef86ce04d62ddac17ad39ba87e89bd3b1"}, - {file = "pillow-11.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b523466b1a31d0dcef7c5be1f20b942919b62fd6e9a9be199d035509cbefc0ec"}, - {file = "pillow-11.1.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:9044b5e4f7083f209c4e35aa5dd54b1dd5b112b108648f5c902ad586d4f945c5"}, - {file = "pillow-11.1.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:3764d53e09cdedd91bee65c2527815d315c6b90d7b8b79759cc48d7bf5d4f114"}, - {file = "pillow-11.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:31eba6bbdd27dde97b0174ddf0297d7a9c3a507a8a1480e1e60ef914fe23d352"}, - {file = "pillow-11.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:b5d658fbd9f0d6eea113aea286b21d3cd4d3fd978157cbf2447a6035916506d3"}, - {file = "pillow-11.1.0-cp313-cp313-win32.whl", hash = "sha256:f86d3a7a9af5d826744fabf4afd15b9dfef44fe69a98541f666f66fbb8d3fef9"}, - {file = "pillow-11.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:593c5fd6be85da83656b93ffcccc2312d2d149d251e98588b14fbc288fd8909c"}, - {file = "pillow-11.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:11633d58b6ee5733bde153a8dafd25e505ea3d32e261accd388827ee987baf65"}, - {file = "pillow-11.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:70ca5ef3b3b1c4a0812b5c63c57c23b63e53bc38e758b37a951e5bc466449861"}, - {file = "pillow-11.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8000376f139d4d38d6851eb149b321a52bb8893a88dae8ee7d95840431977081"}, - {file = "pillow-11.1.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ee85f0696a17dd28fbcfceb59f9510aa71934b483d1f5601d1030c3c8304f3c"}, - {file = "pillow-11.1.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:dd0e081319328928531df7a0e63621caf67652c8464303fd102141b785ef9547"}, - {file = "pillow-11.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e63e4e5081de46517099dc30abe418122f54531a6ae2ebc8680bcd7096860eab"}, - {file = "pillow-11.1.0-cp313-cp313t-win32.whl", hash = "sha256:dda60aa465b861324e65a78c9f5cf0f4bc713e4309f83bc387be158b077963d9"}, - {file = "pillow-11.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ad5db5781c774ab9a9b2c4302bbf0c1014960a0a7be63278d13ae6fdf88126fe"}, - {file = "pillow-11.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:67cd427c68926108778a9005f2a04adbd5e67c442ed21d95389fe1d595458756"}, - {file = "pillow-11.1.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:bf902d7413c82a1bfa08b06a070876132a5ae6b2388e2712aab3a7cbc02205c6"}, - {file = "pillow-11.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c1eec9d950b6fe688edee07138993e54ee4ae634c51443cfb7c1e7613322718e"}, - {file = "pillow-11.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e275ee4cb11c262bd108ab2081f750db2a1c0b8c12c1897f27b160c8bd57bbc"}, - {file = "pillow-11.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4db853948ce4e718f2fc775b75c37ba2efb6aaea41a1a5fc57f0af59eee774b2"}, - {file = "pillow-11.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:ab8a209b8485d3db694fa97a896d96dd6533d63c22829043fd9de627060beade"}, - {file = "pillow-11.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:54251ef02a2309b5eec99d151ebf5c9904b77976c8abdcbce7891ed22df53884"}, - {file = "pillow-11.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5bb94705aea800051a743aa4874bb1397d4695fb0583ba5e425ee0328757f196"}, - {file = "pillow-11.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:89dbdb3e6e9594d512780a5a1c42801879628b38e3efc7038094430844e271d8"}, - {file = "pillow-11.1.0-cp39-cp39-win32.whl", hash = "sha256:e5449ca63da169a2e6068dd0e2fcc8d91f9558aba89ff6d02121ca8ab11e79e5"}, - {file = "pillow-11.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:3362c6ca227e65c54bf71a5f88b3d4565ff1bcbc63ae72c34b07bbb1cc59a43f"}, - {file = "pillow-11.1.0-cp39-cp39-win_arm64.whl", hash = "sha256:b20be51b37a75cc54c2c55def3fa2c65bb94ba859dde241cd0a4fd302de5ae0a"}, - {file = "pillow-11.1.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8c730dc3a83e5ac137fbc92dfcfe1511ce3b2b5d7578315b63dbbb76f7f51d90"}, - {file = "pillow-11.1.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7d33d2fae0e8b170b6a6c57400e077412240f6f5bb2a342cf1ee512a787942bb"}, - {file = "pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8d65b38173085f24bc07f8b6c505cbb7418009fa1a1fcb111b1f4961814a442"}, - {file = "pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:015c6e863faa4779251436db398ae75051469f7c903b043a48f078e437656f83"}, - {file = "pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d44ff19eea13ae4acdaaab0179fa68c0c6f2f45d66a4d8ec1eda7d6cecbcc15f"}, - {file = "pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d3d8da4a631471dfaf94c10c85f5277b1f8e42ac42bade1ac67da4b4a7359b73"}, - {file = "pillow-11.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4637b88343166249fe8aa94e7c4a62a180c4b3898283bb5d3d2fd5fe10d8e4e0"}, - {file = "pillow-11.1.0.tar.gz", hash = "sha256:368da70808b36d73b4b390a8ffac11069f8a5c85f29eff1f1b01bcf3ef5b2a20"}, -] - -[package.extras] -docs = ["furo", "olefile", "sphinx (>=8.1)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] -fpx = ["olefile"] -mic = ["olefile"] -tests = ["check-manifest", "coverage (>=7.4.2)", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout", "trove-classifiers (>=2024.10.12)"] -typing = ["typing-extensions"] -xmp = ["defusedxml"] - [[package]] name = "platformdirs" version = "4.3.6" @@ -1737,109 +1649,6 @@ files = [ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] -[[package]] -name = "regex" -version = "2024.11.6" -description = "Alternative regular expression module, to replace re." -optional = false -python-versions = ">=3.8" -files = [ - {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff590880083d60acc0433f9c3f713c51f7ac6ebb9adf889c79a261ecf541aa91"}, - {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:658f90550f38270639e83ce492f27d2c8d2cd63805c65a13a14d36ca126753f0"}, - {file = "regex-2024.11.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:164d8b7b3b4bcb2068b97428060b2a53be050085ef94eca7f240e7947f1b080e"}, - {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3660c82f209655a06b587d55e723f0b813d3a7db2e32e5e7dc64ac2a9e86fde"}, - {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d22326fcdef5e08c154280b71163ced384b428343ae16a5ab2b3354aed12436e"}, - {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f1ac758ef6aebfc8943560194e9fd0fa18bcb34d89fd8bd2af18183afd8da3a2"}, - {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:997d6a487ff00807ba810e0f8332c18b4eb8d29463cfb7c820dc4b6e7562d0cf"}, - {file = "regex-2024.11.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02a02d2bb04fec86ad61f3ea7f49c015a0681bf76abb9857f945d26159d2968c"}, - {file = "regex-2024.11.6-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f02f93b92358ee3f78660e43b4b0091229260c5d5c408d17d60bf26b6c900e86"}, - {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:06eb1be98df10e81ebaded73fcd51989dcf534e3c753466e4b60c4697a003b67"}, - {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:040df6fe1a5504eb0f04f048e6d09cd7c7110fef851d7c567a6b6e09942feb7d"}, - {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabbfc59f2c6edba2a6622c647b716e34e8e3867e0ab975412c5c2f79b82da2"}, - {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8447d2d39b5abe381419319f942de20b7ecd60ce86f16a23b0698f22e1b70008"}, - {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:da8f5fc57d1933de22a9e23eec290a0d8a5927a5370d24bda9a6abe50683fe62"}, - {file = "regex-2024.11.6-cp310-cp310-win32.whl", hash = "sha256:b489578720afb782f6ccf2840920f3a32e31ba28a4b162e13900c3e6bd3f930e"}, - {file = "regex-2024.11.6-cp310-cp310-win_amd64.whl", hash = "sha256:5071b2093e793357c9d8b2929dfc13ac5f0a6c650559503bb81189d0a3814519"}, - {file = "regex-2024.11.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5478c6962ad548b54a591778e93cd7c456a7a29f8eca9c49e4f9a806dcc5d638"}, - {file = "regex-2024.11.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c89a8cc122b25ce6945f0423dc1352cb9593c68abd19223eebbd4e56612c5b7"}, - {file = "regex-2024.11.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:94d87b689cdd831934fa3ce16cc15cd65748e6d689f5d2b8f4f4df2065c9fa20"}, - {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1062b39a0a2b75a9c694f7a08e7183a80c63c0d62b301418ffd9c35f55aaa114"}, - {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:167ed4852351d8a750da48712c3930b031f6efdaa0f22fa1933716bfcd6bf4a3"}, - {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d548dafee61f06ebdb584080621f3e0c23fff312f0de1afc776e2a2ba99a74f"}, - {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a19f302cd1ce5dd01a9099aaa19cae6173306d1302a43b627f62e21cf18ac0"}, - {file = "regex-2024.11.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bec9931dfb61ddd8ef2ebc05646293812cb6b16b60cf7c9511a832b6f1854b55"}, - {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9714398225f299aa85267fd222f7142fcb5c769e73d7733344efc46f2ef5cf89"}, - {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:202eb32e89f60fc147a41e55cb086db2a3f8cb82f9a9a88440dcfc5d37faae8d"}, - {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:4181b814e56078e9b00427ca358ec44333765f5ca1b45597ec7446d3a1ef6e34"}, - {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:068376da5a7e4da51968ce4c122a7cd31afaaec4fccc7856c92f63876e57b51d"}, - {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ac10f2c4184420d881a3475fb2c6f4d95d53a8d50209a2500723d831036f7c45"}, - {file = "regex-2024.11.6-cp311-cp311-win32.whl", hash = "sha256:c36f9b6f5f8649bb251a5f3f66564438977b7ef8386a52460ae77e6070d309d9"}, - {file = "regex-2024.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:02e28184be537f0e75c1f9b2f8847dc51e08e6e171c6bde130b2687e0c33cf60"}, - {file = "regex-2024.11.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:52fb28f528778f184f870b7cf8f225f5eef0a8f6e3778529bdd40c7b3920796a"}, - {file = "regex-2024.11.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdd6028445d2460f33136c55eeb1f601ab06d74cb3347132e1c24250187500d9"}, - {file = "regex-2024.11.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:805e6b60c54bf766b251e94526ebad60b7de0c70f70a4e6210ee2891acb70bf2"}, - {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b85c2530be953a890eaffde05485238f07029600e8f098cdf1848d414a8b45e4"}, - {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb26437975da7dc36b7efad18aa9dd4ea569d2357ae6b783bf1118dabd9ea577"}, - {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abfa5080c374a76a251ba60683242bc17eeb2c9818d0d30117b4486be10c59d3"}, - {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b7fa6606c2881c1db9479b0eaa11ed5dfa11c8d60a474ff0e095099f39d98e"}, - {file = "regex-2024.11.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c32f75920cf99fe6b6c539c399a4a128452eaf1af27f39bce8909c9a3fd8cbe"}, - {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:982e6d21414e78e1f51cf595d7f321dcd14de1f2881c5dc6a6e23bbbbd68435e"}, - {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a7c2155f790e2fb448faed6dd241386719802296ec588a8b9051c1f5c481bc29"}, - {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149f5008d286636e48cd0b1dd65018548944e495b0265b45e1bffecce1ef7f39"}, - {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:e5364a4502efca094731680e80009632ad6624084aff9a23ce8c8c6820de3e51"}, - {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0a86e7eeca091c09e021db8eb72d54751e527fa47b8d5787caf96d9831bd02ad"}, - {file = "regex-2024.11.6-cp312-cp312-win32.whl", hash = "sha256:32f9a4c643baad4efa81d549c2aadefaeba12249b2adc5af541759237eee1c54"}, - {file = "regex-2024.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:a93c194e2df18f7d264092dc8539b8ffb86b45b899ab976aa15d48214138e81b"}, - {file = "regex-2024.11.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a6ba92c0bcdf96cbf43a12c717eae4bc98325ca3730f6b130ffa2e3c3c723d84"}, - {file = "regex-2024.11.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:525eab0b789891ac3be914d36893bdf972d483fe66551f79d3e27146191a37d4"}, - {file = "regex-2024.11.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:086a27a0b4ca227941700e0b31425e7a28ef1ae8e5e05a33826e17e47fbfdba0"}, - {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bde01f35767c4a7899b7eb6e823b125a64de314a8ee9791367c9a34d56af18d0"}, - {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b583904576650166b3d920d2bcce13971f6f9e9a396c673187f49811b2769dc7"}, - {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c4de13f06a0d54fa0d5ab1b7138bfa0d883220965a29616e3ea61b35d5f5fc7"}, - {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3cde6e9f2580eb1665965ce9bf17ff4952f34f5b126beb509fee8f4e994f143c"}, - {file = "regex-2024.11.6-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d7f453dca13f40a02b79636a339c5b62b670141e63efd511d3f8f73fba162b3"}, - {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:59dfe1ed21aea057a65c6b586afd2a945de04fc7db3de0a6e3ed5397ad491b07"}, - {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b97c1e0bd37c5cd7902e65f410779d39eeda155800b65fc4d04cc432efa9bc6e"}, - {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f9d1e379028e0fc2ae3654bac3cbbef81bf3fd571272a42d56c24007979bafb6"}, - {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:13291b39131e2d002a7940fb176e120bec5145f3aeb7621be6534e46251912c4"}, - {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f51f88c126370dcec4908576c5a627220da6c09d0bff31cfa89f2523843316d"}, - {file = "regex-2024.11.6-cp313-cp313-win32.whl", hash = "sha256:63b13cfd72e9601125027202cad74995ab26921d8cd935c25f09c630436348ff"}, - {file = "regex-2024.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:2b3361af3198667e99927da8b84c1b010752fa4b1115ee30beaa332cabc3ef1a"}, - {file = "regex-2024.11.6-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3a51ccc315653ba012774efca4f23d1d2a8a8f278a6072e29c7147eee7da446b"}, - {file = "regex-2024.11.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ad182d02e40de7459b73155deb8996bbd8e96852267879396fb274e8700190e3"}, - {file = "regex-2024.11.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ba9b72e5643641b7d41fa1f6d5abda2c9a263ae835b917348fc3c928182ad467"}, - {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40291b1b89ca6ad8d3f2b82782cc33807f1406cf68c8d440861da6304d8ffbbd"}, - {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cdf58d0e516ee426a48f7b2c03a332a4114420716d55769ff7108c37a09951bf"}, - {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a36fdf2af13c2b14738f6e973aba563623cb77d753bbbd8d414d18bfaa3105dd"}, - {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1cee317bfc014c2419a76bcc87f071405e3966da434e03e13beb45f8aced1a6"}, - {file = "regex-2024.11.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50153825ee016b91549962f970d6a4442fa106832e14c918acd1c8e479916c4f"}, - {file = "regex-2024.11.6-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ea1bfda2f7162605f6e8178223576856b3d791109f15ea99a9f95c16a7636fb5"}, - {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:df951c5f4a1b1910f1a99ff42c473ff60f8225baa1cdd3539fe2819d9543e9df"}, - {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:072623554418a9911446278f16ecb398fb3b540147a7828c06e2011fa531e773"}, - {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:f654882311409afb1d780b940234208a252322c24a93b442ca714d119e68086c"}, - {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:89d75e7293d2b3e674db7d4d9b1bee7f8f3d1609428e293771d1a962617150cc"}, - {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:f65557897fc977a44ab205ea871b690adaef6b9da6afda4790a2484b04293a5f"}, - {file = "regex-2024.11.6-cp38-cp38-win32.whl", hash = "sha256:6f44ec28b1f858c98d3036ad5d7d0bfc568bdd7a74f9c24e25f41ef1ebfd81a4"}, - {file = "regex-2024.11.6-cp38-cp38-win_amd64.whl", hash = "sha256:bb8f74f2f10dbf13a0be8de623ba4f9491faf58c24064f32b65679b021ed0001"}, - {file = "regex-2024.11.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5704e174f8ccab2026bd2f1ab6c510345ae8eac818b613d7d73e785f1310f839"}, - {file = "regex-2024.11.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:220902c3c5cc6af55d4fe19ead504de80eb91f786dc102fbd74894b1551f095e"}, - {file = "regex-2024.11.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e7e351589da0850c125f1600a4c4ba3c722efefe16b297de54300f08d734fbf"}, - {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5056b185ca113c88e18223183aa1a50e66507769c9640a6ff75859619d73957b"}, - {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e34b51b650b23ed3354b5a07aab37034d9f923db2a40519139af34f485f77d0"}, - {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5670bce7b200273eee1840ef307bfa07cda90b38ae56e9a6ebcc9f50da9c469b"}, - {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08986dce1339bc932923e7d1232ce9881499a0e02925f7402fb7c982515419ef"}, - {file = "regex-2024.11.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:93c0b12d3d3bc25af4ebbf38f9ee780a487e8bf6954c115b9f015822d3bb8e48"}, - {file = "regex-2024.11.6-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:764e71f22ab3b305e7f4c21f1a97e1526a25ebdd22513e251cf376760213da13"}, - {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:f056bf21105c2515c32372bbc057f43eb02aae2fda61052e2f7622c801f0b4e2"}, - {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:69ab78f848845569401469da20df3e081e6b5a11cb086de3eed1d48f5ed57c95"}, - {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:86fddba590aad9208e2fa8b43b4c098bb0ec74f15718bb6a704e3c63e2cef3e9"}, - {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:684d7a212682996d21ca12ef3c17353c021fe9de6049e19ac8481ec35574a70f"}, - {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a03e02f48cd1abbd9f3b7e3586d97c8f7a9721c436f51a5245b3b9483044480b"}, - {file = "regex-2024.11.6-cp39-cp39-win32.whl", hash = "sha256:41758407fc32d5c3c5de163888068cfee69cb4c2be844e7ac517a52770f9af57"}, - {file = "regex-2024.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b2837718570f95dd41675328e111345f9b7095d821bac435aac173ac80b19983"}, - {file = "regex-2024.11.6.tar.gz", hash = "sha256:7ab159b063c52a0333c884e4679f8d7a85112ee3078fe3d9004b2dd875585519"}, -] - [[package]] name = "requests" version = "2.32.3" @@ -2060,33 +1869,6 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodest doc = ["intersphinx_registry", "jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.16.5)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<8.0.0)", "sphinx-copybutton", "sphinx-design (>=0.4.0)"] test = ["Cython", "array-api-strict (>=2.0,<2.1.1)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] -[[package]] -name = "sentence-transformers" -version = "3.3.1" -description = "State-of-the-Art Text Embeddings" -optional = false -python-versions = ">=3.9" -files = [ - {file = "sentence_transformers-3.3.1-py3-none-any.whl", hash = "sha256:abffcc79dab37b7d18d21a26d5914223dd42239cfe18cb5e111c66c54b658ae7"}, - {file = "sentence_transformers-3.3.1.tar.gz", hash = "sha256:9635dbfb11c6b01d036b9cfcee29f7716ab64cf2407ad9f403a2e607da2ac48b"}, -] - -[package.dependencies] -huggingface-hub = ">=0.20.0" -Pillow = "*" -scikit-learn = "*" -scipy = "*" -torch = ">=1.11.0" -tqdm = "*" -transformers = ">=4.41.0,<5.0.0" - -[package.extras] -dev = ["accelerate (>=0.20.3)", "datasets", "peft", "pre-commit", "pytest", "pytest-cov"] -onnx = ["optimum[onnxruntime] (>=1.23.1)"] -onnx-gpu = ["optimum[onnxruntime-gpu] (>=1.23.1)"] -openvino = ["optimum-intel[openvino] (>=1.20.0)"] -train = ["accelerate (>=0.20.3)", "datasets"] - [[package]] name = "setuptools" version = "75.8.0" @@ -2146,38 +1928,6 @@ files = [ {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, ] -[[package]] -name = "tokenizers" -version = "0.21.0" -description = "" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"}, - {file = "tokenizers-0.21.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f53ea537c925422a2e0e92a24cce96f6bc5046bbef24a1652a5edc8ba975f62e"}, - {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b177fb54c4702ef611de0c069d9169f0004233890e0c4c5bd5508ae05abf193"}, - {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b43779a269f4629bebb114e19c3fca0223296ae9fea8bb9a7a6c6fb0657ff8e"}, - {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aeb255802be90acfd363626753fda0064a8df06031012fe7d52fd9a905eb00e"}, - {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8b09dbeb7a8d73ee204a70f94fc06ea0f17dcf0844f16102b9f414f0b7463ba"}, - {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:400832c0904f77ce87c40f1a8a27493071282f785724ae62144324f171377273"}, - {file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84ca973b3a96894d1707e189c14a774b701596d579ffc7e69debfc036a61a04"}, - {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:eb7202d231b273c34ec67767378cd04c767e967fda12d4a9e36208a34e2f137e"}, - {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:089d56db6782a73a27fd8abf3ba21779f5b85d4a9f35e3b493c7bbcbbf0d539b"}, - {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:c87ca3dc48b9b1222d984b6b7490355a6fdb411a2d810f6f05977258400ddb74"}, - {file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4145505a973116f91bc3ac45988a92e618a6f83eb458f49ea0790df94ee243ff"}, - {file = "tokenizers-0.21.0-cp39-abi3-win32.whl", hash = "sha256:eb1702c2f27d25d9dd5b389cc1f2f51813e99f8ca30d9e25348db6585a97e24a"}, - {file = "tokenizers-0.21.0-cp39-abi3-win_amd64.whl", hash = "sha256:87841da5a25a3a5f70c102de371db120f41873b854ba65e52bccd57df5a3780c"}, - {file = "tokenizers-0.21.0.tar.gz", hash = "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4"}, -] - -[package.dependencies] -huggingface-hub = ">=0.16.4,<1.0" - -[package.extras] -dev = ["tokenizers[testing]"] -docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] -testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] - [[package]] name = "tomli" version = "2.2.1" @@ -2320,75 +2070,6 @@ notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] -[[package]] -name = "transformers" -version = "4.48.0" -description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" -optional = false -python-versions = ">=3.9.0" -files = [ - {file = "transformers-4.48.0-py3-none-any.whl", hash = "sha256:6d3de6d71cb5f2a10f9775ccc17abce9620195caaf32ec96542bd2a6937f25b0"}, - {file = "transformers-4.48.0.tar.gz", hash = "sha256:03fdfcbfb8b0367fb6c9fbe9d1c9aa54dfd847618be9b52400b2811d22799cb1"}, -] - -[package.dependencies] -filelock = "*" -huggingface-hub = ">=0.24.0,<1.0" -numpy = ">=1.17" -packaging = ">=20.0" -pyyaml = ">=5.1" -regex = "!=2019.12.17" -requests = "*" -safetensors = ">=0.4.1" -tokenizers = ">=0.21,<0.22" -tqdm = ">=4.27" - -[package.extras] -accelerate = ["accelerate (>=0.26.0)"] -agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=2.0)"] -all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"] -audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -benchmark = ["optimum-benchmark (>=0.3.0)"] -codecarbon = ["codecarbon (>=2.8.1)"] -deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] -flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -ftfy = ["ftfy"] -integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] -ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] -modelcreation = ["cookiecutter (==1.7.3)"] -natten = ["natten (>=0.14.6,<0.15.0)"] -onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] -onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] -optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"] -ray = ["ray[tune] (>=2.7.0)"] -retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] -ruff = ["ruff (==0.5.1)"] -sagemaker = ["sagemaker (>=2.31.0)"] -sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] -serving = ["fastapi", "pydantic", "starlette", "uvicorn"] -sigopt = ["sigopt"] -sklearn = ["scikit-learn"] -speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] -tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] -tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] -tiktoken = ["blobfile", "tiktoken"] -timm = ["timm (<=1.0.11)"] -tokenizers = ["tokenizers (>=0.21,<0.22)"] -torch = ["accelerate (>=0.26.0)", "torch (>=2.0)"] -torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.24.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"] -video = ["av (==9.2.0)"] -vision = ["Pillow (>=10.0.1,<=15.0)"] - [[package]] name = "triton" version = "3.1.0" @@ -2579,4 +2260,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = ">=3.10, <=3.12" -content-hash = "a2145b5f1d55eea1ccc3ef498aa90aaa96a62ab56928a3adf33a51b2700361be" +content-hash = "c7606af7fb47a2fb5e856b23ef3e06a1740544bda46470dafeb7c7a3ca794d5e" diff --git a/pyproject.toml b/pyproject.toml index e2ce51b..08a2476 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ scikit-optimize = "^0.10.2" einops = "^0.8.0" accelerate = "^1.2.1" scipy = "^1.15.0" -sentence-transformers = "^3.3.1" [tool.poetry.group.dev.dependencies] pytest = "^8.1"