-
Notifications
You must be signed in to change notification settings - Fork 26
/
training_information.ts
253 lines (219 loc) · 7.04 KB
/
training_information.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import { PreTrainedTokenizer } from "@xenova/transformers";
import { DataType } from "../index.js";
interface Privacy {
// maximum weights difference between each round
clippingRadius?: number;
// variance of the Gaussian noise added to the shared weights.
noiseScale?: number;
}
export type TrainingInformation<D extends DataType> = {
// epochs: number of epochs to run training for
epochs: number;
// roundDuration: number of epochs between each weight sharing round.
// e.g.if 3 then weights are shared every 3 epochs (in the distributed setting).
roundDuration: number;
// validationSplit: fraction of data to keep for validation, note this only works for image data
validationSplit: number;
// batchSize: batch size of training data
batchSize: number;
// scheme: Distributed training scheme, i.e. Federated and Decentralized
scheme: "decentralized" | "federated" | "local";
// use Differential Privacy, reduce training accuracy and improve privacy.
privacy?: Privacy;
// maxShareValue: Secure Aggregation: maximum absolute value of a number in a randomly generated share
// default is 100, must be a positive number, check the docs/PRIVACY.md file for more information on significance of maxShareValue selection
// only relevant if secure aggregation is true (for either federated or decentralized learning)
maxShareValue?: number;
// minNbOfParticipants: minimum number of participants required to train collaboratively
// In decentralized Learning the default is 3, in federated learning it is 2
minNbOfParticipants: number;
// aggregationStrategy: aggregator to be used by the server for federated learning, or by the peers for decentralized learning
// default is 'mean'
aggregationStrategy?: "mean" | "secure";
// Tensor framework used by the model
tensorBackend: "tfjs" | "gpt";
} & DataTypeToTrainingInformation[D];
interface DataTypeToTrainingInformation {
image: {
dataType: "image";
// LABEL_LIST of classes, e.g. if two class of images, one with dogs and one with cats, then we would
// define ['dogs', 'cats'].
LABEL_LIST: string[];
// IMAGE_H height of image (or RESIZED_IMAGE_H if ImagePreprocessing.Resize in preprocessingFunctions)
IMAGE_H: number;
// IMAGE_W width of image (or RESIZED_IMAGE_W if ImagePreprocessing.Resize in preprocessingFunctions)
IMAGE_W: number;
};
tabular: {
dataType: "tabular";
// inputColumns: for tabular data, the columns to be chosen as input data for the model
inputColumns: string[];
// outputColumns: for tabular data, the columns to be predicted by the model
outputColumn: string;
};
text: {
dataType: "text";
// tokenizer (string | PreTrainedTokenizer). This field should be initialized with the name of a Transformers.js pre-trained tokenizer, e.g., 'Xenova/gpt2'.
// When the tokenizer is first called, the actual object will be initialized and loaded into this field for the subsequent tokenizations.
tokenizer: string | PreTrainedTokenizer;
// contextLength: the maximum length of a input string used as input to a GPT model. It is used during preprocessing to
// truncate strings to a maximum length. The default value is tokenizer.model_max_length
contextLength: number;
};
}
function isPrivacy(raw: unknown): raw is Privacy {
if (typeof raw !== "object" || raw === null) {
return false;
}
const {
clippingRadius,
noiseScale,
}: Partial<Record<keyof Privacy, unknown>> = raw;
if (
(clippingRadius !== undefined && typeof clippingRadius !== "number") ||
(noiseScale !== undefined && typeof noiseScale !== "number")
)
return false;
const _: Privacy = {
clippingRadius,
noiseScale,
} satisfies Record<keyof Privacy, unknown>;
return true;
}
export function isTrainingInformation(
raw: unknown,
): raw is TrainingInformation<DataType> {
if (typeof raw !== "object" || raw === null) {
return false;
}
const {
aggregationStrategy,
batchSize,
dataType,
privacy,
epochs,
maxShareValue,
minNbOfParticipants,
roundDuration,
scheme,
validationSplit,
tensorBackend,
}: Partial<Record<keyof TrainingInformation<DataType>, unknown>> = raw;
if (
typeof epochs !== "number" ||
typeof batchSize !== "number" ||
typeof roundDuration !== "number" ||
typeof validationSplit !== "number" ||
typeof minNbOfParticipants !== "number" ||
(privacy !== undefined && !isPrivacy(privacy)) ||
(maxShareValue !== undefined && typeof maxShareValue !== "number")
) {
return false;
}
switch (aggregationStrategy) {
case undefined:
case "mean":
case "secure":
break;
default:
return false;
}
switch (tensorBackend) {
case "tfjs":
case "gpt":
break;
default:
return false;
}
switch (scheme) {
case "decentralized":
case "federated":
case "local":
break;
default:
return false;
}
const repack = {
aggregationStrategy,
batchSize,
epochs,
maxShareValue,
minNbOfParticipants,
privacy,
roundDuration,
scheme,
tensorBackend,
validationSplit,
};
switch (dataType) {
case "image": {
type ImageOnly = Omit<
TrainingInformation<"image">,
keyof TrainingInformation<DataType>
>;
const { LABEL_LIST, IMAGE_W, IMAGE_H }: Partial<ImageOnly> = raw;
if (
!(
Array.isArray(LABEL_LIST) &&
LABEL_LIST.every((e) => typeof e === "string")
) ||
typeof IMAGE_H !== "number" ||
typeof IMAGE_W !== "number"
)
return false;
const _: TrainingInformation<"image"> = {
...repack,
dataType,
LABEL_LIST,
IMAGE_W,
IMAGE_H,
} satisfies Record<keyof TrainingInformation<"image">, unknown>;
return true;
}
case "tabular": {
type TabularOnly = Omit<
TrainingInformation<"tabular">,
keyof TrainingInformation<DataType>
>;
const { inputColumns, outputColumn }: Partial<TabularOnly> = raw;
if (
!(
Array.isArray(inputColumns) &&
inputColumns.every((e) => typeof e === "string")
) ||
typeof outputColumn !== "string"
)
return false;
const _: TrainingInformation<"tabular"> = {
...repack,
dataType,
inputColumns,
outputColumn,
} satisfies Record<keyof TrainingInformation<"tabular">, unknown>;
return true;
}
case "text": {
const {
contextLength,
tokenizer,
}: Partial<
Omit<TrainingInformation<"text">,
keyof TrainingInformation<DataType>>
> = raw;
if (
(typeof tokenizer !== "string" &&
!(tokenizer instanceof PreTrainedTokenizer)) ||
(typeof contextLength !== "number")
)
return false;
const _: TrainingInformation<"text"> = {
...repack,
dataType,
contextLength,
tokenizer,
} satisfies Record<keyof TrainingInformation<"text">, unknown>;
return true;
}
}
return false;
}