Skip to content

Commit

Permalink
*: enforce using override keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienVig committed Oct 9, 2024
1 parent dfb607d commit d75338a
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 23 deletions.
4 changes: 2 additions & 2 deletions discojs/src/client/decentralized/decentralized_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ export class DecentralizedClient extends Client {
* create peer-to-peer WebRTC connections with peers. The server is used to exchange
* peers network information.
*/
async connect(): Promise<Model> {
override async connect(): Promise<Model> {
const model = await super.connect() // Get the server base model
const serverURL = new URL('', this.url.href)
switch (this.url.protocol) {
Expand Down Expand Up @@ -100,7 +100,7 @@ export class DecentralizedClient extends Client {
return model
}

async disconnect (): Promise<void> {
override async disconnect (): Promise<void> {
// Disconnect from peers
await this.#pool?.shutdown()
this.#pool = undefined
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/client/federated/federated_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ export class FederatedClient extends Client {
* as well as the latest training information: latest global model, current round and
* whether we are waiting for more participants.
*/
async connect(): Promise<Model> {
override async connect(): Promise<Model> {
const model = await super.connect() // Get the server base model

const serverURL = new URL("", this.url.href);
Expand Down
6 changes: 3 additions & 3 deletions discojs/src/client/local_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ import { Client } from "./client.js";
*/
export class LocalClient extends Client {

getNbOfParticipants(): number {
override getNbOfParticipants(): number {
return 1;
}

onRoundBeginCommunication(): Promise<void> {
override onRoundBeginCommunication(): Promise<void> {
return Promise.resolve();
}
// Simply return the local weights
onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer> {
override onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer> {
return Promise.resolve(weights);
}
}
2 changes: 1 addition & 1 deletion discojs/src/dataset/data/image_data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { ImagePreprocessing, IMAGE_PREPROCESSING } from './preprocessing/index.j
export class ImageData extends Data {
public readonly availablePreprocessing = IMAGE_PREPROCESSING

static async init (
static override async init (
dataset: tf.data.Dataset<tf.TensorContainer>,
task: Task,
size?: number
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/dataset/data/tabular_data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { TABULAR_PREPROCESSING } from './preprocessing/index.js'
export class TabularData extends Data {
public readonly availablePreprocessing = TABULAR_PREPROCESSING

static async init (
static override async init (
dataset: tf.data.Dataset<tf.TensorContainer>,
task: Task,
size?: number
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/dataset/data/text_data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { TEXT_PREPROCESSING } from './preprocessing/index.js'
export class TextData extends Data {
public readonly availablePreprocessing = TEXT_PREPROCESSING

static init (
static override init (
dataset: tf.data.Dataset<tf.TensorContainer>,
task: Task,
size?: number
Expand Down
20 changes: 10 additions & 10 deletions discojs/src/models/gpt/layers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import type { ModelSize } from './config.js'
class Range extends tf.layers.Layer {
static readonly className = 'Range'

computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
override computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
return inputShape
}

call (input: tf.Tensor | tf.Tensor[], kwargs: Record<string, unknown>): tf.Tensor | tf.Tensor[] {
override call (input: tf.Tensor | tf.Tensor[], kwargs: Record<string, unknown>): tf.Tensor | tf.Tensor[] {
return tf.tidy(() => {
if (Array.isArray(input)) {
// TODO support multitensor
Expand All @@ -30,11 +30,11 @@ tf.serialization.registerClass(Range)
class LogLayer extends tf.layers.Layer {
static readonly className = 'LogLayer'

computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
override computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
return inputShape
}

call (input: tf.Tensor | tf.Tensor[], kwargs: Record<string, unknown>): tf.Tensor | tf.Tensor[] {
override call (input: tf.Tensor | tf.Tensor[], kwargs: Record<string, unknown>): tf.Tensor | tf.Tensor[] {
return tf.tidy(() => {
if (Array.isArray(input)) {
input = input[0]
Expand Down Expand Up @@ -77,7 +77,7 @@ class CausalSelfAttention extends tf.layers.Layer {
this.mask = tf.linalg.bandPart(tf.ones([config.blockSize, config.blockSize]), -1, 0)
}

build (): void {
override build (): void {
this.cAttnKernel = this.addWeight(
'c_attn/kernel',
[this.nEmbd, 3 * this.nEmbd],
Expand All @@ -104,16 +104,16 @@ class CausalSelfAttention extends tf.layers.Layer {
)
}

computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
override computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
return inputShape
}

getConfig (): tf.serialization.ConfigDict {
override getConfig (): tf.serialization.ConfigDict {
const config = super.getConfig()
return Object.assign({}, config, this.config)
}

call (input: tf.Tensor | tf.Tensor[], kwargs: Record<string, unknown>): tf.Tensor | tf.Tensor[] {
override call (input: tf.Tensor | tf.Tensor[], kwargs: Record<string, unknown>): tf.Tensor | tf.Tensor[] {
return tf.tidy(() => {
if (this.cAttnKernel === undefined ||
this.cAttnBias === undefined ||
Expand Down Expand Up @@ -199,11 +199,11 @@ class GELU extends tf.layers.Layer {
super({})
}

computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
override computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] {
return inputShape
}

call (input: tf.Tensor | tf.Tensor[], kwargs: Record<string, unknown>): tf.Tensor | tf.Tensor[] {
override call (input: tf.Tensor | tf.Tensor[], kwargs: Record<string, unknown>): tf.Tensor | tf.Tensor[] {
return tf.tidy(() => {
if (Array.isArray(input)) {
// TODO support multitensor
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/models/gpt/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ class GPTModel extends tf.LayersModel {
return this.config
}

compile() {
override compile() {
if (this.optimizer !== undefined) return
this.optimizer = this.config.weightDecay !== 0
? getCustomAdam(this, this.config.lr, this.config.weightDecay)
: tf.train.adam(this.config.lr)
}

async fitDataset<T>(dataset: Dataset<T>, trainingArgs: tf.ModelFitDatasetArgs<T>): Promise<tf.History> {
override async fitDataset<T>(dataset: Dataset<T>, trainingArgs: tf.ModelFitDatasetArgs<T>): Promise<tf.History> {
const callbacks = trainingArgs.callbacks as tf.CustomCallbackArgs
const evalDataset = trainingArgs.validationData as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>
await callbacks.onTrainBegin?.()
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/models/gpt/optimizers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class AdamW extends tf.AdamOptimizer {
this.gradientClipNorm = p.gradientClipNorm
}

applyGradients (variableGradients: Record<string, tf.Variable> | Array<{ name: string, tensor: tf.Tensor }>): void {
override applyGradients (variableGradients: Record<string, tf.Variable> | Array<{ name: string, tensor: tf.Tensor }>): void {
const varNames: string[] = Array.isArray(variableGradients)
? variableGradients.map((v) => v.name)
: Object.keys(variableGradients)
Expand Down
5 changes: 4 additions & 1 deletion tsconfig.base.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
"skipLibCheck": true,

// don't pollute
"noEmitOnError": true
"noEmitOnError": true,

// enforce using the override keyword
"noImplicitOverride": true
}
}

0 comments on commit d75338a

Please sign in to comment.