Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add autoselector + benchmarks #616

Open
wants to merge 5 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
Binary file added examples/.DS_Store
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Pipeline Design Dimensions,MO-GAAL,AutoEncoder,SO-GAAL,VAE,AnoGAN,Deep SVDD,ALAD
Data Augmentation,"SMOTE, GAN-based Oversampling",N/A,"Oversampling, GAN-based Augmentation",N/A,N/A,N/A,N/A
Data Preprocessing,"Normalization, Standardization","StandardScaler, Normalization","MinMax Scaling, Standardization","StandardScaler, MinMax Scaling","StandardScaler, MinMax Scaling","StandardScaler, MinMax Scaling","StandardScaler, MinMax Scaling"
Network Architecture,"Discriminator, Generator (GAN), MLP, AutoEncoder","AutoEncoder, MLP","Discriminator, Generator (GAN)","VAE, beta-VAE, AutoEncoder","Discriminator, Generator (GAN)","AutoEncoder, One-Class Classifier","Discriminator, Generator (GAN)"
Hidden Layers,"[[32, 16], [64, 32, 16], [128, 64, 32, 16]]","[[32, 16], [64, 32, 16], [128, 64, 32, 16]]","[[64, 32], [128, 64, 32], [256, 128, 64, 32]]","[[64, 32], [128, 64, 32], [256, 128, 64]]","[[64, 32], [128, 64, 32], [256, 128, 64, 32]]","[[64, 32], [128, 64, 32], [256, 128, 64, 32]]","[[64, 32], [128, 64, 32], [256, 128, 64]]"
Activation,"ReLU, LeakyReLU, Tanh","ReLU, Tanh, LeakyReLU","ReLU, LeakyReLU, Tanh","ReLU, Tanh, LeakyReLU","ReLU, LeakyReLU, Tanh","ReLU, Tanh, LeakyReLU","ReLU, LeakyReLU, Tanh"
Dropout,"0.0, 0.25, 0.5","0.0, 0.2, 0.5","0.0, 0.25, 0.5","0.0, 0.3, 0.5","0.0, 0.25, 0.5","0.0, 0.2, 0.5","0.0, 0.25, 0.5"
Initialization,"Xavier, He, Random Normal","Xavier, He, Random Normal","Xavier, He, Random Normal","Xavier, He, Random Normal","Xavier, He, Random Normal","Xavier, He, Random Normal","Xavier, He, Random Normal"
Loss Function,"BCE, WGAN Loss, Hinge Loss","Mean Squared Error, Binary Crossentropy","BCE, WGAN Loss","Reconstruction Loss, KL Divergence","Binary Crossentropy, Anomaly Score Loss","SVDD Loss, Reconstruction Loss (AE)","Binary Crossentropy, Adversarial Loss"
Optimizer,"SGD, Adam, RMSprop","SGD, Adam, RMSprop","Adam, RMSprop, SGD","Adam, RMSprop, SGD","Adam, RMSprop, SGD","Adam, RMSprop, SGD","Adam, RMSprop, SGD"
Epochs,"50, 100, 200","50, 100, 200","50, 100, 150","50, 100, 200","50, 100, 150","50, 100, 200","50, 100, 150"
Batch Size,"32, 64, 128","32, 64, 128","32, 64, 128","32, 64, 128","32, 64, 128","32, 64, 128","32, 64, 128"
Learning Rate,"1e-3, 1e-4, 1e-5","1e-3, 1e-4, 1e-5","1e-3, 1e-4, 1e-5","1e-3, 1e-4, 1e-5","1e-3, 1e-4, 1e-5","1e-3, 1e-4, 1e-5","1e-3, 1e-4, 1e-5"
Weight Decay,"1e-2, 1e-4","1e-2, 1e-4","1e-2, 1e-4","1e-2, 1e-4","1e-2, 1e-4","1e-2, 1e-4","1e-2, 1e-4"
7,306 changes: 7,306 additions & 0 deletions examples/auto_model_selection_example/auto_selection.ipynb

Large diffs are not rendered by default.

260 changes: 260 additions & 0 deletions examples/auto_model_selection_example/model_info_summarizer.ipynb

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions examples/auto_model_selection_example/prn_df.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
,Data,#Samples,# Dimensions,Outlier Perc,MO-GAAL,SO-GAAL,AutoEncoder,VAE,AnoGAN,DeepSVDD,ALAD,AE1SVM,DevNet,LUNAR
0,arrhythmia,452,274,14.6018,0.3214,0.2857,0.3929,0.4286,0.4643,0.3929,0.0357,0.3929,0.0714,0.4643
0,cardio,1831,21,9.6122,0.2,0.4143,0.3,0.6571,0.3429,0.5,0.2429,0.5571,0.0,0.2143
0,glass,214,9,4.2056,0.0,0.2,0.2,0.2,0.2,0.2,0.0,0.2,0.0,0.2
0,ionosphere,351,33,35.8974,0.5217,0.413,0.587,0.4783,0.4348,0.4783,0.3261,0.4565,0.5,0.8043
0,letter,1600,32,6.25,0.0488,0.05,0.2439,0.0732,0.0488,0.0488,0.0732,0.122,0.1951,0.4146
0,lympho,148,18,4.0541,0.3333,0.3333,0.6667,0.6667,0.6667,0.3333,0.3333,0.3333,0.0,0.6667
0,mnist,7603,100,9.2069,0.1963,0.2852,0.3741,0.4519,0.2519,0.2704,0.0963,0.3852,0.0704,0.337
0,musk,3062,166,3.1679,0.0976,0.0,0.3659,1.0,0.0976,0.4146,0.0244,1.0,0.9268,0.1951
0,optdigits,5216,64,2.8758,0.1385,0.0,0.0,0.0,0.0308,0.0,0.0,0.0,0.0,0.0462
0,pendigits,6870,16,2.2707,0.0,0.0455,0.0484,0.3065,0.0161,0.0484,0.0484,0.2419,0.0,0.129
0,pima,768,8,34.8958,0.2212,0.1681,0.5752,0.4867,0.5487,0.5133,0.3805,0.4779,0.2743,0.5575
0,satellite,6435,36,31.6395,0.5111,0.4059,0.4951,0.5603,0.5542,0.4163,0.3214,0.5591,0.383,0.431
0,satimage-2,5803,36,1.2235,0.0,0.0,0.3226,0.7097,0.6129,0.5161,0.0645,0.6129,0.0,0.2903
0,shuttle,49097,9,7.1511,0.6092,0.0,0.9068,0.957,0.9591,0.9527,0.2399,0.9245,0.0064,0.1934
0,vertebral,240,6,12.5,0.1429,0.1429,0.0,0.1429,0.2143,0.0,0.2143,0.1429,0.2143,0.0
0,vowels,1456,12,3.4341,0.0,0.0,0.4545,0.2727,0.0,0.0909,0.0,0.3636,0.0,0.5455
0,wbc,378,30,5.5556,0.0,0.0,0.5,0.6,0.6,0.5,0.0,0.5,0.0,0.4
126 changes: 126 additions & 0 deletions examples/auto_model_selection_example/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
### Framework Explanation: A Neurosymbolic Approach to Dataset Analysis and Automated Model Selection

---

### **1. Loading Model Analyses: Encoding Knowledge as Symbolic Tags**
The framework begins by processing pre-analyzed model metadata stored in JSON files. These metadata contain explicit, symbolic descriptions of each model’s **strengths** and **weaknesses**, designed to facilitate reasoning about model applicability to specific datasets.

The core of our framework is the utilization of symbolic representations to characterize each model's strengths and weaknesses. These symbols are derived systematically by analyzing the respective papers and source code of each model. This process ensures that the symbolic descriptions capture both the theoretical foundations and practical implementations of the models.

Specifically:
- **Paper Analysis**: We review the primary research papers describing each model to identify its targeted application domains, key advantages, and known limitations. For instance, for models like AnoGAN, the focus on high-dimensional medical images is highlighted in its foundational paper.
- **Code Inspection**: By examining the source code, we extract implementation-specific details such as computational requirements (e.g., GPU dependence), scalability considerations, and specific preprocessing requirements. This complements the theoretical understanding provided in the papers.

These extracted symbols are then structured into:
- **Strengths**: Key attributes where the model excels, represented as labels such as "images", "medical", or "high dimensionality". Each label includes a detailed explanation derived from the literature and implementation insights.
- **Weaknesses**: Known limitations or scenarios where the model is less effective, such as "small data size" or "real-time data".


#### **Key Functionality:**
- **Input:** JSON files for each model with tags such as:
- `strengths`: e.g., "images", "medical", "high dimensionality".
- `weaknesses`: e.g., "small data size", "text data".
- **Process:** Extract and store symbolic information in a structured dictionary, where each model is mapped to its strengths and weaknesses.
- **Output:** A structured knowledge base enabling symbolic reasoning in later stages.

#### **Symbolic Value:**
By explicitly encoding domain-specific properties and limitations, this step transforms the selection process into a logical reasoning task, allowing systematic alignment with dataset characteristics.

---

### **2. Dataset Profiling: Statistical Summarization and Tagging**
The framework analyzes the input dataset to produce a comprehensive statistical profile, summarizing its key characteristics. These include both high-level descriptors (e.g., data types, dimensionality) and deeper statistical properties (e.g., skewness, kurtosis).

#### **Key Functionality:**
- **Input:** Raw dataset (`pandas.DataFrame`) and optional user notes.
- **Process:**
- Compute dataset-level attributes such as shape, data type distribution, missing value ratio, and numerical feature statistics.
- Quantify statistical metrics for numerical columns, such as skewness and kurtosis, to capture data complexity.
- Generate symbolic tags (e.g., "noisy data", "high dimensionality") based on the profile.
- **Output:** A structured dataset description and standardized symbolic tags.

#### **Neurosymbolic Integration:**
- **Symbolic:** Converts raw statistical features into tags, enabling alignment with model descriptions.
- **Neural:** Uses GPT to refine and adapt the tags, ensuring compatibility with the downstream symbolic reasoning framework.

---

### **3. GPT-Driven Tagging: Neural Refinement of Dataset Properties**
Using the dataset's statistical summary, the GPT model generates a refined, standardized set of tags that describe the dataset in terms relevant to model selection. These tags represent the dataset's **semantic properties**, such as size, domain, and computational requirements.

#### **Key Functionality:**
- **Input:** Statistical summary of the dataset, including computed metrics and descriptive notes.
- **Process:** GPT generates tags in JSON format using predefined categories:
- Data size: e.g., "small", "medium", "large".
- Data type: e.g., "images", "tabular data".
- Domain: e.g., "medical", "finance".
- Characteristics: e.g., "noisy data", "imbalanced data".
- Computational constraints: e.g., "GPU", "high memory".
- **Output:** JSON-formatted tags, ready for comparison with model strengths and weaknesses.

#### **Neural Value:**
GPT’s ability to generalize across diverse datasets ensures the generated tags align semantically with model descriptions, even for datasets with novel or ambiguous characteristics.

---

### **4. Automated Model Selection: Symbolic Reasoning Enhanced by Neural Insight**
This step compares dataset tags with model metadata to determine the most suitable model for the given dataset. The decision-making process combines:
- Symbolic reasoning for structured tag alignment.
- Neural capabilities of GPT for complex, context-aware recommendations.

#### **Key Functionality:**
- **Input:** Dataset tags, model strengths and weaknesses, and a list of available models.
- **Process:**
- Symbolic matching of dataset tags to model strengths.
- Neural reasoning via GPT to evaluate trade-offs between competing models.
- Generate a JSON output with the recommended model and an explanation of the decision.
- **Output:** Selected model and rationale.

#### **Example:**
Given a dataset described by:
```json
{
"tags": ["images", "medical", "high dimensionality", "noisy data", "GPU"]
}
```
And the model `AnoGAN` with strengths like "medical", "images", and weaknesses like "small data size", GPT selects `AnoGAN` due to its strong alignment with the dataset properties and mitigable weaknesses.

---

### **5. Model Deployment: Dynamically Instantiating the Classifier**
Once a model is selected, the framework dynamically initializes it with appropriate configurations, ready for training or inference.

#### **Key Functionality:**
- **Input:** Selected model name and its hyperparameter settings.
- **Process:** Import the relevant model class from the library, set its parameters, and return an initialized instance.
- **Output:** A fully instantiated classifier object.

#### **Example Deployment:**
For `AnoGAN`, the framework initializes the model with GPU acceleration, batch size, and epoch settings tailored to the dataset. Conversely, for text-based datasets, it avoids image-specific models like `AnoGAN`.

---

### **Advantages of the Framework**
#### **1. Symbolic Reasoning for Interpretability**
The explicit use of symbolic tags for models and datasets enhances interpretability, providing clear explanations for why a model was selected.

#### **2. Neural Flexibility for Complex Reasoning**
GPT’s neural capabilities enable nuanced trade-offs in ambiguous scenarios, such as datasets that partially align with multiple models.

#### **3. Generality Across Domains**
The modular design accommodates diverse datasets, from images to tabular data, and seamlessly integrates new models and tags.

#### **4. Automation and Scalability**
By automating both dataset profiling and model selection, the framework reduces the need for manual intervention, making it scalable for real-world applications.

---

### **End-to-End Example: A Neurosymbolic Workflow**
1. **Input:**
- Dataset Tags: `["images", "medical", "high dimensionality", "noisy data", "GPU"]`
- Models: `AnoGAN`, `AutoEncoder`, `DeepSVDD`.
2. **Output:**
- Selected Model: `AnoGAN`
- Rationale: `"AnoGAN's strengths align with the dataset properties, particularly its focus on medical images and handling of high-dimensional, noisy data."`

This neurosymbolic approach ensures robust, explainable, and efficient model selection tailored to the needs of complex datasets.
18 changes: 18 additions & 0 deletions examples/auto_model_selection_example/result_with_baseline.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
,Data,AutoEncoder,LUNAR,without note,with note,Average_Performance
0,arrhythmia,0.8116,0.8284,0.8176,0.8176,0.66415
1,cardio,0.7847,0.5704,0.7255,0.7847,0.69021
2,glass,0.5901,0.7926,0.7926,0.7926,0.54024
3,ionosphere,0.7851,0.9156,0.7675,0.9156,0.7108000000000001
4,letter,0.8087,0.9057,0.5884,0.5884,0.5763199999999999
5,lympho,0.9825,0.9357,0.9357,0.9649,0.72691
6,mnist,0.8567,0.7411,0.8567,0.9002,0.6943
7,musk,0.8853,0.7666,1.0,0.8853,0.77917
8,optdigits,0.5124,0.4836,0.5124,0.5074,0.5601299999999999
9,pendigits,0.6687,0.6973,0.9273,0.7824,0.67171
10,pima,0.7189,0.7177,0.7177,0.6013,0.5411699999999999
11,satellite,0.6431,0.6179,0.7419,0.6431,0.6348999999999999
12,satimage-2,0.8684,0.8161,0.9948,0.8684,0.8842700000000002
13,shuttle,0.9939,0.641,0.9947,0.9939,0.79559
14,vertebral,0.324,0.2552,0.4172,0.4172,0.4660100000000001
15,vowels,0.9305,0.946,0.7489,0.9305,0.6199299999999999
16,wbc,0.9556,0.9042,0.9042,0.9218,0.62985
18 changes: 18 additions & 0 deletions examples/auto_model_selection_example/roc_df.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
,Unnamed: 0,Data,#Samples,# Dimensions,Outlier Perc,MO-GAAL,SO-GAAL,AutoEncoder,VAE,AnoGAN,DeepSVDD,ALAD,AE1SVM,DevNet,LUNAR,Average_Performance,without_note_modes,with_note_modes,without note,with note
0,0,arrhythmia,452,274,14.6018,0.616,0.5661,0.8116,0.8176,0.7946,0.7554,0.4232,0.8176,0.211,0.8284,0.66415,VAE,AE1SVM,0.8176,0.8176
1,0,cardio,1831,21,9.6122,0.5603,0.7255,0.7847,0.9615,0.8004,0.934,0.6025,0.9125,0.0503,0.5704,0.69021,SO-GAAL,AutoEncoder,0.7255,0.7847
2,0,glass,214,9,4.2056,0.4247,0.3975,0.5901,0.6099,0.6765,0.4,0.2667,0.6864,0.558,0.7926,0.54024,LUNAR,LUNAR,0.7926,0.7926
3,0,ionosphere,351,33,35.8974,0.6751,0.5908,0.7851,0.7675,0.66,0.7535,0.5046,0.773,0.6828,0.9156,0.7108000000000001,VAE,LUNAR,0.7675,0.9156
4,0,letter,1600,32,6.25,0.3489,0.3074,0.8087,0.5884,0.5223,0.5133,0.4811,0.5883,0.6991,0.9057,0.5763199999999999,VAE,VAE,0.5884,0.5884
5,0,lympho,148,18,4.0541,0.5263,0.3918,0.9825,0.9825,0.9825,0.8421,0.6374,0.9649,0.0234,0.9357,0.72691,LUNAR,AE1SVM,0.9357,0.9649
6,0,mnist,7603,100,9.2069,0.6122,0.6926,0.8567,0.9002,0.6688,0.7207,0.4786,0.8721,0.4,0.7411,0.6943,AutoEncoder,VAE,0.8567,0.9002
7,0,musk,3062,166,3.1679,0.5686,0.4388,0.8853,1.0,0.806,0.9505,0.3772,1.0,0.9987,0.7666,0.77917,AE1SVM,AutoEncoder,1.0,0.8853
8,0,optdigits,5216,64,2.8758,0.6552,0.4641,0.5124,0.5074,0.8159,0.5199,0.494,0.4455,0.7033,0.4836,0.5601299999999999,AutoEncoder,VAE,0.5124,0.5074
9,0,pendigits,6870,16,2.2707,0.6974,0.5114,0.6687,0.9273,0.8492,0.7824,0.569,0.9097,0.1047,0.6973,0.67171,VAE,Deep SVDD,0.9273,
10,0,pima,768,8,34.8958,0.2915,0.2593,0.7189,0.6112,0.6403,0.6713,0.5437,0.6013,0.3565,0.7177,0.5411699999999999,LUNAR,AE1SVM,0.7177,0.6013
11,0,satellite,6435,36,31.6395,0.6742,0.5679,0.6431,0.7419,0.7241,0.5839,0.5014,0.7522,0.5424,0.6179,0.6348999999999999,VAE,AutoEncoder,0.7419,0.6431
12,0,satimage-2,5803,36,1.2235,0.9693,0.8797,0.8684,0.9948,0.9667,0.9667,0.5848,0.9922,0.804,0.8161,0.8842700000000002,VAE,AutoEncoder,0.9948,0.8684
13,0,shuttle,49097,9,7.1511,0.9085,0.7012,0.9939,0.9947,0.9866,0.9935,0.649,0.99,0.0975,0.641,0.79559,VAE,AutoEncoder,0.9947,0.9939
14,0,vertebral,240,6,12.5,0.5409,0.6524,0.324,0.4172,0.6838,0.2587,0.4869,0.4434,0.5976,0.2552,0.4660100000000001,VAE,VAE,0.4172,0.4172
15,0,vowels,1456,12,3.4341,0.1879,0.3373,0.9305,0.7489,0.5153,0.6694,0.3423,0.7626,0.7591,0.946,0.6199299999999999,VAE,AutoEncoder,0.7489,0.9305
16,0,wbc,378,30,5.5556,0.062,0.131,0.9556,0.9218,0.9437,0.9197,0.4408,0.9296,0.0901,0.9042,0.62985,LUNAR,VAE,0.9042,0.9218
18 changes: 18 additions & 0 deletions examples/auto_model_selection_example/time_df.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
,Data,#Samples,# Dimensions,Outlier Perc,MO-GAAL,SO-GAAL,AutoEncoder,VAE,AnoGAN,DeepSVDD,ALAD,AE1SVM,DevNet,LUNAR
0,arrhythmia,452,274,14.6018,2.2201,0.4481,2.3577,0.8746,25.0503,0.4209,2.6153,1.4392,0.3011,1.0523
0,cardio,1831,21,9.6122,1.935,0.4122,0.6947,2.932,95.7761,1.4623,2.7029,5.3502,0.4773,0.9887
0,glass,214,9,4.2056,0.6103,0.0928,0.0874,0.3567,11.0639,0.1817,2.5448,0.6271,0.0824,0.9243
0,ionosphere,351,33,35.8974,0.6259,0.1123,0.1285,0.5312,18.8991,0.3043,2.5635,1.0727,0.1084,1.2895
0,letter,1600,32,6.25,1.4954,0.3275,0.6169,2.614,82.8181,1.2625,2.7119,4.6007,0.4063,0.9716
0,lympho,148,18,4.0541,0.5972,0.0854,0.0476,0.1826,8.0602,0.1397,2.5423,0.4654,0.0658,1.3465
0,mnist,7603,100,9.2069,7.1153,2.5733,2.926,12.5522,398.1935,5.9011,3.3021,21.9616,2.5411,2.3247
0,musk,3062,166,3.1679,2.7724,1.497,1.1978,5.0226,162.3098,2.4641,2.8737,9.0085,1.3459,1.257
0,optdigits,5216,64,2.8758,4.8944,1.4061,2.0056,8.4954,272.4242,4.0775,3.0446,15.0324,1.4741,1.5397
0,pendigits,6870,16,2.2707,6.2045,1.5857,2.6522,11.2758,351.1087,5.0944,3.1435,19.4177,1.4525,1.6487
0,pima,768,8,34.8958,0.6682,0.1511,0.2856,1.2059,39.4155,0.6038,2.5746,2.265,0.1911,0.8936
0,satellite,6435,36,31.6395,5.5677,1.5037,2.4127,10.2862,326.4625,4.8506,3.109,18.271,1.5682,1.8521
0,satimage-2,5803,36,1.2235,4.8854,1.2301,2.177,9.2547,295.0104,4.2977,3.0388,16.3908,1.4015,1.6889
0,shuttle,49097,9,7.1511,54.5689,12.518,18.359,78.2603,2472.4886,36.2362,7.0884,138.645,12.3863,9.7518
0,vertebral,240,6,12.5,0.6014,0.0915,0.0858,0.3476,12.7787,0.2069,2.5141,0.7572,0.0818,0.8338
0,vowels,1456,12,3.4341,1.4241,0.2844,0.545,2.3138,74.3473,1.1114,2.6293,4.2116,0.3274,0.9719
0,wbc,378,30,5.5556,0.6167,0.1117,0.1449,0.6031,20.4393,0.3196,2.5309,1.2024,0.1141,1.2352
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Dataset,Domain,Data Type,Number of Instances,Number of Attributes,Missing Values,Area,Mean,Std Dev,Min,Max
arrhythmia,Medical,Multivariate,452,279,Yes,Cardiology,10.125108197144888,6.546008639872744,-242.4,780.0
glass,Forensic Science,Multivariate,214,9,No,Material Identification,11.265851609553478,0.6879278145483789,0.0,75.41
ionosphere,Astronomy,Multivariate,351,34,No,Space,0.2552029672796339,0.5251381269031516,-1.0,1.0
lympho,Medical,Multivariate,148,18,No,Oncology,2.0746996996997,0.6997776820615417,1.0,8.0
Loading
Loading