From a0e2b53459aa1969e723c153bb457c85861ba436 Mon Sep 17 00:00:00 2001 From: Yiming Date: Wed, 25 Dec 2024 22:59:53 +0800 Subject: [PATCH 01/16] feat(zmodel): add new functions `currentModel` and `currentOperation` (#1925) --- .../function-invocation-validator.ts | 12 +- packages/schema/src/res/stdlib.zmodel | 23 +++ packages/sdk/src/code-gen.ts | 5 + .../src/typescript-expression-transformer.ts | 77 ++++++-- .../with-policy/currentModel.test.ts | 185 ++++++++++++++++++ .../with-policy/currentOperation.test.ts | 154 +++++++++++++++ 6 files changed, 441 insertions(+), 15 deletions(-) create mode 100644 tests/integration/tests/enhancements/with-policy/currentModel.test.ts create mode 100644 tests/integration/tests/enhancements/with-policy/currentOperation.test.ts diff --git a/packages/schema/src/language-server/validator/function-invocation-validator.ts b/packages/schema/src/language-server/validator/function-invocation-validator.ts index 8c11a2a72..343c75cad 100644 --- a/packages/schema/src/language-server/validator/function-invocation-validator.ts +++ b/packages/schema/src/language-server/validator/function-invocation-validator.ts @@ -87,7 +87,17 @@ export default class FunctionInvocationValidator implements AstValidator(expr.args[0]?.value); + if (arg && !allCasing.includes(arg)) { + accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, { + node: expr.args[0], + }); + } + } else if ( funcAllowedContext.includes(ExpressionContext.AccessPolicy) || funcAllowedContext.includes(ExpressionContext.ValidationRule) ) { diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index 3316a90a9..483993d92 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -171,6 +171,29 @@ function hasSome(field: Any[], search: Any[]): Boolean { function isEmpty(field: Any[]): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) +/** + * The name of the model for which the policy rule is defined. If the rule is + * inherited to a sub model, this function returns the name of the sub model. + * + * @param optional parameter to control the casing of the returned value. Valid + * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults + * to "original". + */ +function currentModel(casing: String?): String { +} @@@expressionContext([AccessPolicy]) + +/** + * The operation for which the policy rule is defined for. Note that a rule with + * "all" operation is expanded to "create", "read", "update", and "delete" rules, + * and the function returns corresponding value for each expanded version. + * + * @param optional parameter to control the casing of the returned value. Valid + * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults + * to "original". + */ +function currentOperation(casing: String?): String { +} @@@expressionContext([AccessPolicy]) + /** * Marks an attribute to be only applicable to certain field types. */ diff --git a/packages/sdk/src/code-gen.ts b/packages/sdk/src/code-gen.ts index 7b26cc0c4..67833b788 100644 --- a/packages/sdk/src/code-gen.ts +++ b/packages/sdk/src/code-gen.ts @@ -47,6 +47,11 @@ export async function saveProject(project: Project) { * Emit a TS project to JS files. */ export async function emitProject(project: Project) { + // ignore type checking for all source files + for (const sf of project.getSourceFiles()) { + sf.insertStatements(0, '// @ts-nocheck'); + } + const errors = project.getPreEmitDiagnostics().filter((d) => d.getCategory() === DiagnosticCategory.Error); if (errors.length > 0) { console.error('Error compiling generated code:'); diff --git a/packages/sdk/src/typescript-expression-transformer.ts b/packages/sdk/src/typescript-expression-transformer.ts index 9a884ebdf..801db4d4f 100644 --- a/packages/sdk/src/typescript-expression-transformer.ts +++ b/packages/sdk/src/typescript-expression-transformer.ts @@ -20,6 +20,7 @@ import { isNullExpr, isThisExpr, } from '@zenstackhq/language/ast'; +import { getContainerOfType } from 'langium'; import { P, match } from 'ts-pattern'; import { ExpressionContext } from './constants'; import { getEntityCheckerFunctionName } from './names'; @@ -40,6 +41,8 @@ type Options = { operationContext?: 'read' | 'create' | 'update' | 'postUpdate' | 'delete'; }; +type Casing = 'original' | 'upper' | 'lower' | 'capitalize' | 'uncapitalize'; + // a registry of function handlers marked with @func const functionHandlers = new Map(); @@ -150,7 +153,7 @@ export class TypeScriptExpressionTransformer { } const args = expr.args.map((arg) => arg.value); - return handler.value.call(this, args, normalizeUndefined); + return handler.value.call(this, expr, args, normalizeUndefined); } // #region function invocation handlers @@ -168,7 +171,7 @@ export class TypeScriptExpressionTransformer { } @func('length') - private _length(args: Expression[]) { + private _length(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); const min = getLiteral(args[1]); const max = getLiteral(args[2]); @@ -188,7 +191,7 @@ export class TypeScriptExpressionTransformer { } @func('contains') - private _contains(args: Expression[], normalizeUndefined: boolean) { + private _contains(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); const caseInsensitive = getLiteral(args[2]) === true; let result: string; @@ -201,34 +204,34 @@ export class TypeScriptExpressionTransformer { } @func('startsWith') - private _startsWith(args: Expression[], normalizeUndefined: boolean) { + private _startsWith(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); const result = `${field}?.startsWith(${this.transform(args[1], normalizeUndefined)})`; return this.ensureBoolean(result); } @func('endsWith') - private _endsWith(args: Expression[], normalizeUndefined: boolean) { + private _endsWith(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); const result = `${field}?.endsWith(${this.transform(args[1], normalizeUndefined)})`; return this.ensureBoolean(result); } @func('regex') - private _regex(args: Expression[]) { + private _regex(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); const pattern = getLiteral(args[1]); return this.ensureBooleanTernary(args[0], field, `new RegExp(${JSON.stringify(pattern)}).test(${field})`); } @func('email') - private _email(args: Expression[]) { + private _email(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); return this.ensureBooleanTernary(args[0], field, `z.string().email().safeParse(${field}).success`); } @func('datetime') - private _datetime(args: Expression[]) { + private _datetime(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); return this.ensureBooleanTernary( args[0], @@ -238,20 +241,20 @@ export class TypeScriptExpressionTransformer { } @func('url') - private _url(args: Expression[]) { + private _url(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); return this.ensureBooleanTernary(args[0], field, `z.string().url().safeParse(${field}).success`); } @func('has') - private _has(args: Expression[], normalizeUndefined: boolean) { + private _has(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); const result = `${field}?.includes(${this.transform(args[1], normalizeUndefined)})`; return this.ensureBoolean(result); } @func('hasEvery') - private _hasEvery(args: Expression[], normalizeUndefined: boolean) { + private _hasEvery(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); return this.ensureBooleanTernary( args[0], @@ -261,7 +264,7 @@ export class TypeScriptExpressionTransformer { } @func('hasSome') - private _hasSome(args: Expression[], normalizeUndefined: boolean) { + private _hasSome(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); return this.ensureBooleanTernary( args[0], @@ -271,13 +274,13 @@ export class TypeScriptExpressionTransformer { } @func('isEmpty') - private _isEmpty(args: Expression[]) { + private _isEmpty(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); return `(!${field} || ${field}?.length === 0)`; } @func('check') - private _check(args: Expression[]) { + private _check(_invocation: InvocationExpr, args: Expression[]) { if (!isDataModelFieldReference(args[0])) { throw new TypeScriptExpressionTransformerError(`First argument of check() must be a field`); } @@ -309,6 +312,52 @@ export class TypeScriptExpressionTransformer { return `${entityCheckerFunc}(input.${fieldRef.target.$refText}, context)`; } + private toStringWithCaseChange(value: string, casing: Casing) { + if (!value) { + return "''"; + } + return match(casing) + .with('original', () => `'${value}'`) + .with('upper', () => `'${value.toUpperCase()}'`) + .with('lower', () => `'${value.toLowerCase()}'`) + .with('capitalize', () => `'${value.charAt(0).toUpperCase() + value.slice(1)}'`) + .with('uncapitalize', () => `'${value.charAt(0).toLowerCase() + value.slice(1)}'`) + .exhaustive(); + } + + @func('currentModel') + private _currentModel(invocation: InvocationExpr, args: Expression[]) { + let casing: Casing = 'original'; + if (args[0]) { + casing = getLiteral(args[0]) as Casing; + } + + const containingModel = getContainerOfType(invocation, isDataModel); + if (!containingModel) { + throw new TypeScriptExpressionTransformerError('currentModel() must be called inside a model'); + } + return this.toStringWithCaseChange(containingModel.name, casing); + } + + @func('currentOperation') + private _currentOperation(_invocation: InvocationExpr, args: Expression[]) { + let casing: Casing = 'original'; + if (args[0]) { + casing = getLiteral(args[0]) as Casing; + } + + if (!this.options.operationContext) { + throw new TypeScriptExpressionTransformerError( + 'currentOperation() must be called inside an access policy rule' + ); + } + let contextOperation = this.options.operationContext; + if (contextOperation === 'postUpdate') { + contextOperation = 'update'; + } + return this.toStringWithCaseChange(contextOperation, casing); + } + private ensureBoolean(expr: string) { if (this.options.context === ExpressionContext.ValidationRule) { // all fields are optional in a validation context, so we treat undefined diff --git a/tests/integration/tests/enhancements/with-policy/currentModel.test.ts b/tests/integration/tests/enhancements/with-policy/currentModel.test.ts new file mode 100644 index 000000000..0b98314a4 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/currentModel.test.ts @@ -0,0 +1,185 @@ +import { loadModelWithError, loadSchema } from '@zenstackhq/testtools'; + +describe('currentModel tests', () => { + it('works in models', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentModel() == 'User') + } + + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentModel() == 'User') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with upper case', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('upper') == 'USER') + } + + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('upper') == 'Post') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with lower case', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('lower') == 'user') + } + + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('lower') == 'Post') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with capitalization', async () => { + const { enhance } = await loadSchema( + ` + model user { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('capitalize') == 'User') + } + + model post { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('capitalize') == 'post') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with uncapitalization', async () => { + const { enhance } = await loadSchema( + ` + model USER { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('uncapitalize') == 'uSER') + } + + model POST { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('uncapitalize') == 'POST') + } + ` + ); + + const db = enhance(); + await expect(db.USER.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.POST.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works when inherited from abstract base', async () => { + const { enhance } = await loadSchema( + ` + abstract model Base { + id Int @id + @@allow('read', true) + @@allow('create', currentModel() == 'User') + } + + model User extends Base { + } + + model Post extends Base { + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works when inherited from delegate base', async () => { + const { enhance } = await loadSchema( + ` + model Base { + id Int @id + type String + @@delegate(type) + + @@allow('read', true) + @@allow('create', currentModel() == 'User') + } + + model User extends Base { + } + + model Post extends Base { + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('complains when used outside policies', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @default(currentModel()) + } + ` + ) + ).resolves.toContain('function "currentModel" is not allowed in the current context: DefaultValue'); + }); + + it('complains when casing argument is invalid', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @id + @@allow('create', currentModel('foo') == 'User') + } + ` + ) + ).resolves.toContain('argument must be one of: "original", "upper", "lower", "capitalize", "uncapitalize"'); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/currentOperation.test.ts b/tests/integration/tests/enhancements/with-policy/currentOperation.test.ts new file mode 100644 index 000000000..c56713316 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/currentOperation.test.ts @@ -0,0 +1,154 @@ +import { loadModelWithError, loadSchema } from '@zenstackhq/testtools'; + +describe('currentOperation tests', () => { + it('works with specific rules', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation() == 'create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation() == 'read') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with all rule', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('all', currentOperation() == 'create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation() == 'read') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with upper case', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('upper') == 'CREATE') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('upper') == 'READ') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with lower case', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('lower') == 'create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('lower') == 'read') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with capitalization', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('capitalize') == 'Create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('capitalize') == 'create') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with uncapitalization', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('uncapitalize') == 'create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('uncapitalize') == 'read') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('complains when used outside policies', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @default(currentOperation()) + } + ` + ) + ).resolves.toContain('function "currentOperation" is not allowed in the current context: DefaultValue'); + }); + + it('complains when casing argument is invalid', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @id + @@allow('create', currentOperation('foo') == 'User') + } + ` + ) + ).resolves.toContain('argument must be one of: "original", "upper", "lower", "capitalize", "uncapitalize"'); + }); +}); From 1b7448f4ed2a6430fe0904c61255aa03edfd23d4 Mon Sep 17 00:00:00 2001 From: Eugen Istoc Date: Mon, 30 Dec 2024 00:09:21 -0300 Subject: [PATCH 02/16] feat: Add `@encrypted` enhancer (#1922) --- .../src/enhancements/edge/encrypted.ts | 1 + .../enhancements/node/create-enhancement.ts | 15 +- .../src/enhancements/node/encrypted.ts | 175 ++++++++++++++++++ packages/runtime/src/types.ts | 15 +- packages/schema/src/res/stdlib.zmodel | 8 + .../with-encrypted/with-encrypted.test.ts | 108 +++++++++++ 6 files changed, 319 insertions(+), 3 deletions(-) create mode 120000 packages/runtime/src/enhancements/edge/encrypted.ts create mode 100644 packages/runtime/src/enhancements/node/encrypted.ts create mode 100644 tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts diff --git a/packages/runtime/src/enhancements/edge/encrypted.ts b/packages/runtime/src/enhancements/edge/encrypted.ts new file mode 120000 index 000000000..96d88b82d --- /dev/null +++ b/packages/runtime/src/enhancements/edge/encrypted.ts @@ -0,0 +1 @@ +../node/encrypted.ts \ No newline at end of file diff --git a/packages/runtime/src/enhancements/node/create-enhancement.ts b/packages/runtime/src/enhancements/node/create-enhancement.ts index adec1fdf2..871f8a1b4 100644 --- a/packages/runtime/src/enhancements/node/create-enhancement.ts +++ b/packages/runtime/src/enhancements/node/create-enhancement.ts @@ -14,13 +14,14 @@ import { withJsonProcessor } from './json-processor'; import { Logger } from './logger'; import { withOmit } from './omit'; import { withPassword } from './password'; +import { withEncrypted } from './encrypted'; import { policyProcessIncludeRelationPayload, withPolicy } from './policy'; import type { PolicyDef } from './types'; /** * All enhancement kinds */ -const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate']; +const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate', 'encrypted']; /** * Options for {@link createEnhancement} @@ -100,6 +101,7 @@ export function createEnhancement( } const hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password')); + const hasEncrypted = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@encrypted')); const hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit')); const hasDefaultAuth = allFields.some((field) => field.defaultValueProvider); const hasTypeDefField = allFields.some((field) => field.isTypeDef); @@ -120,13 +122,22 @@ export function createEnhancement( } } - // password enhancement must be applied prior to policy because it changes then length of the field + // password and encrypted enhancement must be applied prior to policy because it changes then length of the field // and can break validation rules like `@length` if (hasPassword && kinds.includes('password')) { // @password proxy result = withPassword(result, options); } + if (hasEncrypted && kinds.includes('encrypted')) { + if (!options.encryption) { + throw new Error('Encryption options are required for @encrypted enhancement'); + } + + // @encrypted proxy + result = withEncrypted(result, options); + } + // 'policy' and 'validation' enhancements are both enabled by `withPolicy` if (kinds.includes('policy') || kinds.includes('validation')) { result = withPolicy(result, options, context); diff --git a/packages/runtime/src/enhancements/node/encrypted.ts b/packages/runtime/src/enhancements/node/encrypted.ts new file mode 100644 index 000000000..c6d6fc873 --- /dev/null +++ b/packages/runtime/src/enhancements/node/encrypted.ts @@ -0,0 +1,175 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +/* eslint-disable @typescript-eslint/no-unused-vars */ + +import { + FieldInfo, + NestedWriteVisitor, + enumerate, + getModelFields, + resolveField, + type PrismaWriteActionType, +} from '../../cross'; +import { DbClientContract, CustomEncryption, SimpleEncryption } from '../../types'; +import { InternalEnhancementOptions } from './create-enhancement'; +import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; +import { QueryUtils } from './query-utils'; + +/** + * Gets an enhanced Prisma client that supports `@encrypted` attribute. + * + * @private + */ +export function withEncrypted( + prisma: DbClient, + options: InternalEnhancementOptions +): DbClient { + return makeProxy( + prisma, + options.modelMeta, + (_prisma, model) => new EncryptedHandler(_prisma as DbClientContract, model, options), + 'encrypted' + ); +} + +class EncryptedHandler extends DefaultPrismaProxyHandler { + private queryUtils: QueryUtils; + private encoder = new TextEncoder(); + private decoder = new TextDecoder(); + + constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { + super(prisma, model, options); + + this.queryUtils = new QueryUtils(prisma, options); + + if (!options.encryption) throw new Error('Encryption options must be provided'); + + if (this.isCustomEncryption(options.encryption!)) { + if (!options.encryption.encrypt || !options.encryption.decrypt) + throw new Error('Custom encryption must provide encrypt and decrypt functions'); + } else { + if (!options.encryption.encryptionKey) throw new Error('Encryption key must be provided'); + if (options.encryption.encryptionKey.length !== 32) throw new Error('Encryption key must be 32 bytes'); + } + } + + private async getKey(secret: Uint8Array): Promise { + return crypto.subtle.importKey('raw', secret, 'AES-GCM', false, ['encrypt', 'decrypt']); + } + + private isCustomEncryption(encryption: CustomEncryption | SimpleEncryption): encryption is CustomEncryption { + return 'encrypt' in encryption && 'decrypt' in encryption; + } + + private async encrypt(field: FieldInfo, data: string): Promise { + if (this.isCustomEncryption(this.options.encryption!)) { + return this.options.encryption.encrypt(this.model, field, data); + } + + const key = await this.getKey(this.options.encryption!.encryptionKey); + const iv = crypto.getRandomValues(new Uint8Array(12)); + + const encrypted = await crypto.subtle.encrypt( + { + name: 'AES-GCM', + iv, + }, + key, + this.encoder.encode(data) + ); + + // Combine IV and encrypted data into a single array of bytes + const bytes = [...iv, ...new Uint8Array(encrypted)]; + + // Convert bytes to base64 string + return btoa(String.fromCharCode(...bytes)); + } + + private async decrypt(field: FieldInfo, data: string): Promise { + if (this.isCustomEncryption(this.options.encryption!)) { + return this.options.encryption.decrypt(this.model, field, data); + } + + const key = await this.getKey(this.options.encryption!.encryptionKey); + + // Convert base64 back to bytes + const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0)); + + // First 12 bytes are IV, rest is encrypted data + const decrypted = await crypto.subtle.decrypt( + { + name: 'AES-GCM', + iv: bytes.slice(0, 12), + }, + key, + bytes.slice(12) + ); + + return this.decoder.decode(decrypted); + } + + // base override + protected async preprocessArgs(action: PrismaProxyActions, args: any) { + const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; + if (args && args.data && actionsOfInterest.includes(action)) { + await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); + } + return args; + } + + // base override + protected async processResultEntity(method: PrismaProxyActions, data: T): Promise { + if (!data || typeof data !== 'object') { + return data; + } + + for (const value of enumerate(data)) { + await this.doPostProcess(value, this.model); + } + + return data; + } + + private async doPostProcess(entityData: any, model: string) { + const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData); + + for (const field of getModelFields(entityData)) { + const fieldInfo = await resolveField(this.options.modelMeta, realModel, field); + + if (!fieldInfo) { + continue; + } + + const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted'); + if (shouldDecrypt) { + // Don't decrypt null, undefined or empty string values + if (!entityData[field]) continue; + + try { + entityData[field] = await this.decrypt(fieldInfo, entityData[field]); + } catch (error) { + console.warn('Decryption failed, keeping original value:', error); + } + } + } + } + + private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { + const visitor = new NestedWriteVisitor(this.options.modelMeta, { + field: async (field, _action, data, context) => { + // Don't encrypt null, undefined or empty string values + if (!data) return; + + const encAttr = field.attributes?.find((attr) => attr.name === '@encrypted'); + if (encAttr && field.type === 'String') { + try { + context.parent[field.name] = await this.encrypt(field, data); + } catch (error) { + throw new Error(`Encryption failed for field ${field.name}: ${error}`); + } + } + }, + }); + + await visitor.visit(model, action, args); + } +} diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index 7c4df97c1..e691fc32c 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import type { z } from 'zod'; +import { FieldInfo } from './cross'; export type PrismaPromise = Promise & Record PrismaPromise>; @@ -133,6 +134,11 @@ export type EnhancementOptions = { * The `isolationLevel` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack. */ transactionIsolationLevel?: TransactionIsolationLevel; + + /** + * The encryption options for using the `encrypted` enhancement. + */ + encryption?: SimpleEncryption | CustomEncryption; }; /** @@ -145,7 +151,7 @@ export type EnhancementContext = { /** * Kinds of enhancements to `PrismaClient` */ -export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate'; +export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encrypted'; /** * Function for transforming errors. @@ -166,3 +172,10 @@ export type ZodSchemas = { */ input?: Record>; }; + +export type CustomEncryption = { + encrypt: (model: string, field: FieldInfo, plain: string) => Promise; + decrypt: (model: string, field: FieldInfo, cipher: string) => Promise; +}; + +export type SimpleEncryption = { encryptionKey: Uint8Array }; diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index 483993d92..a0a0a41f8 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -575,6 +575,14 @@ attribute @@auth() @@@supportTypeDef */ attribute @password(saltLength: Int?, salt: String?) @@@targetField([StringField]) + +/** + * Indicates that the field is encrypted when storing in the DB and should be decrypted when read + * + * ZenStack uses the Web Crypto API to encrypt and decrypt the field. + */ +attribute @encrypted() @@@targetField([StringField]) + /** * Indicates that the field should be omitted when read from the generated services. */ diff --git a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts new file mode 100644 index 000000000..1e0544c0b --- /dev/null +++ b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts @@ -0,0 +1,108 @@ +import { FieldInfo } from '@zenstackhq/runtime'; +import { loadSchema } from '@zenstackhq/testtools'; +import path from 'path'; + +describe('Encrypted test', () => { + let origDir: string; + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + afterEach(async () => { + process.chdir(origDir); + }); + + it('Simple encryption test', async () => { + const { enhance } = await loadSchema(` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + + @@allow('all', true) + }`); + + const sudoDb = enhance(undefined, { kinds: [] }); + const encryptionKey = new Uint8Array(Buffer.from('AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=', 'base64')); + + const db = enhance(undefined, { + kinds: ['encrypted'], + encryption: { encryptionKey }, + }); + + const create = await db.user.create({ + data: { + id: '1', + encrypted_value: 'abc123', + }, + }); + + const read = await db.user.findUnique({ + where: { + id: '1', + }, + }); + + const sudoRead = await sudoDb.user.findUnique({ + where: { + id: '1', + }, + }); + + expect(create.encrypted_value).toBe('abc123'); + expect(read.encrypted_value).toBe('abc123'); + expect(sudoRead.encrypted_value).not.toBe('abc123'); + }); + + it('Custom encryption test', async () => { + const { enhance } = await loadSchema(` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + + @@allow('all', true) + }`); + + const sudoDb = enhance(undefined, { kinds: [] }); + const db = enhance(undefined, { + kinds: ['encrypted'], + encryption: { + encrypt: async (model: string, field: FieldInfo, data: string) => { + // Add _enc to the end of the input + return data + '_enc'; + }, + decrypt: async (model: string, field: FieldInfo, cipher: string) => { + // Remove _enc from the end of the input explicitly + if (cipher.endsWith('_enc')) { + return cipher.slice(0, -4); // Remove last 4 characters (_enc) + } + + return cipher; + }, + }, + }); + + const create = await db.user.create({ + data: { + id: '1', + encrypted_value: 'abc123', + }, + }); + + const read = await db.user.findUnique({ + where: { + id: '1', + }, + }); + + const sudoRead = await sudoDb.user.findUnique({ + where: { + id: '1', + }, + }); + + expect(create.encrypted_value).toBe('abc123'); + expect(read.encrypted_value).toBe('abc123'); + expect(sudoRead.encrypted_value).toBe('abc123_enc'); + }); +}); From dcef942d881b4d0f6cc6fc2ff2978e85c6948cc9 Mon Sep 17 00:00:00 2001 From: Yiming Date: Mon, 30 Dec 2024 11:10:25 +0800 Subject: [PATCH 03/16] chore: bump version (#1926) --- package.json | 2 +- packages/ide/jetbrains/build.gradle.kts | 2 +- packages/ide/jetbrains/package.json | 2 +- packages/language/package.json | 2 +- packages/misc/redwood/package.json | 2 +- packages/plugins/openapi/package.json | 2 +- packages/plugins/swr/package.json | 2 +- packages/plugins/tanstack-query/package.json | 2 +- packages/plugins/trpc/package.json | 2 +- packages/runtime/package.json | 2 +- packages/schema/package.json | 2 +- packages/sdk/package.json | 2 +- packages/server/package.json | 2 +- packages/testtools/package.json | 2 +- 14 files changed, 14 insertions(+), 14 deletions(-) diff --git a/package.json b/package.json index 715dfb26e..e85f7cdcb 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-monorepo", - "version": "2.10.2", + "version": "2.11.0", "description": "", "scripts": { "build": "pnpm -r --filter=\"!./packages/ide/*\" build", diff --git a/packages/ide/jetbrains/build.gradle.kts b/packages/ide/jetbrains/build.gradle.kts index a2fc573a5..8412fc899 100644 --- a/packages/ide/jetbrains/build.gradle.kts +++ b/packages/ide/jetbrains/build.gradle.kts @@ -9,7 +9,7 @@ plugins { } group = "dev.zenstack" -version = "2.10.2" +version = "2.11.0" repositories { mavenCentral() diff --git a/packages/ide/jetbrains/package.json b/packages/ide/jetbrains/package.json index 331700e89..694b784e9 100644 --- a/packages/ide/jetbrains/package.json +++ b/packages/ide/jetbrains/package.json @@ -1,6 +1,6 @@ { "name": "jetbrains", - "version": "2.10.2", + "version": "2.11.0", "displayName": "ZenStack JetBrains IDE Plugin", "description": "ZenStack JetBrains IDE plugin", "homepage": "https://zenstack.dev", diff --git a/packages/language/package.json b/packages/language/package.json index aa9fdb382..1f4d72647 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/language", - "version": "2.10.2", + "version": "2.11.0", "displayName": "ZenStack modeling language compiler", "description": "ZenStack modeling language compiler", "homepage": "https://zenstack.dev", diff --git a/packages/misc/redwood/package.json b/packages/misc/redwood/package.json index f49e46f0b..159ab6d29 100644 --- a/packages/misc/redwood/package.json +++ b/packages/misc/redwood/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/redwood", "displayName": "ZenStack RedwoodJS Integration", - "version": "2.10.2", + "version": "2.11.0", "description": "CLI and runtime for integrating ZenStack with RedwoodJS projects.", "repository": { "type": "git", diff --git a/packages/plugins/openapi/package.json b/packages/plugins/openapi/package.json index 508dce2b5..e29c8f0e7 100644 --- a/packages/plugins/openapi/package.json +++ b/packages/plugins/openapi/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/openapi", "displayName": "ZenStack Plugin and Runtime for OpenAPI", - "version": "2.10.2", + "version": "2.11.0", "description": "ZenStack plugin and runtime supporting OpenAPI", "main": "index.js", "repository": { diff --git a/packages/plugins/swr/package.json b/packages/plugins/swr/package.json index a3b0c1018..ba61604c7 100644 --- a/packages/plugins/swr/package.json +++ b/packages/plugins/swr/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/swr", "displayName": "ZenStack plugin for generating SWR hooks", - "version": "2.10.2", + "version": "2.11.0", "description": "ZenStack plugin for generating SWR hooks", "main": "index.js", "repository": { diff --git a/packages/plugins/tanstack-query/package.json b/packages/plugins/tanstack-query/package.json index 713457bf6..70d538457 100644 --- a/packages/plugins/tanstack-query/package.json +++ b/packages/plugins/tanstack-query/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/tanstack-query", "displayName": "ZenStack plugin for generating tanstack-query hooks", - "version": "2.10.2", + "version": "2.11.0", "description": "ZenStack plugin for generating tanstack-query hooks", "main": "index.js", "exports": { diff --git a/packages/plugins/trpc/package.json b/packages/plugins/trpc/package.json index 2aa5ad5d5..c900331e5 100644 --- a/packages/plugins/trpc/package.json +++ b/packages/plugins/trpc/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/trpc", "displayName": "ZenStack plugin for tRPC", - "version": "2.10.2", + "version": "2.11.0", "description": "ZenStack plugin for tRPC", "main": "index.js", "repository": { diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 51d026db7..3dcd6c4d6 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/runtime", "displayName": "ZenStack Runtime Library", - "version": "2.10.2", + "version": "2.11.0", "description": "Runtime of ZenStack for both client-side and server-side environments.", "repository": { "type": "git", diff --git a/packages/schema/package.json b/packages/schema/package.json index 45740d15f..93fdfc1f0 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack Language Tools", "description": "FullStack enhancement for Prisma ORM: seamless integration from database to UI", - "version": "2.10.2", + "version": "2.11.0", "author": { "name": "ZenStack Team" }, diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 58fd11aee..a49c0bbdd 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "2.10.2", + "version": "2.11.0", "description": "ZenStack plugin development SDK", "main": "index.js", "scripts": { diff --git a/packages/server/package.json b/packages/server/package.json index 806471ecb..85b5ec809 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/server", - "version": "2.10.2", + "version": "2.11.0", "displayName": "ZenStack Server-side Adapters", "description": "ZenStack server-side adapters", "homepage": "https://zenstack.dev", diff --git a/packages/testtools/package.json b/packages/testtools/package.json index c47c29c65..74db32925 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "2.10.2", + "version": "2.11.0", "description": "ZenStack Test Tools", "main": "index.js", "private": true, From 8486f6447bfb89279118998bd95a7156c244cb72 Mon Sep 17 00:00:00 2001 From: Yiming Date: Mon, 30 Dec 2024 15:22:48 +0800 Subject: [PATCH 04/16] feat: improvements to "encryption" enhancement (#1927) --- .../enhancements/node/create-enhancement.ts | 4 +- .../src/enhancements/node/encrypted.ts | 21 ++- packages/runtime/src/types.ts | 2 +- .../attribute-application-validator.ts | 16 +- .../with-encrypted/with-encrypted.test.ts | 164 +++++++++++++++++- 5 files changed, 186 insertions(+), 21 deletions(-) diff --git a/packages/runtime/src/enhancements/node/create-enhancement.ts b/packages/runtime/src/enhancements/node/create-enhancement.ts index 871f8a1b4..6090f523f 100644 --- a/packages/runtime/src/enhancements/node/create-enhancement.ts +++ b/packages/runtime/src/enhancements/node/create-enhancement.ts @@ -21,7 +21,7 @@ import type { PolicyDef } from './types'; /** * All enhancement kinds */ -const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate', 'encrypted']; +const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate', 'encryption']; /** * Options for {@link createEnhancement} @@ -129,7 +129,7 @@ export function createEnhancement( result = withPassword(result, options); } - if (hasEncrypted && kinds.includes('encrypted')) { + if (hasEncrypted && kinds.includes('encryption')) { if (!options.encryption) { throw new Error('Encryption options are required for @encrypted enhancement'); } diff --git a/packages/runtime/src/enhancements/node/encrypted.ts b/packages/runtime/src/enhancements/node/encrypted.ts index c6d6fc873..d5db66690 100644 --- a/packages/runtime/src/enhancements/node/encrypted.ts +++ b/packages/runtime/src/enhancements/node/encrypted.ts @@ -9,8 +9,9 @@ import { resolveField, type PrismaWriteActionType, } from '../../cross'; -import { DbClientContract, CustomEncryption, SimpleEncryption } from '../../types'; +import { CustomEncryption, DbClientContract, SimpleEncryption } from '../../types'; import { InternalEnhancementOptions } from './create-enhancement'; +import { Logger } from './logger'; import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; import { QueryUtils } from './query-utils'; @@ -27,7 +28,7 @@ export function withEncrypted( prisma, options.modelMeta, (_prisma, model) => new EncryptedHandler(_prisma as DbClientContract, model, options), - 'encrypted' + 'encryption' ); } @@ -35,20 +36,24 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { private queryUtils: QueryUtils; private encoder = new TextEncoder(); private decoder = new TextDecoder(); + private logger: Logger; constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { super(prisma, model, options); this.queryUtils = new QueryUtils(prisma, options); + this.logger = new Logger(prisma); - if (!options.encryption) throw new Error('Encryption options must be provided'); + if (!options.encryption) throw this.queryUtils.unknownError('Encryption options must be provided'); if (this.isCustomEncryption(options.encryption!)) { if (!options.encryption.encrypt || !options.encryption.decrypt) - throw new Error('Custom encryption must provide encrypt and decrypt functions'); + throw this.queryUtils.unknownError('Custom encryption must provide encrypt and decrypt functions'); } else { - if (!options.encryption.encryptionKey) throw new Error('Encryption key must be provided'); - if (options.encryption.encryptionKey.length !== 32) throw new Error('Encryption key must be 32 bytes'); + if (!options.encryption.encryptionKey) + throw this.queryUtils.unknownError('Encryption key must be provided'); + if (options.encryption.encryptionKey.length !== 32) + throw this.queryUtils.unknownError('Encryption key must be 32 bytes'); } } @@ -147,7 +152,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { try { entityData[field] = await this.decrypt(fieldInfo, entityData[field]); } catch (error) { - console.warn('Decryption failed, keeping original value:', error); + this.logger.warn(`Decryption failed, keeping original value: ${error}`); } } } @@ -164,7 +169,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { try { context.parent[field.name] = await this.encrypt(field, data); } catch (error) { - throw new Error(`Encryption failed for field ${field.name}: ${error}`); + this.queryUtils.unknownError(`Encryption failed for field ${field.name}: ${error}`); } } }, diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index e691fc32c..012c94699 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -151,7 +151,7 @@ export type EnhancementContext = { /** * Kinds of enhancements to `PrismaClient` */ -export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encrypted'; +export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encryption'; /** * Function for transforming errors. diff --git a/packages/schema/src/language-server/validator/attribute-application-validator.ts b/packages/schema/src/language-server/validator/attribute-application-validator.ts index a7c0fef9a..0e1d8e885 100644 --- a/packages/schema/src/language-server/validator/attribute-application-validator.ts +++ b/packages/schema/src/language-server/validator/attribute-application-validator.ts @@ -25,7 +25,7 @@ import { isRelationshipField, resolved, } from '@zenstackhq/sdk'; -import { ValidationAcceptor, streamAst } from 'langium'; +import { ValidationAcceptor, streamAllContents, streamAst } from 'langium'; import pluralize from 'pluralize'; import { AstValidator } from '../types'; import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils'; @@ -138,6 +138,9 @@ export default class AttributeApplicationValidator implements AstValidator { + if (isDataModelFieldReference(node) && hasAttribute(node.target.ref as DataModelField, '@encrypted')) { + accept('error', `Encrypted fields cannot be used in policy rules`, { node }); + } + }); + } + private validatePolicyKinds( kind: string, candidates: string[], diff --git a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts index 1e0544c0b..9b6307822 100644 --- a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts +++ b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts @@ -1,9 +1,10 @@ import { FieldInfo } from '@zenstackhq/runtime'; -import { loadSchema } from '@zenstackhq/testtools'; +import { loadSchema, loadModelWithError } from '@zenstackhq/testtools'; import path from 'path'; describe('Encrypted test', () => { let origDir: string; + const encryptionKey = new Uint8Array(Buffer.from('AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=', 'base64')); beforeAll(async () => { origDir = path.resolve('.'); @@ -14,21 +15,25 @@ describe('Encrypted test', () => { }); it('Simple encryption test', async () => { - const { enhance } = await loadSchema(` + const { enhance, prisma } = await loadSchema( + ` model User { id String @id @default(cuid()) encrypted_value String @encrypted() @@allow('all', true) - }`); + }`, + { + enhancements: ['encryption'], + enhanceOptions: { + encryption: { encryptionKey }, + }, + } + ); const sudoDb = enhance(undefined, { kinds: [] }); - const encryptionKey = new Uint8Array(Buffer.from('AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=', 'base64')); - const db = enhance(undefined, { - kinds: ['encrypted'], - encryption: { encryptionKey }, - }); + const db = enhance(); const create = await db.user.create({ data: { @@ -49,9 +54,50 @@ describe('Encrypted test', () => { }, }); + const rawRead = await prisma.user.findUnique({ where: { id: '1' } }); + expect(create.encrypted_value).toBe('abc123'); expect(read.encrypted_value).toBe('abc123'); expect(sudoRead.encrypted_value).not.toBe('abc123'); + expect(rawRead.encrypted_value).not.toBe('abc123'); + }); + + it('Multi-field encryption test', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + x1 String @encrypted() + x2 String @encrypted() + + @@allow('all', true) + }`, + { + enhancements: ['encryption'], + enhanceOptions: { + encryption: { encryptionKey }, + }, + } + ); + + const db = enhance(); + + const create = await db.user.create({ + data: { + id: '1', + x1: 'abc123', + x2: '123abc', + }, + }); + + const read = await db.user.findUnique({ + where: { + id: '1', + }, + }); + + expect(create).toMatchObject({ x1: 'abc123', x2: '123abc' }); + expect(read).toMatchObject({ x1: 'abc123', x2: '123abc' }); }); it('Custom encryption test', async () => { @@ -65,7 +111,7 @@ describe('Encrypted test', () => { const sudoDb = enhance(undefined, { kinds: [] }); const db = enhance(undefined, { - kinds: ['encrypted'], + kinds: ['encryption'], encryption: { encrypt: async (model: string, field: FieldInfo, data: string) => { // Add _enc to the end of the input @@ -105,4 +151,104 @@ describe('Encrypted test', () => { expect(read.encrypted_value).toBe('abc123'); expect(sudoRead.encrypted_value).toBe('abc123_enc'); }); + + it('Only supports string fields', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @id @default(cuid()) + encrypted_value Bytes @encrypted() + }` + ) + ).resolves.toContain(`attribute \"@encrypted\" cannot be used on this type of field`); + }); + + it('Returns cipher text when decryption fails', async () => { + const { enhance, enhanceRaw, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + + @@allow('all', true) + }`, + { enhancements: ['encryption'] } + ); + + const db = enhance(undefined, { + kinds: ['encryption'], + encryption: { encryptionKey }, + }); + + const create = await db.user.create({ + data: { + id: '1', + encrypted_value: 'abc123', + }, + }); + expect(create.encrypted_value).toBe('abc123'); + + const db1 = enhanceRaw(prisma, undefined, { + encryption: { encryptionKey: crypto.getRandomValues(new Uint8Array(32)) }, + }); + const read = await db1.user.findUnique({ where: { id: '1' } }); + expect(read.encrypted_value).toBeTruthy(); + expect(read.encrypted_value).not.toBe('abc123'); + }); + + it('Works with length validation', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() @length(0, 6) + + @@allow('all', true) + }`, + { + enhanceOptions: { encryption: { encryptionKey } }, + } + ); + + const db = enhance(); + + const create = await db.user.create({ + data: { + id: '1', + encrypted_value: 'abc123', + }, + }); + expect(create.encrypted_value).toBe('abc123'); + + await expect( + db.user.create({ + data: { id: '2', encrypted_value: 'abc1234' }, + }) + ).toBeRejectedByPolicy(); + }); + + it('Complains when encrypted fields are used in model-level policy rules', async () => { + await expect( + loadModelWithError(` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + @@allow('all', encrypted_value != 'abc123') + } + `) + ).resolves.toContain(`Encrypted fields cannot be used in policy rules`); + }); + + it('Complains when encrypted fields are used in field-level policy rules', async () => { + await expect( + loadModelWithError(` + model User { + id String @id @default(cuid()) + encrypted_value String @encrypted() + value Int @allow('all', encrypted_value != 'abc123') + } + `) + ).resolves.toContain(`Encrypted fields cannot be used in policy rules`); + }); }); From 93246f39e24e40c47d150eccbae67572bbdbb4f3 Mon Sep 17 00:00:00 2001 From: Yiming Date: Mon, 30 Dec 2024 15:44:58 +0800 Subject: [PATCH 05/16] chore(encryption): rename encrypted to encryption (#1928) --- packages/runtime/src/enhancements/edge/encrypted.ts | 1 - packages/runtime/src/enhancements/edge/encryption.ts | 1 + packages/runtime/src/enhancements/node/create-enhancement.ts | 2 +- .../src/enhancements/node/{encrypted.ts => encryption.ts} | 0 4 files changed, 2 insertions(+), 2 deletions(-) delete mode 120000 packages/runtime/src/enhancements/edge/encrypted.ts create mode 120000 packages/runtime/src/enhancements/edge/encryption.ts rename packages/runtime/src/enhancements/node/{encrypted.ts => encryption.ts} (100%) diff --git a/packages/runtime/src/enhancements/edge/encrypted.ts b/packages/runtime/src/enhancements/edge/encrypted.ts deleted file mode 120000 index 96d88b82d..000000000 --- a/packages/runtime/src/enhancements/edge/encrypted.ts +++ /dev/null @@ -1 +0,0 @@ -../node/encrypted.ts \ No newline at end of file diff --git a/packages/runtime/src/enhancements/edge/encryption.ts b/packages/runtime/src/enhancements/edge/encryption.ts new file mode 120000 index 000000000..9931fc8ea --- /dev/null +++ b/packages/runtime/src/enhancements/edge/encryption.ts @@ -0,0 +1 @@ +../node/encryption.ts \ No newline at end of file diff --git a/packages/runtime/src/enhancements/node/create-enhancement.ts b/packages/runtime/src/enhancements/node/create-enhancement.ts index 6090f523f..07f905182 100644 --- a/packages/runtime/src/enhancements/node/create-enhancement.ts +++ b/packages/runtime/src/enhancements/node/create-enhancement.ts @@ -10,11 +10,11 @@ import type { } from '../../types'; import { withDefaultAuth } from './default-auth'; import { withDelegate } from './delegate'; +import { withEncrypted } from './encryption'; import { withJsonProcessor } from './json-processor'; import { Logger } from './logger'; import { withOmit } from './omit'; import { withPassword } from './password'; -import { withEncrypted } from './encrypted'; import { policyProcessIncludeRelationPayload, withPolicy } from './policy'; import type { PolicyDef } from './types'; diff --git a/packages/runtime/src/enhancements/node/encrypted.ts b/packages/runtime/src/enhancements/node/encryption.ts similarity index 100% rename from packages/runtime/src/enhancements/node/encrypted.ts rename to packages/runtime/src/enhancements/node/encryption.ts From f609c862ed53e722da1628b4533a38292ac1414b Mon Sep 17 00:00:00 2001 From: Yiming Date: Mon, 30 Dec 2024 17:39:59 +0800 Subject: [PATCH 06/16] fix(openapi): type "id" field according to ZModel schema type (#1929) --- .../plugins/openapi/src/rest-generator.ts | 12 +++++++-- .../openapi/tests/openapi-restful.test.ts | 26 +++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/packages/plugins/openapi/src/rest-generator.ts b/packages/plugins/openapi/src/rest-generator.ts index e6da0268b..98c6abcb6 100644 --- a/packages/plugins/openapi/src/rest-generator.ts +++ b/packages/plugins/openapi/src/rest-generator.ts @@ -906,16 +906,24 @@ export class RESTfulOpenAPIGenerator extends OpenAPIGeneratorBase { }, }; + let idFieldSchema: OAPI.SchemaObject = { type: 'string' }; + if (idFields.length === 1) { + // FIXME: JSON:API actually requires id field to be a string, + // but currently the RESTAPIHandler returns the original data + // type as declared in the ZModel schema. + idFieldSchema = this.fieldTypeToOpenAPISchema(idFields[0].type); + } + if (mode === 'create') { // 'id' is required if there's no default value const idFields = model.fields.filter((f) => isIdField(f)); if (idFields.length === 1 && !hasAttribute(idFields[0], '@default')) { - properties = { id: { type: 'string' }, ...properties }; + properties = { id: idFieldSchema, ...properties }; toplevelRequired.unshift('id'); } } else { // 'id' always required for read and update - properties = { id: { type: 'string' }, ...properties }; + properties = { id: idFieldSchema, ...properties }; toplevelRequired.unshift('id'); } diff --git a/packages/plugins/openapi/tests/openapi-restful.test.ts b/packages/plugins/openapi/tests/openapi-restful.test.ts index 8fd0880ff..51d16e888 100644 --- a/packages/plugins/openapi/tests/openapi-restful.test.ts +++ b/packages/plugins/openapi/tests/openapi-restful.test.ts @@ -84,7 +84,7 @@ model Bar { const { name: output } = tmp.fileSync({ postfix: '.yaml' }); - const options = buildOptions(model, modelFile, output, '3.1.0'); + const options = buildOptions(model, modelFile, output, specVersion); await generate(model, options, dmmf); console.log(`OpenAPI specification generated for ${specVersion}: ${output}`); @@ -324,7 +324,7 @@ model Foo { const { name: output } = tmp.fileSync({ postfix: '.yaml' }); - const options = buildOptions(model, modelFile, output, '3.1.0'); + const options = buildOptions(model, modelFile, output, specVersion); await generate(model, options, dmmf); console.log(`OpenAPI specification generated for ${specVersion}: ${output}`); @@ -340,6 +340,28 @@ model Foo { } }); + it('int field as id', async () => { + const { model, dmmf, modelFile } = await loadZModelAndDmmf(` +plugin openapi { + provider = '${normalizePath(path.resolve(__dirname, '../dist'))}' +} + +model Foo { + id Int @id @default(autoincrement()) +} + `); + + const { name: output } = tmp.fileSync({ postfix: '.yaml' }); + + const options = buildOptions(model, modelFile, output, '3.0.0'); + await generate(model, options, dmmf); + console.log(`OpenAPI specification generated: ${output}`); + await OpenAPIParser.validate(output); + + const parsed = YAML.parse(fs.readFileSync(output, 'utf-8')); + expect(parsed.components.schemas.Foo.properties.id.type).toBe('integer'); + }); + it('exposes individual fields from a compound id as attributes', async () => { const { model, dmmf, modelFile } = await loadZModelAndDmmf(` plugin openapi { From b477ad886e82f530486d769481221fae5b545c54 Mon Sep 17 00:00:00 2001 From: Yiming Date: Tue, 31 Dec 2024 14:23:25 +0800 Subject: [PATCH 07/16] fix(zmodel): check cyclic inheritance (#1931) --- .../validator/datamodel-validator.ts | 24 ++++++++++++ .../validation/cyclic-inheritance.test.ts | 39 +++++++++++++++++++ packages/sdk/src/utils.ts | 12 +++++- 3 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 packages/schema/tests/schema/validation/cyclic-inheritance.test.ts diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts index 9054c82c6..630bf0085 100644 --- a/packages/schema/src/language-server/validator/datamodel-validator.ts +++ b/packages/schema/src/language-server/validator/datamodel-validator.ts @@ -33,6 +33,10 @@ export default class DataModelValidator implements AstValidator { validateDuplicatedDeclarations(dm, getModelFieldsWithBases(dm), accept); this.validateAttributes(dm, accept); this.validateFields(dm, accept); + + if (dm.superTypes.length > 0) { + this.validateInheritance(dm, accept); + } } private validateFields(dm: DataModel, accept: ValidationAcceptor) { @@ -407,6 +411,26 @@ export default class DataModelValidator implements AstValidator { }); } } + + private validateInheritance(dm: DataModel, accept: ValidationAcceptor) { + const seen = [dm]; + const todo: DataModel[] = dm.superTypes.map((superType) => superType.ref!); + while (todo.length > 0) { + const current = todo.shift()!; + if (seen.includes(current)) { + accept( + 'error', + `Circular inheritance detected: ${seen.map((m) => m.name).join(' -> ')} -> ${current.name}`, + { + node: dm, + } + ); + return; + } + seen.push(current); + todo.push(...current.superTypes.map((superType) => superType.ref!)); + } + } } export interface MissingOppositeRelationData { diff --git a/packages/schema/tests/schema/validation/cyclic-inheritance.test.ts b/packages/schema/tests/schema/validation/cyclic-inheritance.test.ts new file mode 100644 index 000000000..494dad2be --- /dev/null +++ b/packages/schema/tests/schema/validation/cyclic-inheritance.test.ts @@ -0,0 +1,39 @@ +import { loadModelWithError } from '../../utils'; + +describe('Cyclic inheritance', () => { + it('abstract inheritance', async () => { + const errors = await loadModelWithError( + ` + abstract model A extends B {} + abstract model B extends A {} + model C extends B { + id Int @id + } + ` + ); + expect(errors).toContain('Circular inheritance detected: A -> B -> A'); + expect(errors).toContain('Circular inheritance detected: B -> A -> B'); + expect(errors).toContain('Circular inheritance detected: C -> B -> A -> B'); + }); + + it('delegate inheritance', async () => { + const errors = await loadModelWithError( + ` + model A extends B { + typeA String + @@delegate(typeA) + } + model B extends A { + typeB String + @@delegate(typeB) + } + model C extends B { + id Int @id + } + ` + ); + expect(errors).toContain('Circular inheritance detected: A -> B -> A'); + expect(errors).toContain('Circular inheritance detected: B -> A -> B'); + expect(errors).toContain('Circular inheritance detected: C -> B -> A -> B'); + }); +}); diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 46c2a82c1..ecb6895eb 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -544,8 +544,16 @@ export function getModelFieldsWithBases(model: DataModel, includeDelegate = true } } -export function getRecursiveBases(dataModel: DataModel, includeDelegate = true): DataModel[] { +export function getRecursiveBases( + dataModel: DataModel, + includeDelegate = true, + seen = new Set() +): DataModel[] { const result: DataModel[] = []; + if (seen.has(dataModel)) { + return result; + } + seen.add(dataModel); dataModel.superTypes.forEach((superType) => { const baseDecl = superType.ref; if (baseDecl) { @@ -553,7 +561,7 @@ export function getRecursiveBases(dataModel: DataModel, includeDelegate = true): return; } result.push(baseDecl); - result.push(...getRecursiveBases(baseDecl, includeDelegate)); + result.push(...getRecursiveBases(baseDecl, includeDelegate, seen)); } }); return result; From 2eecae53340c5f78a669b5017c8c0a796519f997 Mon Sep 17 00:00:00 2001 From: Yiming Date: Thu, 2 Jan 2025 10:00:39 +0800 Subject: [PATCH 08/16] fix(encryption): decrypt fields in nested read results (#1934) --- .../src/enhancements/node/encryption.ts | 28 ++++++---- .../with-encrypted/with-encrypted.test.ts | 53 ++++++++++++++++--- 2 files changed, 64 insertions(+), 17 deletions(-) diff --git a/packages/runtime/src/enhancements/node/encryption.ts b/packages/runtime/src/enhancements/node/encryption.ts index d5db66690..3d0f738d4 100644 --- a/packages/runtime/src/enhancements/node/encryption.ts +++ b/packages/runtime/src/enhancements/node/encryption.ts @@ -138,21 +138,29 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData); for (const field of getModelFields(entityData)) { - const fieldInfo = await resolveField(this.options.modelMeta, realModel, field); + // Don't decrypt null, undefined or empty string values + if (!entityData[field]) continue; + const fieldInfo = await resolveField(this.options.modelMeta, realModel, field); if (!fieldInfo) { continue; } - const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted'); - if (shouldDecrypt) { - // Don't decrypt null, undefined or empty string values - if (!entityData[field]) continue; - - try { - entityData[field] = await this.decrypt(fieldInfo, entityData[field]); - } catch (error) { - this.logger.warn(`Decryption failed, keeping original value: ${error}`); + if (fieldInfo.isDataModel) { + const items = + fieldInfo.isArray && Array.isArray(entityData[field]) ? entityData[field] : [entityData[field]]; + for (const item of items) { + // recurse + await this.doPostProcess(item, fieldInfo.type); + } + } else { + const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted'); + if (shouldDecrypt) { + try { + entityData[field] = await this.decrypt(fieldInfo, entityData[field]); + } catch (error) { + this.logger.warn(`Decryption failed, keeping original value: ${error}`); + } } } } diff --git a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts index 9b6307822..3f02c3f70 100644 --- a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts +++ b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts @@ -20,8 +20,6 @@ describe('Encrypted test', () => { model User { id String @id @default(cuid()) encrypted_value String @encrypted() - - @@allow('all', true) }`, { enhancements: ['encryption'], @@ -62,6 +60,52 @@ describe('Encrypted test', () => { expect(rawRead.encrypted_value).not.toBe('abc123'); }); + it('Decrypts nested fields', async () => { + const { enhance, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + posts Post[] + } + + model Post { + id String @id @default(cuid()) + title String @encrypted() + author User @relation(fields: [authorId], references: [id]) + authorId String + } + `, + { + enhancements: ['encryption'], + enhanceOptions: { + encryption: { encryptionKey }, + }, + } + ); + + const db = enhance(); + + const create = await db.user.create({ + data: { + id: '1', + posts: { create: { title: 'Post1' } }, + }, + include: { posts: true }, + }); + expect(create.posts[0].title).toBe('Post1'); + + const read = await db.user.findUnique({ + where: { + id: '1', + }, + include: { posts: true }, + }); + expect(read.posts[0].title).toBe('Post1'); + + const rawRead = await prisma.user.findUnique({ where: { id: '1' }, include: { posts: true } }); + expect(rawRead.posts[0].title).not.toBe('Post1'); + }); + it('Multi-field encryption test', async () => { const { enhance } = await loadSchema( ` @@ -69,8 +113,6 @@ describe('Encrypted test', () => { id String @id @default(cuid()) x1 String @encrypted() x2 String @encrypted() - - @@allow('all', true) }`, { enhancements: ['encryption'], @@ -105,8 +147,6 @@ describe('Encrypted test', () => { model User { id String @id @default(cuid()) encrypted_value String @encrypted() - - @@allow('all', true) }`); const sudoDb = enhance(undefined, { kinds: [] }); @@ -203,7 +243,6 @@ describe('Encrypted test', () => { model User { id String @id @default(cuid()) encrypted_value String @encrypted() @length(0, 6) - @@allow('all', true) }`, { From 1956bdb461858cf5e0562434f47a3678d493b142 Mon Sep 17 00:00:00 2001 From: Yiming Date: Thu, 2 Jan 2025 16:39:29 +0800 Subject: [PATCH 09/16] fix(delegate): delegate model's guards are not properly including concrete models (#1932) --- .../src/plugins/enhancer/enhance/index.ts | 17 +--- .../enhancer/policy/expression-writer.ts | 24 +++--- .../enhancer/policy/policy-guard-generator.ts | 2 + .../src/plugins/prisma/schema-generator.ts | 5 +- packages/schema/src/utils/ast-utils.ts | 28 ++++++- .../with-delegate/policy-interaction.test.ts | 80 +++++++++++++++++++ tests/regression/tests/issue-1930.test.ts | 80 +++++++++++++++++++ 7 files changed, 207 insertions(+), 29 deletions(-) create mode 100644 tests/regression/tests/issue-1930.test.ts diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index ba8c50feb..689ddaf2c 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -24,7 +24,6 @@ import { isArrayExpr, isDataModel, isGeneratorDecl, - isReferenceExpr, isTypeDef, type Model, } from '@zenstackhq/sdk/ast'; @@ -45,6 +44,7 @@ import { } from 'ts-morph'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '..'; +import { getConcreteModels, getDiscriminatorField } from '../../../utils/ast-utils'; import { execPackage } from '../../../utils/exec-utils'; import { CorePlugins, getPluginCustomOutputFolder } from '../../plugin-utils'; import { trackPrismaSchemaError } from '../../prisma'; @@ -407,9 +407,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara this.model.declarations .filter((d): d is DataModel => isDelegateModel(d)) .forEach((dm) => { - const concreteModels = this.model.declarations.filter( - (d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm) - ); + const concreteModels = getConcreteModels(dm); if (concreteModels.length > 0) { delegateInfo.push([dm, concreteModels]); } @@ -579,7 +577,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara const typeName = typeAlias.getName(); const payloadRecord = delegateInfo.find(([delegate]) => `$${delegate.name}Payload` === typeName); if (payloadRecord) { - const discriminatorDecl = this.getDiscriminatorField(payloadRecord[0]); + const discriminatorDecl = getDiscriminatorField(payloadRecord[0]); if (discriminatorDecl) { source = `${payloadRecord[1] .map( @@ -826,15 +824,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara .filter((n) => n.getName().startsWith(DELEGATE_AUX_RELATION_PREFIX)); } - private getDiscriminatorField(delegate: DataModel) { - const delegateAttr = getAttribute(delegate, '@@delegate'); - if (!delegateAttr) { - return undefined; - } - const arg = delegateAttr.args[0]?.value; - return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined; - } - private saveSourceFile(sf: SourceFile) { if (this.options.preserveTsFiles) { saveSourceFile(sf); diff --git a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts index 645e02cd1..0d792bdc1 100644 --- a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts +++ b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts @@ -839,16 +839,18 @@ export class ExpressionWriter { operation = this.options.operationContext; } - this.block(() => { - if (operation === 'postUpdate') { - // 'postUpdate' policies are not delegated to relations, just use constant `false` here - // e.g.: - // @@allow('all', check(author)) should not delegate "postUpdate" to author - this.writer.write(`${fieldRef.target.$refText}: ${FALSE}`); - } else { - const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation); - this.writer.write(`${fieldRef.target.$refText}: ${targetGuardFunc}(context, db)`); - } - }); + this.block(() => + this.writeFieldCondition(fieldRef, () => { + if (operation === 'postUpdate') { + // 'postUpdate' policies are not delegated to relations, just use constant `false` here + // e.g.: + // @@allow('all', check(author)) should not delegate "postUpdate" to author + this.writer.write(FALSE); + } else { + const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation); + this.writer.write(`${targetGuardFunc}(context, db)`); + } + }) + ); } } diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 8206f797b..9ffe41dcb 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -454,6 +454,8 @@ export class PolicyGenerator { writer: CodeBlockWriter, sourceFile: SourceFile ) { + // first handle several cases where a constant function can be used + if (kind === 'update' && allows.length === 0) { // no allow rule for 'update', policy is constant based on if there's // post-update counterpart diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 96a3b15f5..a0bde1769 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -57,6 +57,7 @@ import path from 'path'; import semver from 'semver'; import { name } from '.'; import { getStringLiteral } from '../../language-server/validator/utils'; +import { getConcreteModels } from '../../utils/ast-utils'; import { execPackage } from '../../utils/exec-utils'; import { isDefaultWithAuth } from '../enhancer/enhancer-utils'; import { @@ -320,9 +321,7 @@ export class PrismaSchemaGenerator { } // collect concrete models inheriting this model - const concreteModels = decl.$container.declarations.filter( - (d) => isDataModel(d) && d !== decl && d.superTypes.some((base) => base.ref === decl) - ); + const concreteModels = getConcreteModels(decl); // generate an optional relation field in delegate base model to each concrete model concreteModels.forEach((concrete) => { diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index a6fab7ea5..0e462547f 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -2,6 +2,7 @@ import { BinaryExpr, DataModel, DataModelAttribute, + DataModelField, Expression, InheritableNode, isBinaryExpr, @@ -9,12 +10,13 @@ import { isDataModelField, isInvocationExpr, isModel, + isReferenceExpr, isTypeDef, Model, ModelImport, TypeDef, } from '@zenstackhq/language/ast'; -import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk'; +import { getAttribute, getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk'; import { AstNode, copyAstNode, @@ -310,3 +312,27 @@ export function findUpInheritance(start: DataModel, target: DataModel): DataMode } return undefined; } + +/** + * Gets all concrete models that inherit from the given delegate model + */ +export function getConcreteModels(dataModel: DataModel): DataModel[] { + if (!isDelegateModel(dataModel)) { + return []; + } + return dataModel.$container.declarations.filter( + (d): d is DataModel => isDataModel(d) && d !== dataModel && d.superTypes.some((base) => base.ref === dataModel) + ); +} + +/** + * Gets the discriminator field for the given delegate model + */ +export function getDiscriminatorField(dataModel: DataModel) { + const delegateAttr = getAttribute(dataModel, '@@delegate'); + if (!delegateAttr) { + return undefined; + } + const arg = delegateAttr.args[0]?.value; + return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined; +} diff --git a/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts index d149a6392..67fc456af 100644 --- a/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts @@ -571,4 +571,84 @@ describe('Polymorphic Policy Test', () => { expect(foundPost2.foo).toBeUndefined(); expect(foundPost2.bar).toBeUndefined(); }); + + it('respects concrete policies when read as base optional relation', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + asset Asset? + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + type String + + @@delegate(type) + @@allow('all', true) + } + + model Post extends Asset { + title String + private Boolean + @@allow('create', true) + @@deny('read', private) + } + ` + ); + + const fullDb = enhance(undefined, { kinds: ['delegate'] }); + await fullDb.user.create({ data: { id: 1 } }); + await fullDb.post.create({ data: { title: 'Post1', private: true, user: { connect: { id: 1 } } } }); + await expect(fullDb.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({ + asset: expect.objectContaining({ type: 'Post' }), + }); + + const db = enhance(); + const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } }); + expect(read.asset).toBeTruthy(); + expect(read.asset.title).toBeUndefined(); + }); + + it('respects concrete policies when read as base required relation', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + asset Asset @relation(fields: [assetId], references: [id]) + assetId Int @unique + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + user User? + type String + + @@delegate(type) + @@allow('all', true) + } + + model Post extends Asset { + title String + private Boolean + @@deny('read', private) + } + ` + ); + + const fullDb = enhance(undefined, { kinds: ['delegate'] }); + await fullDb.post.create({ data: { id: 1, title: 'Post1', private: true, user: { create: { id: 1 } } } }); + await expect(fullDb.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({ + asset: expect.objectContaining({ type: 'Post' }), + }); + + const db = enhance(); + const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } }); + expect(read).toBeTruthy(); + expect(read.asset.title).toBeUndefined(); + }); }); diff --git a/tests/regression/tests/issue-1930.test.ts b/tests/regression/tests/issue-1930.test.ts new file mode 100644 index 000000000..762369321 --- /dev/null +++ b/tests/regression/tests/issue-1930.test.ts @@ -0,0 +1,80 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1930', () => { + it('regression', async () => { + const { enhance } = await loadSchema( + ` +model Organization { + id String @id @default(cuid()) + entities Entity[] + + @@allow('all', true) +} + +model Entity { + id String @id @default(cuid()) + org Organization? @relation(fields: [orgId], references: [id]) + orgId String? + contents EntityContent[] + entityType String + isDeleted Boolean @default(false) + + @@delegate(entityType) + + @@allow('all', !isDeleted) +} + +model EntityContent { + id String @id @default(cuid()) + entity Entity @relation(fields: [entityId], references: [id]) + entityId String + + entityContentType String + + @@delegate(entityContentType) + + @@allow('create', true) + @@allow('read', check(entity)) +} + +model Article extends Entity { +} + +model ArticleContent extends EntityContent { + body String? +} + +model OtherContent extends EntityContent { + data Int +} + ` + ); + + const fullDb = enhance(undefined, { kinds: ['delegate'] }); + const org = await fullDb.organization.create({ data: {} }); + const article = await fullDb.article.create({ + data: { org: { connect: { id: org.id } } }, + }); + + const db = enhance(); + + // normal create/read + await expect( + db.articleContent.create({ + data: { body: 'abc', entity: { connect: { id: article.id } } }, + }) + ).toResolveTruthy(); + await expect(db.article.findFirst({ include: { contents: true } })).resolves.toMatchObject({ + contents: expect.arrayContaining([expect.objectContaining({ body: 'abc' })]), + }); + + // deleted article's contents are not readable + const deletedArticle = await fullDb.article.create({ + data: { org: { connect: { id: org.id } }, isDeleted: true }, + }); + const content1 = await fullDb.articleContent.create({ + data: { body: 'bcd', entity: { connect: { id: deletedArticle.id } } }, + }); + await expect(db.articleContent.findUnique({ where: { id: content1.id } })).toResolveNull(); + }); +}); From bcbfb9ab206922e97a4d0c50c5e65b1eb0dbac2e Mon Sep 17 00:00:00 2001 From: Yiming Date: Mon, 6 Jan 2025 15:27:32 +0800 Subject: [PATCH 10/16] fix(delegate): support _count select of base fields (#1937) --- .../runtime/src/enhancements/node/delegate.ts | 145 +++++++++++++----- .../with-delegate/enhanced-client.test.ts | 60 ++++++++ tests/regression/tests/issue-1467.test.ts | 51 ++++++ 3 files changed, 217 insertions(+), 39 deletions(-) create mode 100644 tests/regression/tests/issue-1467.test.ts diff --git a/packages/runtime/src/enhancements/node/delegate.ts b/packages/runtime/src/enhancements/node/delegate.ts index 80fd09f17..06c1526e5 100644 --- a/packages/runtime/src/enhancements/node/delegate.ts +++ b/packages/runtime/src/enhancements/node/delegate.ts @@ -180,47 +180,102 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { return; } - for (const kind of ['select', 'include'] as const) { - if (args[kind] && typeof args[kind] === 'object') { - for (const [field, value] of Object.entries(args[kind])) { - const fieldInfo = resolveField(this.options.modelMeta, model, field); - if (!fieldInfo) { - continue; - } + // there're two cases where we need to inject polymorphic base hierarchy for fields + // defined in base models + // 1. base fields mentioned in select/include clause + // { select: { fieldFromBase: true } } => { select: { delegate_aux_[Base]: { fieldFromBase: true } } } + // 2. base fields mentioned in _count select/include clause + // { select: { _count: { select: { fieldFromBase: true } } } } => { select: { delegate_aux_[Base]: { select: { _count: { select: { fieldFromBase: true } } } } } } + // + // Note that although structurally similar, we need to correctly deal with different injection location of the "delegate_aux" hierarchy + + // selectors for the above two cases + const selectors = [ + // regular select: { select: { field: true } } + (payload: any) => ({ data: payload.select, kind: 'select' as const, isCount: false }), + // regular include: { include: { field: true } } + (payload: any) => ({ data: payload.include, kind: 'include' as const, isCount: false }), + // select _count: { select: { _count: { select: { field: true } } } } + (payload: any) => ({ + data: payload.select?._count?.select, + kind: 'select' as const, + isCount: true, + }), + // include _count: { include: { _count: { select: { field: true } } } } + (payload: any) => ({ + data: payload.include?._count?.select, + kind: 'include' as const, + isCount: true, + }), + ]; + + for (const selector of selectors) { + const { data, kind, isCount } = selector(args); + if (!data || typeof data !== 'object') { + continue; + } - if (this.isDelegateOrDescendantOfDelegate(fieldInfo?.type) && value) { - // delegate model, recursively inject hierarchy - if (args[kind][field]) { - if (args[kind][field] === true) { - // make sure the payload is an object - args[kind][field] = {}; - } - await this.injectSelectIncludeHierarchy(fieldInfo.type, args[kind][field]); + for (const [field, value] of Object.entries(data)) { + const fieldInfo = resolveField(this.options.modelMeta, model, field); + if (!fieldInfo) { + continue; + } + + if (this.isDelegateOrDescendantOfDelegate(fieldInfo?.type) && value) { + // delegate model, recursively inject hierarchy + if (data[field]) { + if (data[field] === true) { + // make sure the payload is an object + data[field] = {}; } + await this.injectSelectIncludeHierarchy(fieldInfo.type, data[field]); } + } - // refetch the field select/include value because it may have been - // updated during injection - const fieldValue = args[kind][field]; + // refetch the field select/include value because it may have been + // updated during injection + const fieldValue = data[field]; - if (fieldValue !== undefined) { - if (fieldValue.orderBy) { - // `orderBy` may contain fields from base types - enumerate(fieldValue.orderBy).forEach((item) => - this.injectWhereHierarchy(fieldInfo.type, item) - ); - } + if (fieldValue !== undefined) { + if (fieldValue.orderBy) { + // `orderBy` may contain fields from base types + enumerate(fieldValue.orderBy).forEach((item) => + this.injectWhereHierarchy(fieldInfo.type, item) + ); + } - if (this.injectBaseFieldSelect(model, field, fieldValue, args, kind)) { - delete args[kind][field]; - } else if (fieldInfo.isDataModel) { - let nextValue = fieldValue; - if (nextValue === true) { - // make sure the payload is an object - args[kind][field] = nextValue = {}; + let injected = false; + if (!isCount) { + // regular select/include injection + injected = await this.injectBaseFieldSelect(model, field, fieldValue, args, kind); + if (injected) { + // if injected, remove the field from the original payload + delete data[field]; + } + } else { + // _count select/include injection, inject into an empty payload and then merge to the proper location + const injectTarget = { [kind]: {} }; + injected = await this.injectBaseFieldSelect(model, field, fieldValue, injectTarget, kind, true); + if (injected) { + // if injected, remove the field from the original payload + delete data[field]; + if (Object.keys(data).length === 0) { + // if the original "_count" payload becomes empty, remove it + delete args[kind]['_count']; } - await this.injectSelectIncludeHierarchy(fieldInfo.type, nextValue); + // finally merge the injection into the original payload + const merged = deepmerge(args[kind], injectTarget[kind]); + args[kind] = merged; + } + } + + if (!injected && fieldInfo.isDataModel) { + let nextValue = fieldValue; + if (nextValue === true) { + // make sure the payload is an object + data[field] = nextValue = {}; } + await this.injectSelectIncludeHierarchy(fieldInfo.type, nextValue); } } } @@ -272,7 +327,8 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { field: string, value: any, selectInclude: any, - context: 'select' | 'include' + context: 'select' | 'include', + forCount = false // if the injection is for a "{ _count: { select: { field: true } } }" payload ) { const fieldInfo = resolveField(this.options.modelMeta, model, field); if (!fieldInfo?.inheritedFrom) { @@ -286,16 +342,12 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { const baseRelationName = this.makeAuxRelationName(base); // prepare base layer select/include - // let selectOrInclude = 'select'; let thisLayer: any; if (target.include) { - // selectOrInclude = 'include'; thisLayer = target.include; } else if (target.select) { - // selectOrInclude = 'select'; thisLayer = target.select; } else { - // selectInclude = 'include'; thisLayer = target.select = {}; } @@ -303,7 +355,22 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { if (!thisLayer[baseRelationName]) { thisLayer[baseRelationName] = { [context]: {} }; } - thisLayer[baseRelationName][context][field] = value; + if (forCount) { + // { _count: { select: { field: true } } } => { delegate_aux_[Base]: { select: { _count: { select: { field: true } } } } } + if ( + !thisLayer[baseRelationName][context]['_count'] || + typeof thisLayer[baseRelationName][context] !== 'object' + ) { + thisLayer[baseRelationName][context]['_count'] = {}; + } + thisLayer[baseRelationName][context]['_count'] = deepmerge( + thisLayer[baseRelationName][context]['_count'], + { select: { [field]: value } } + ); + } else { + // { select: { field: true } } => { delegate_aux_[Base]: { select: { field: true } } } + thisLayer[baseRelationName][context][field] = value; + } break; } else { if (!thisLayer[baseRelationName]) { diff --git a/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts b/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts index 91a385db0..7a555e0cd 100644 --- a/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts @@ -378,6 +378,66 @@ describe('Polymorphism Test', () => { ).resolves.toHaveLength(1); }); + it('read with counting relation defined in base', async () => { + const { enhance } = await loadSchema( + ` + + model A { + id Int @id @default(autoincrement()) + type String + bs B[] + cs C[] + @@delegate(type) + } + + model A1 extends A { + a1 Int + type1 String + @@delegate(type1) + } + + model A2 extends A1 { + a2 Int + } + + model B { + id Int @id @default(autoincrement()) + a A @relation(fields: [aId], references: [id]) + aId Int + b Int + } + + model C { + id Int @id @default(autoincrement()) + a A @relation(fields: [aId], references: [id]) + aId Int + c Int + } + `, + { enhancements: ['delegate'] } + ); + const db = enhance(); + + const a2 = await db.a2.create({ + data: { a1: 1, a2: 2, bs: { create: [{ b: 1 }, { b: 2 }] }, cs: { create: [{ c: 1 }] } }, + include: { _count: { select: { bs: true } } }, + }); + expect(a2).toMatchObject({ a1: 1, a2: 2, _count: { bs: 2 } }); + + await expect( + db.a2.findFirst({ select: { a1: true, _count: { select: { bs: true } } } }) + ).resolves.toStrictEqual({ + a1: 1, + _count: { bs: 2 }, + }); + + await expect(db.a.findFirst({ select: { _count: { select: { bs: true, cs: true } } } })).resolves.toMatchObject( + { + _count: { bs: 2, cs: 1 }, + } + ); + }); + it('order by base fields', async () => { const { db, user } = await setup(); diff --git a/tests/regression/tests/issue-1467.test.ts b/tests/regression/tests/issue-1467.test.ts new file mode 100644 index 000000000..374313e45 --- /dev/null +++ b/tests/regression/tests/issue-1467.test.ts @@ -0,0 +1,51 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1467', () => { + it('regression', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + type String + @@allow('all', true) + } + + model Container { + id Int @id @default(autoincrement()) + drink Drink @relation(fields: [drinkId], references: [id]) + drinkId Int + @@allow('all', true) + } + + model Drink { + id Int @id @default(autoincrement()) + name String @unique + containers Container[] + type String + + @@delegate(type) + @@allow('all', true) + } + + model Beer extends Drink { + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + await db.beer.create({ + data: { id: 1, name: 'Beer1' }, + }); + + await db.container.create({ data: { drink: { connect: { id: 1 } } } }); + await db.container.create({ data: { drink: { connect: { id: 1 } } } }); + + const beers = await db.beer.findFirst({ + select: { id: true, name: true, _count: { select: { containers: true } } }, + orderBy: { name: 'asc' }, + }); + expect(beers).toMatchObject({ _count: { containers: 2 } }); + }); +}); From da08eed477fdc5b86a6e5c4e68789138482476f6 Mon Sep 17 00:00:00 2001 From: Yiming Date: Mon, 6 Jan 2025 18:45:55 +0800 Subject: [PATCH 11/16] chore: misc CLI fixes (#1939) --- packages/schema/src/cli/actions/generate.ts | 4 ++-- packages/schema/src/cli/actions/info.ts | 4 +++- packages/schema/src/cli/actions/init.ts | 3 +-- packages/schema/src/cli/cli-util.ts | 8 ++++---- packages/schema/src/cli/index.ts | 2 +- packages/schema/src/utils/version-utils.ts | 2 +- 6 files changed, 12 insertions(+), 11 deletions(-) diff --git a/packages/schema/src/cli/actions/generate.ts b/packages/schema/src/cli/actions/generate.ts index d697504ee..229a9ddd8 100644 --- a/packages/schema/src/cli/actions/generate.ts +++ b/packages/schema/src/cli/actions/generate.ts @@ -37,8 +37,8 @@ export async function generate(projectPath: string, options: Options) { // check for multiple versions of Zenstack packages const packages = getZenStackPackages(projectPath); - if (packages) { - const versions = new Set(packages.map((p) => p.version)); + if (packages.length > 0) { + const versions = new Set(packages.map((p) => p.version).filter((v): v is string => !!v)); if (versions.size > 1) { console.warn( colors.yellow( diff --git a/packages/schema/src/cli/actions/info.ts b/packages/schema/src/cli/actions/info.ts index dddef9e27..c212babf4 100644 --- a/packages/schema/src/cli/actions/info.ts +++ b/packages/schema/src/cli/actions/info.ts @@ -16,7 +16,9 @@ export async function info(projectPath: string) { console.log('Installed ZenStack Packages:'); const versions = new Set(); for (const { pkg, version } of packages) { - versions.add(version); + if (version) { + versions.add(version); + } console.log(` ${colors.green(pkg.padEnd(20))}\t${version}`); } diff --git a/packages/schema/src/cli/actions/init.ts b/packages/schema/src/cli/actions/init.ts index 5790997e6..1016d61a9 100644 --- a/packages/schema/src/cli/actions/init.ts +++ b/packages/schema/src/cli/actions/init.ts @@ -63,8 +63,7 @@ export async function init(projectPath: string, options: Options) { if (sampleModelGenerated) { console.log(`Sample model generated at: ${colors.blue(zmodelFile)} -Please check the following guide on how to model your app: - https://zenstack.dev/#/modeling-your-app.`); +Learn how to use ZenStack: https://zenstack.dev/docs.`); } else if (prismaSchema) { console.log( `Your current Prisma schema "${prismaSchema}" has been copied to "${zmodelFile}". diff --git a/packages/schema/src/cli/cli-util.ts b/packages/schema/src/cli/cli-util.ts index b822f75ee..54ac123bd 100644 --- a/packages/schema/src/cli/cli-util.ts +++ b/packages/schema/src/cli/cli-util.ts @@ -227,13 +227,13 @@ export async function getPluginDocuments(services: ZModelServices, fileName: str return result; } -export function getZenStackPackages(projectPath: string) { +export function getZenStackPackages(projectPath: string): Array<{ pkg: string; version: string | undefined }> { let pkgJson: { dependencies: Record; devDependencies: Record }; const resolvedPath = path.resolve(projectPath); try { pkgJson = require(path.join(resolvedPath, 'package.json')); } catch { - return undefined; + return []; } const packages = [ @@ -245,7 +245,7 @@ export function getZenStackPackages(projectPath: string) { try { const resolved = require.resolve(`${pkg}/package.json`, { paths: [resolvedPath] }); // eslint-disable-next-line @typescript-eslint/no-var-requires - return { pkg, version: require(resolved).version }; + return { pkg, version: require(resolved).version as string }; } catch { return { pkg, version: undefined }; } @@ -286,7 +286,7 @@ export async function checkNewVersion() { return; } - if (latestVersion && semver.gt(latestVersion, currVersion)) { + if (latestVersion && currVersion && semver.gt(latestVersion, currVersion)) { console.log(`A newer version ${colors.cyan(latestVersion)} is available.`); } } diff --git a/packages/schema/src/cli/index.ts b/packages/schema/src/cli/index.ts index c58db8c43..62084ce9b 100644 --- a/packages/schema/src/cli/index.ts +++ b/packages/schema/src/cli/index.ts @@ -73,7 +73,7 @@ export const checkAction = async (options: Parameters[1]): export function createProgram() { const program = new Command('zenstack'); - program.version(getVersion(), '-v --version', 'display CLI version'); + program.version(getVersion()!, '-v --version', 'display CLI version'); const schemaExtensions = ZModelLanguageMetaData.fileExtensions.join(', '); diff --git a/packages/schema/src/utils/version-utils.ts b/packages/schema/src/utils/version-utils.ts index 0e2de705d..3a2daae57 100644 --- a/packages/schema/src/utils/version-utils.ts +++ b/packages/schema/src/utils/version-utils.ts @@ -1,5 +1,5 @@ /* eslint-disable @typescript-eslint/no-var-requires */ -export function getVersion() { +export function getVersion(): string | undefined { try { return require('../package.json').version; } catch { From 3ee50d35a2c991dcf53d2bd92d34106c442ccc92 Mon Sep 17 00:00:00 2001 From: Yiming Date: Mon, 6 Jan 2025 19:46:48 +0800 Subject: [PATCH 12/16] fix(server): return an object without primary data for delete route (#1938) --- packages/server/src/api/rest/index.ts | 4 ++-- packages/server/tests/adapter/express.test.ts | 2 +- packages/server/tests/adapter/fastify.test.ts | 2 +- packages/server/tests/adapter/hono.test.ts | 2 +- packages/server/tests/adapter/next.test.ts | 2 +- packages/server/tests/adapter/sveltekit.test.ts | 2 +- packages/server/tests/api/rest.test.ts | 8 ++++---- 7 files changed, 11 insertions(+), 11 deletions(-) diff --git a/packages/server/src/api/rest/index.ts b/packages/server/src/api/rest/index.ts index 1107fbc64..16de93637 100644 --- a/packages/server/src/api/rest/index.ts +++ b/packages/server/src/api/rest/index.ts @@ -1103,8 +1103,8 @@ class RequestHandler extends APIHandlerBase { where: this.makePrismaIdFilter(typeInfo.idFields, resourceId), }); return { - status: 204, - body: undefined, + status: 200, + body: { meta: {} }, }; } diff --git a/packages/server/tests/adapter/express.test.ts b/packages/server/tests/adapter/express.test.ts index 0627990e7..85ccc8a21 100644 --- a/packages/server/tests/adapter/express.test.ts +++ b/packages/server/tests/adapter/express.test.ts @@ -190,7 +190,7 @@ describe('Express adapter tests - rest handler', () => { expect(r.body.data.attributes.email).toBe('user1@def.com'); r = await request(app).delete(makeUrl('/api/user/user1')); - expect(r.status).toBe(204); + expect(r.status).toBe(200); expect(await prisma.user.findMany()).toHaveLength(0); }); }); diff --git a/packages/server/tests/adapter/fastify.test.ts b/packages/server/tests/adapter/fastify.test.ts index f03066e4f..ed4da3c72 100644 --- a/packages/server/tests/adapter/fastify.test.ts +++ b/packages/server/tests/adapter/fastify.test.ts @@ -233,7 +233,7 @@ describe('Fastify adapter tests - rest handler', () => { expect(r.json().data.attributes.email).toBe('user1@def.com'); r = await app.inject({ method: 'DELETE', url: '/api/user/user1' }); - expect(r.statusCode).toBe(204); + expect(r.statusCode).toBe(200); expect(await prisma.user.findMany()).toHaveLength(0); }); }); diff --git a/packages/server/tests/adapter/hono.test.ts b/packages/server/tests/adapter/hono.test.ts index 3fc1bb9da..fc55e1647 100644 --- a/packages/server/tests/adapter/hono.test.ts +++ b/packages/server/tests/adapter/hono.test.ts @@ -167,7 +167,7 @@ describe('Hono adapter tests - rest handler', () => { expect((await unmarshal(r)).data.attributes.email).toBe('user1@def.com'); r = await handler(makeRequest('DELETE', makeUrl(makeUrl('/api/user/user1')))); - expect(r.status).toBe(204); + expect(r.status).toBe(200); expect(await prisma.user.findMany()).toHaveLength(0); }); }); diff --git a/packages/server/tests/adapter/next.test.ts b/packages/server/tests/adapter/next.test.ts index b8652de7c..733b30ade 100644 --- a/packages/server/tests/adapter/next.test.ts +++ b/packages/server/tests/adapter/next.test.ts @@ -307,7 +307,7 @@ model M { expect(resp.body.data.attributes.value).toBe(2); }); - await makeTestClient('/m/1', options).del('/').expect(204); + await makeTestClient('/m/1', options).del('/').expect(200); expect(await prisma.m.count()).toBe(0); }); }); diff --git a/packages/server/tests/adapter/sveltekit.test.ts b/packages/server/tests/adapter/sveltekit.test.ts index d9663a2b6..650f89f85 100644 --- a/packages/server/tests/adapter/sveltekit.test.ts +++ b/packages/server/tests/adapter/sveltekit.test.ts @@ -164,7 +164,7 @@ describe('SvelteKit adapter tests - rest handler', () => { expect((await unmarshal(r)).data.attributes.email).toBe('user1@def.com'); r = await handler(makeRequest('DELETE', makeUrl(makeUrl('/api/user/user1')))); - expect(r.status).toBe(204); + expect(r.status).toBe(200); expect(await prisma.user.findMany()).toHaveLength(0); }); }); diff --git a/packages/server/tests/api/rest.test.ts b/packages/server/tests/api/rest.test.ts index 2a59e6067..b36755055 100644 --- a/packages/server/tests/api/rest.test.ts +++ b/packages/server/tests/api/rest.test.ts @@ -2340,8 +2340,8 @@ describe('REST server tests', () => { prisma, }); - expect(r.status).toBe(204); - expect(r.body).toBeUndefined(); + expect(r.status).toBe(200); + expect(r.body).toMatchObject({ meta: {} }); }); it('deletes an item with compound id', async () => { @@ -2355,8 +2355,8 @@ describe('REST server tests', () => { path: `/postLike/1${idDivider}user1`, prisma, }); - expect(r.status).toBe(204); - expect(r.body).toBeUndefined(); + expect(r.status).toBe(200); + expect(r.body).toMatchObject({ meta: {} }); }); it('returns 404 if the user does not exist', async () => { From 00c19829b7e76e857843401532446ef894767e42 Mon Sep 17 00:00:00 2001 From: Yiming Date: Mon, 6 Jan 2025 21:57:59 +0800 Subject: [PATCH 13/16] fix(delegate): don't generate both "@@schema" attributes from base and sub models (#1940) --- packages/schema/src/utils/ast-utils.ts | 17 +++++- tests/regression/tests/issue-1647.test.ts | 69 +++++++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 tests/regression/tests/issue-1647.test.ts diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 0e462547f..f59ee7faa 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -16,7 +16,14 @@ import { ModelImport, TypeDef, } from '@zenstackhq/language/ast'; -import { getAttribute, getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk'; +import { + getAttribute, + getInheritanceChain, + getRecursiveBases, + hasAttribute, + isDelegateModel, + isFromStdlib, +} from '@zenstackhq/sdk'; import { AstNode, copyAstNode, @@ -96,6 +103,9 @@ function filterBaseAttribute(forModel: DataModel, base: DataModel, attr: DataMod // uninheritable attributes for delegate inheritance (they reference fields from the base) const uninheritableFromDelegateAttributes = ['@@unique', '@@index', '@@fulltext']; + // attributes that are inherited but can be overridden + const overrideAttributes = ['@@schema']; + if (uninheritableAttributes.includes(attr.decl.$refText)) { return false; } @@ -109,6 +119,11 @@ function filterBaseAttribute(forModel: DataModel, base: DataModel, attr: DataMod return false; } + if (hasAttribute(forModel, attr.decl.$refText) && overrideAttributes.includes(attr.decl.$refText)) { + // don't inherit an attribute if it's overridden in the sub model + return false; + } + return true; } diff --git a/tests/regression/tests/issue-1647.test.ts b/tests/regression/tests/issue-1647.test.ts new file mode 100644 index 000000000..e93f63cfb --- /dev/null +++ b/tests/regression/tests/issue-1647.test.ts @@ -0,0 +1,69 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import fs from 'fs'; + +describe('issue 1647', () => { + it('inherits @@schema by default', async () => { + const { projectDir } = await loadSchema( + ` + datasource db { + provider = 'postgresql' + url = env('DATABASE_URL') + schemas = ['public', 'post'] + } + + generator client { + provider = 'prisma-client-js' + previewFeatures = ['multiSchema'] + } + + model Asset { + id Int @id + type String + @@delegate(type) + @@schema('public') + } + + model Post extends Asset { + title String + } + `, + { addPrelude: false, pushDb: false, getPrismaOnly: true } + ); + + const prismaSchema = fs.readFileSync(`${projectDir}/prisma/schema.prisma`, 'utf-8'); + expect(prismaSchema.split('\n').filter((l) => l.includes('@@schema("public")'))).toHaveLength(2); + }); + it('respects sub model @@schema overrides', async () => { + const { projectDir } = await loadSchema( + ` + datasource db { + provider = 'postgresql' + url = env('DATABASE_URL') + schemas = ['public', 'post'] + } + + generator client { + provider = 'prisma-client-js' + previewFeatures = ['multiSchema'] + } + + model Asset { + id Int @id + type String + @@delegate(type) + @@schema('public') + } + + model Post extends Asset { + title String + @@schema('post') + } + `, + { addPrelude: false, pushDb: false, getPrismaOnly: true } + ); + + const prismaSchema = fs.readFileSync(`${projectDir}/prisma/schema.prisma`, 'utf-8'); + expect(prismaSchema.split('\n').filter((l) => l.includes('@@schema("public")'))).toHaveLength(1); + expect(prismaSchema.split('\n').filter((l) => l.includes('@@schema("post")'))).toHaveLength(1); + }); +}); From 7ed98415478ecab2a13ae175c5065b4c64d2ff3a Mon Sep 17 00:00:00 2001 From: Yiming Date: Tue, 7 Jan 2025 14:45:44 +0800 Subject: [PATCH 14/16] feat(encryption): support providing multiple decryption keys for key rotation (#1942) --- .../src/enhancements/node/encryption.ts | 161 ++++++++++++++---- packages/runtime/src/types.ts | 33 +++- .../with-encrypted/with-encrypted.test.ts | 75 +++++++- 3 files changed, 231 insertions(+), 38 deletions(-) diff --git a/packages/runtime/src/enhancements/node/encryption.ts b/packages/runtime/src/enhancements/node/encryption.ts index 3d0f738d4..65666d8cd 100644 --- a/packages/runtime/src/enhancements/node/encryption.ts +++ b/packages/runtime/src/enhancements/node/encryption.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-unused-vars */ +import { z } from 'zod'; import { FieldInfo, NestedWriteVisitor, @@ -37,6 +38,24 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { private encoder = new TextEncoder(); private decoder = new TextDecoder(); private logger: Logger; + private encryptionKey: CryptoKey | undefined; + private encryptionKeyDigest: string | undefined; + private decryptionKeys: Array<{ key: CryptoKey; digest: string }> = []; + private encryptionMetaSchema = z.object({ + // version + v: z.number(), + // algorithm + a: z.string(), + // key digest + k: z.string(), + }); + + // constants + private readonly ENCRYPTION_KEY_BYTES = 32; + private readonly IV_BYTES = 12; + private readonly ALGORITHM = 'AES-GCM'; + private readonly ENCRYPTER_VERSION = 1; + private readonly KEY_DIGEST_BYTES = 8; constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { super(prisma, model, options); @@ -44,49 +63,102 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { this.queryUtils = new QueryUtils(prisma, options); this.logger = new Logger(prisma); - if (!options.encryption) throw this.queryUtils.unknownError('Encryption options must be provided'); + if (!options.encryption) { + throw this.queryUtils.unknownError('Encryption options must be provided'); + } if (this.isCustomEncryption(options.encryption!)) { - if (!options.encryption.encrypt || !options.encryption.decrypt) + if (!options.encryption.encrypt || !options.encryption.decrypt) { throw this.queryUtils.unknownError('Custom encryption must provide encrypt and decrypt functions'); + } } else { - if (!options.encryption.encryptionKey) + if (!options.encryption.encryptionKey) { throw this.queryUtils.unknownError('Encryption key must be provided'); - if (options.encryption.encryptionKey.length !== 32) - throw this.queryUtils.unknownError('Encryption key must be 32 bytes'); + } + if (options.encryption.encryptionKey.length !== this.ENCRYPTION_KEY_BYTES) { + throw this.queryUtils.unknownError(`Encryption key must be ${this.ENCRYPTION_KEY_BYTES} bytes`); + } } } - private async getKey(secret: Uint8Array): Promise { - return crypto.subtle.importKey('raw', secret, 'AES-GCM', false, ['encrypt', 'decrypt']); - } - private isCustomEncryption(encryption: CustomEncryption | SimpleEncryption): encryption is CustomEncryption { return 'encrypt' in encryption && 'decrypt' in encryption; } + private async loadKey(key: Uint8Array, keyUsages: KeyUsage[]): Promise { + return crypto.subtle.importKey('raw', key, this.ALGORITHM, false, keyUsages); + } + + private async computeKeyDigest(key: Uint8Array) { + const rawDigest = await crypto.subtle.digest('SHA-256', key); + return new Uint8Array(rawDigest.slice(0, this.KEY_DIGEST_BYTES)).reduce( + (acc, byte) => acc + byte.toString(16).padStart(2, '0'), + '' + ); + } + + private async getEncryptionKey(): Promise { + if (this.isCustomEncryption(this.options.encryption!)) { + throw new Error('Unexpected custom encryption settings'); + } + if (!this.encryptionKey) { + this.encryptionKey = await this.loadKey(this.options.encryption!.encryptionKey, ['encrypt', 'decrypt']); + } + return this.encryptionKey; + } + + private async getEncryptionKeyDigest() { + if (this.isCustomEncryption(this.options.encryption!)) { + throw new Error('Unexpected custom encryption settings'); + } + if (!this.encryptionKeyDigest) { + this.encryptionKeyDigest = await this.computeKeyDigest(this.options.encryption!.encryptionKey); + } + return this.encryptionKeyDigest; + } + + private async findDecryptionKeys(keyDigest: string): Promise { + if (this.isCustomEncryption(this.options.encryption!)) { + throw new Error('Unexpected custom encryption settings'); + } + + if (this.decryptionKeys.length === 0) { + const keys = [this.options.encryption!.encryptionKey, ...(this.options.encryption!.decryptionKeys || [])]; + this.decryptionKeys = await Promise.all( + keys.map(async (key) => ({ + key: await this.loadKey(key, ['decrypt']), + digest: await this.computeKeyDigest(key), + })) + ); + } + + return this.decryptionKeys.filter((entry) => entry.digest === keyDigest).map((entry) => entry.key); + } + private async encrypt(field: FieldInfo, data: string): Promise { if (this.isCustomEncryption(this.options.encryption!)) { return this.options.encryption.encrypt(this.model, field, data); } - const key = await this.getKey(this.options.encryption!.encryptionKey); - const iv = crypto.getRandomValues(new Uint8Array(12)); - + const key = await this.getEncryptionKey(); + const iv = crypto.getRandomValues(new Uint8Array(this.IV_BYTES)); const encrypted = await crypto.subtle.encrypt( { - name: 'AES-GCM', + name: this.ALGORITHM, iv, }, key, this.encoder.encode(data) ); - // Combine IV and encrypted data into a single array of bytes - const bytes = [...iv, ...new Uint8Array(encrypted)]; + // combine IV and encrypted data into a single array of bytes + const cipherBytes = [...iv, ...new Uint8Array(encrypted)]; + + // encryption metadata + const meta = { v: this.ENCRYPTER_VERSION, a: this.ALGORITHM, k: await this.getEncryptionKeyDigest() }; - // Convert bytes to base64 string - return btoa(String.fromCharCode(...bytes)); + // convert concatenated result to base64 string + return `${btoa(JSON.stringify(meta))}.${btoa(String.fromCharCode(...cipherBytes))}`; } private async decrypt(field: FieldInfo, data: string): Promise { @@ -94,22 +166,47 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { return this.options.encryption.decrypt(this.model, field, data); } - const key = await this.getKey(this.options.encryption!.encryptionKey); + const [metaText, cipherText] = data.split('.'); + if (!metaText || !cipherText) { + throw new Error('Malformed encrypted data'); + } - // Convert base64 back to bytes - const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0)); + let metaObj: unknown; + try { + metaObj = JSON.parse(atob(metaText)); + } catch (error) { + throw new Error('Malformed metadata'); + } - // First 12 bytes are IV, rest is encrypted data - const decrypted = await crypto.subtle.decrypt( - { - name: 'AES-GCM', - iv: bytes.slice(0, 12), - }, - key, - bytes.slice(12) - ); + // parse meta + const { a: algorithm, k: keyDigest } = this.encryptionMetaSchema.parse(metaObj); + + // find a matching decryption key + const keys = await this.findDecryptionKeys(keyDigest); + if (keys.length === 0) { + throw new Error('No matching decryption key found'); + } + + // convert base64 back to bytes + const bytes = Uint8Array.from(atob(cipherText), (c) => c.charCodeAt(0)); + + // extract IV from the head + const iv = bytes.slice(0, this.IV_BYTES); + const cipher = bytes.slice(this.IV_BYTES); + let lastError: unknown; + + for (const key of keys) { + let decrypted: ArrayBuffer; + try { + decrypted = await crypto.subtle.decrypt({ name: algorithm, iv }, key, cipher); + } catch (err) { + lastError = err; + continue; + } + return this.decoder.decode(decrypted); + } - return this.decoder.decode(decrypted); + throw lastError; } // base override @@ -138,7 +235,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData); for (const field of getModelFields(entityData)) { - // Don't decrypt null, undefined or empty string values + // don't decrypt null, undefined or empty string values if (!entityData[field]) continue; const fieldInfo = await resolveField(this.options.modelMeta, realModel, field); @@ -169,7 +266,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { const visitor = new NestedWriteVisitor(this.options.modelMeta, { field: async (field, _action, data, context) => { - // Don't encrypt null, undefined or empty string values + // don't encrypt null, undefined or empty string values if (!data) return; const encAttr = field.attributes?.find((attr) => attr.name === '@encrypted'); diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index 012c94699..fe31a5058 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -173,9 +173,38 @@ export type ZodSchemas = { input?: Record>; }; +/** + * Simple encryption settings for processing fields marked with `@encrypted`. + */ +export type SimpleEncryption = { + /** + * The encryption key. + */ + encryptionKey: Uint8Array; + + /** + * Optional list of all decryption keys that were previously used to encrypt the data + * , for supporting key rotation. The `encryptionKey` field value is automatically + * included for decryption. + * + * When the encrypted data is persisted, a metadata object containing the digest of the + * encryption key is stored alongside the data. This digest is used to quickly determine + * the correct decryption key to use when reading the data. + */ + decryptionKeys?: Uint8Array[]; +}; + +/** + * Custom encryption settings for processing fields marked with `@encrypted`. + */ export type CustomEncryption = { + /** + * Encryption function. + */ encrypt: (model: string, field: FieldInfo, plain: string) => Promise; + + /** + * Decryption function + */ decrypt: (model: string, field: FieldInfo, cipher: string) => Promise; }; - -export type SimpleEncryption = { encryptionKey: Uint8Array }; diff --git a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts index 3f02c3f70..71ccd0323 100644 --- a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts +++ b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts @@ -143,13 +143,12 @@ describe('Encrypted test', () => { }); it('Custom encryption test', async () => { - const { enhance } = await loadSchema(` + const { enhance, prisma } = await loadSchema(` model User { id String @id @default(cuid()) encrypted_value String @encrypted() }`); - const sudoDb = enhance(undefined, { kinds: [] }); const db = enhance(undefined, { kinds: ['encryption'], encryption: { @@ -181,7 +180,7 @@ describe('Encrypted test', () => { }, }); - const sudoRead = await sudoDb.user.findUnique({ + const rawRead = await prisma.user.findUnique({ where: { id: '1', }, @@ -189,7 +188,75 @@ describe('Encrypted test', () => { expect(create.encrypted_value).toBe('abc123'); expect(read.encrypted_value).toBe('abc123'); - expect(sudoRead.encrypted_value).toBe('abc123_enc'); + expect(rawRead.encrypted_value).toBe('abc123_enc'); + }); + + it('Works with multiple decryption keys', async () => { + const { enhanceRaw: enhance, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + secret String @encrypted() + }` + ); + + const key1 = crypto.getRandomValues(new Uint8Array(32)); + const key2 = crypto.getRandomValues(new Uint8Array(32)); + + const db1 = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: key1 }, + }); + const user1 = await db1.user.create({ data: { secret: 'user1' } }); + + const db2 = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: key2 }, + }); + const user2 = await db2.user.create({ data: { secret: 'user2' } }); + + const dbAll = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: crypto.getRandomValues(new Uint8Array(32)), decryptionKeys: [key1, key2] }, + }); + const allUsers = await dbAll.user.findMany(); + expect(allUsers).toEqual(expect.arrayContaining([user1, user2])); + + const dbWithEncryptionKeyExplicitlyProvided = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: key1, decryptionKeys: [key1, key2] }, + }); + await expect(dbWithEncryptionKeyExplicitlyProvided.user.findMany()).resolves.toEqual( + expect.arrayContaining([user1, user2]) + ); + + const dbWithDuplicatedKeys = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: key1, decryptionKeys: [key1, key1, key2, key2] }, + }); + await expect(dbWithDuplicatedKeys.user.findMany()).resolves.toEqual(expect.arrayContaining([user1, user2])); + + const dbWithInvalidKeys = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: key1, decryptionKeys: [key2, crypto.getRandomValues(new Uint8Array(32))] }, + }); + await expect(dbWithInvalidKeys.user.findMany()).resolves.toEqual(expect.arrayContaining([user1, user2])); + + const dbWithMissingKeys = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: key2 }, + }); + const found = await dbWithMissingKeys.user.findMany(); + expect(found).not.toContainEqual(user1); + expect(found).toContainEqual(user2); + + const dbWithAllWrongKeys = enhance(prisma, undefined, { + kinds: ['encryption'], + encryption: { encryptionKey: crypto.getRandomValues(new Uint8Array(32)) }, + }); + const found1 = await dbWithAllWrongKeys.user.findMany(); + expect(found1).not.toContainEqual(user1); + expect(found1).not.toContainEqual(user2); }); it('Only supports string fields', async () => { From 96d0ce502154a0216f28d444ffc45b00c7f9f741 Mon Sep 17 00:00:00 2001 From: Yiming Date: Tue, 7 Jan 2025 16:47:42 +0800 Subject: [PATCH 15/16] fix(encryption): fixes for `createMany` and `createManyAndReturn` operations (#1944) --- packages/runtime/src/constants.ts | 12 ++ .../runtime/src/cross/nested-write-visitor.ts | 48 +++---- .../src/enhancements/node/default-auth.ts | 11 +- .../src/enhancements/node/encryption.ts | 4 +- .../runtime/src/enhancements/node/password.ts | 5 +- .../with-encrypted/with-encrypted.test.ts | 118 ++++++++++++++++++ .../with-password/with-password.test.ts | 23 +++- 7 files changed, 183 insertions(+), 38 deletions(-) diff --git a/packages/runtime/src/constants.ts b/packages/runtime/src/constants.ts index 36acf8c83..495e1853d 100644 --- a/packages/runtime/src/constants.ts +++ b/packages/runtime/src/constants.ts @@ -67,3 +67,15 @@ export const PRISMA_MINIMUM_VERSION = '5.0.0'; * Prefix for auxiliary relation field generated for delegated models */ export const DELEGATE_AUX_RELATION_PREFIX = 'delegate_aux'; + +/** + * Prisma actions that can have a write payload + */ +export const ACTIONS_WITH_WRITE_PAYLOAD = [ + 'create', + 'createMany', + 'createManyAndReturn', + 'update', + 'updateMany', + 'upsert', +]; diff --git a/packages/runtime/src/cross/nested-write-visitor.ts b/packages/runtime/src/cross/nested-write-visitor.ts index c69f9d203..ba4b232a6 100644 --- a/packages/runtime/src/cross/nested-write-visitor.ts +++ b/packages/runtime/src/cross/nested-write-visitor.ts @@ -4,7 +4,7 @@ import type { FieldInfo, ModelMeta } from './model-meta'; import { resolveField } from './model-meta'; import { MaybePromise, PrismaWriteActionType, PrismaWriteActions } from './types'; -import { getModelFields } from './utils'; +import { enumerate, getModelFields } from './utils'; type NestingPathItem = { field?: FieldInfo; model: string; where: any; unique: boolean }; @@ -310,31 +310,33 @@ export class NestedWriteVisitor { payload: any, nestingPath: NestingPathItem[] ) { - for (const field of getModelFields(payload)) { - const fieldInfo = resolveField(this.modelMeta, model, field); - if (!fieldInfo) { - continue; - } + for (const item of enumerate(payload)) { + for (const field of getModelFields(item)) { + const fieldInfo = resolveField(this.modelMeta, model, field); + if (!fieldInfo) { + continue; + } - if (fieldInfo.isDataModel) { - if (payload[field]) { - // recurse into nested payloads - for (const [subAction, subData] of Object.entries(payload[field])) { - if (this.isPrismaWriteAction(subAction) && subData) { - await this.doVisit(fieldInfo.type, subAction, subData, payload[field], fieldInfo, [ - ...nestingPath, - ]); + if (fieldInfo.isDataModel) { + if (item[field]) { + // recurse into nested payloads + for (const [subAction, subData] of Object.entries(item[field])) { + if (this.isPrismaWriteAction(subAction) && subData) { + await this.doVisit(fieldInfo.type, subAction, subData, item[field], fieldInfo, [ + ...nestingPath, + ]); + } } } - } - } else { - // visit plain field - if (this.callback.field) { - await this.callback.field(fieldInfo, action, payload[field], { - parent: payload, - nestingPath, - field: fieldInfo, - }); + } else { + // visit plain field + if (this.callback.field) { + await this.callback.field(fieldInfo, action, item[field], { + parent: item, + nestingPath, + field: fieldInfo, + }); + } } } } diff --git a/packages/runtime/src/enhancements/node/default-auth.ts b/packages/runtime/src/enhancements/node/default-auth.ts index 03ce3750c..e6162a2d2 100644 --- a/packages/runtime/src/enhancements/node/default-auth.ts +++ b/packages/runtime/src/enhancements/node/default-auth.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable @typescript-eslint/no-explicit-any */ +import { ACTIONS_WITH_WRITE_PAYLOAD } from '../../constants'; import { FieldInfo, NestedWriteVisitor, @@ -50,15 +51,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { // base override protected async preprocessArgs(action: PrismaProxyActions, args: any) { - const actionsOfInterest: PrismaProxyActions[] = [ - 'create', - 'createMany', - 'createManyAndReturn', - 'update', - 'updateMany', - 'upsert', - ]; - if (actionsOfInterest.includes(action)) { + if (args && ACTIONS_WITH_WRITE_PAYLOAD.includes(action)) { const newArgs = await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); return newArgs; } diff --git a/packages/runtime/src/enhancements/node/encryption.ts b/packages/runtime/src/enhancements/node/encryption.ts index 65666d8cd..42001fc16 100644 --- a/packages/runtime/src/enhancements/node/encryption.ts +++ b/packages/runtime/src/enhancements/node/encryption.ts @@ -2,6 +2,7 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ import { z } from 'zod'; +import { ACTIONS_WITH_WRITE_PAYLOAD } from '../../constants'; import { FieldInfo, NestedWriteVisitor, @@ -211,8 +212,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { // base override protected async preprocessArgs(action: PrismaProxyActions, args: any) { - const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; - if (args && args.data && actionsOfInterest.includes(action)) { + if (args && ACTIONS_WITH_WRITE_PAYLOAD.includes(action)) { await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); } return args; diff --git a/packages/runtime/src/enhancements/node/password.ts b/packages/runtime/src/enhancements/node/password.ts index 8c1aeb959..a2fdae42c 100644 --- a/packages/runtime/src/enhancements/node/password.ts +++ b/packages/runtime/src/enhancements/node/password.ts @@ -1,7 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-unused-vars */ -import { DEFAULT_PASSWORD_SALT_LENGTH } from '../../constants'; +import { ACTIONS_WITH_WRITE_PAYLOAD, DEFAULT_PASSWORD_SALT_LENGTH } from '../../constants'; import { NestedWriteVisitor, type PrismaWriteActionType } from '../../cross'; import { DbClientContract } from '../../types'; import { InternalEnhancementOptions } from './create-enhancement'; @@ -39,8 +39,7 @@ class PasswordHandler extends DefaultPrismaProxyHandler { // base override protected async preprocessArgs(action: PrismaProxyActions, args: any) { - const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; - if (args && args.data && actionsOfInterest.includes(action)) { + if (args && ACTIONS_WITH_WRITE_PAYLOAD.includes(action)) { await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); } return args; diff --git a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts index 71ccd0323..71d32769f 100644 --- a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts +++ b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts @@ -58,6 +58,124 @@ describe('Encrypted test', () => { expect(read.encrypted_value).toBe('abc123'); expect(sudoRead.encrypted_value).not.toBe('abc123'); expect(rawRead.encrypted_value).not.toBe('abc123'); + + // update + const updated = await db.user.update({ + where: { id: '1' }, + data: { encrypted_value: 'abc234' }, + }); + expect(updated.encrypted_value).toBe('abc234'); + await expect(db.user.findUnique({ where: { id: '1' } })).resolves.toMatchObject({ + encrypted_value: 'abc234', + }); + await expect(prisma.user.findUnique({ where: { id: '1' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc234', + }); + + // upsert with create + const upsertCreate = await db.user.upsert({ + where: { id: '2' }, + create: { + id: '2', + encrypted_value: 'abc345', + }, + update: { + encrypted_value: 'abc456', + }, + }); + expect(upsertCreate.encrypted_value).toBe('abc345'); + await expect(db.user.findUnique({ where: { id: '2' } })).resolves.toMatchObject({ + encrypted_value: 'abc345', + }); + await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc345', + }); + + // upsert with update + const upsertUpdate = await db.user.upsert({ + where: { id: '2' }, + create: { + id: '2', + encrypted_value: 'abc345', + }, + update: { + encrypted_value: 'abc456', + }, + }); + expect(upsertUpdate.encrypted_value).toBe('abc456'); + await expect(db.user.findUnique({ where: { id: '2' } })).resolves.toMatchObject({ + encrypted_value: 'abc456', + }); + await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc456', + }); + + // createMany + await db.user.createMany({ + data: [ + { id: '3', encrypted_value: 'abc567' }, + { id: '4', encrypted_value: 'abc678' }, + ], + }); + await expect(db.user.findUnique({ where: { id: '3' } })).resolves.toMatchObject({ + encrypted_value: 'abc567', + }); + await expect(prisma.user.findUnique({ where: { id: '3' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc567', + }); + + // createManyAndReturn + await expect( + db.user.createManyAndReturn({ + data: [ + { id: '5', encrypted_value: 'abc789' }, + { id: '6', encrypted_value: 'abc890' }, + ], + }) + ).resolves.toEqual( + expect.arrayContaining([ + { id: '5', encrypted_value: 'abc789' }, + { id: '6', encrypted_value: 'abc890' }, + ]) + ); + await expect(db.user.findUnique({ where: { id: '5' } })).resolves.toMatchObject({ + encrypted_value: 'abc789', + }); + await expect(prisma.user.findUnique({ where: { id: '5' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc789', + }); + }); + + it('Works with nullish values', async () => { + const { enhance, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + encrypted_value String? @encrypted() + }`, + { + enhancements: ['encryption'], + enhanceOptions: { + encryption: { encryptionKey }, + }, + } + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: '1', encrypted_value: '' } })).resolves.toMatchObject({ + encrypted_value: '', + }); + await expect(prisma.user.findUnique({ where: { id: '1' } })).resolves.toMatchObject({ encrypted_value: '' }); + + await expect(db.user.create({ data: { id: '2' } })).resolves.toMatchObject({ + encrypted_value: null, + }); + await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.toMatchObject({ encrypted_value: null }); + + await expect(db.user.create({ data: { id: '3', encrypted_value: null } })).resolves.toMatchObject({ + encrypted_value: null, + }); + await expect(prisma.user.findUnique({ where: { id: '3' } })).resolves.toMatchObject({ encrypted_value: null }); }); it('Decrypts nested fields', async () => { diff --git a/tests/integration/tests/enhancements/with-password/with-password.test.ts b/tests/integration/tests/enhancements/with-password/with-password.test.ts index b2fd89a65..a54d0c42d 100644 --- a/tests/integration/tests/enhancements/with-password/with-password.test.ts +++ b/tests/integration/tests/enhancements/with-password/with-password.test.ts @@ -14,7 +14,7 @@ describe('Password test', () => { }); it('password tests', async () => { - const { enhance } = await loadSchema(` + const { enhance, prisma } = await loadSchema(` model User { id String @id @default(cuid()) password String @password(saltLength: 16) @@ -38,6 +38,27 @@ describe('Password test', () => { }, }); expect(compareSync('abc456', r1.password)).toBeTruthy(); + + await db.user.createMany({ + data: [ + { id: '2', password: 'user2' }, + { id: '3', password: 'user3' }, + ], + }); + await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.not.toMatchObject({ password: 'user2' }); + const r2 = await db.user.findUnique({ where: { id: '2' } }); + expect(compareSync('user2', r2.password)).toBeTruthy(); + + const [u4] = await db.user.createManyAndReturn({ + data: [ + { id: '4', password: 'user4' }, + { id: '5', password: 'user5' }, + ], + }); + expect(compareSync('user4', u4.password)).toBeTruthy(); + await expect(prisma.user.findUnique({ where: { id: '4' } })).resolves.not.toMatchObject({ password: 'user4' }); + const r4 = await db.user.findUnique({ where: { id: '4' } }); + expect(compareSync('user4', r4.password)).toBeTruthy(); }); it('length tests', async () => { From 4605278f47616ae1f1a0691af690734d576819bd Mon Sep 17 00:00:00 2001 From: Yiming Date: Tue, 7 Jan 2025 19:13:24 +0800 Subject: [PATCH 16/16] refactor(encryption): extract standalone encrypter/decrypter (#1945) --- packages/runtime/package.json | 4 + packages/runtime/src/encryption/index.ts | 67 +++++++++ packages/runtime/src/encryption/utils.ts | 96 ++++++++++++ .../src/enhancements/node/encryption.ts | 142 ++---------------- 4 files changed, 178 insertions(+), 131 deletions(-) create mode 100644 packages/runtime/src/encryption/index.ts create mode 100644 packages/runtime/src/encryption/utils.ts diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 3dcd6c4d6..1f6f106aa 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -80,6 +80,10 @@ "types": "./zod-utils.d.ts", "default": "./zod-utils.js" }, + "./encryption": { + "types": "./encryption/index.d.ts", + "default": "./encryption/index.js" + }, "./package.json": { "default": "./package.json" } diff --git a/packages/runtime/src/encryption/index.ts b/packages/runtime/src/encryption/index.ts new file mode 100644 index 000000000..d4cb31db6 --- /dev/null +++ b/packages/runtime/src/encryption/index.ts @@ -0,0 +1,67 @@ +import { _decrypt, _encrypt, ENCRYPTION_KEY_BYTES, getKeyDigest, loadKey } from './utils'; + +/** + * Default encrypter + */ +export class Encrypter { + private key: CryptoKey | undefined; + private keyDigest: string | undefined; + + constructor(private readonly encryptionKey: Uint8Array) { + if (encryptionKey.length !== ENCRYPTION_KEY_BYTES) { + throw new Error(`Encryption key must be ${ENCRYPTION_KEY_BYTES} bytes`); + } + } + + /** + * Encrypts the given data + */ + async encrypt(data: string): Promise { + if (!this.key) { + this.key = await loadKey(this.encryptionKey, ['encrypt']); + } + + if (!this.keyDigest) { + this.keyDigest = await getKeyDigest(this.encryptionKey); + } + + return _encrypt(data, this.key, this.keyDigest); + } +} + +/** + * Default decrypter + */ +export class Decrypter { + private keys: Array<{ key: CryptoKey; digest: string }> = []; + + constructor(private readonly decryptionKeys: Uint8Array[]) { + if (decryptionKeys.length === 0) { + throw new Error('At least one decryption key must be provided'); + } + + for (const key of decryptionKeys) { + if (key.length !== ENCRYPTION_KEY_BYTES) { + throw new Error(`Decryption key must be ${ENCRYPTION_KEY_BYTES} bytes`); + } + } + } + + /** + * Decrypts the given data + */ + async decrypt(data: string): Promise { + if (this.keys.length === 0) { + this.keys = await Promise.all( + this.decryptionKeys.map(async (key) => ({ + key: await loadKey(key, ['decrypt']), + digest: await getKeyDigest(key), + })) + ); + } + + return _decrypt(data, async (digest) => + this.keys.filter((entry) => entry.digest === digest).map((entry) => entry.key) + ); + } +} diff --git a/packages/runtime/src/encryption/utils.ts b/packages/runtime/src/encryption/utils.ts new file mode 100644 index 000000000..51ab41dc7 --- /dev/null +++ b/packages/runtime/src/encryption/utils.ts @@ -0,0 +1,96 @@ +import { z } from 'zod'; + +export const ENCRYPTER_VERSION = 1; +export const ENCRYPTION_KEY_BYTES = 32; +export const IV_BYTES = 12; +export const ALGORITHM = 'AES-GCM'; +export const KEY_DIGEST_BYTES = 8; + +const encoder = new TextEncoder(); +const decoder = new TextDecoder(); + +const encryptionMetaSchema = z.object({ + // version + v: z.number(), + // algorithm + a: z.string(), + // key digest + k: z.string(), +}); + +export async function loadKey(key: Uint8Array, keyUsages: KeyUsage[]): Promise { + return crypto.subtle.importKey('raw', key, ALGORITHM, false, keyUsages); +} + +export async function getKeyDigest(key: Uint8Array) { + const rawDigest = await crypto.subtle.digest('SHA-256', key); + return new Uint8Array(rawDigest.slice(0, KEY_DIGEST_BYTES)).reduce( + (acc, byte) => acc + byte.toString(16).padStart(2, '0'), + '' + ); +} + +export async function _encrypt(data: string, key: CryptoKey, keyDigest: string): Promise { + const iv = crypto.getRandomValues(new Uint8Array(IV_BYTES)); + const encrypted = await crypto.subtle.encrypt( + { + name: ALGORITHM, + iv, + }, + key, + encoder.encode(data) + ); + + // combine IV and encrypted data into a single array of bytes + const cipherBytes = [...iv, ...new Uint8Array(encrypted)]; + + // encryption metadata + const meta = { v: ENCRYPTER_VERSION, a: ALGORITHM, k: keyDigest }; + + // convert concatenated result to base64 string + return `${btoa(JSON.stringify(meta))}.${btoa(String.fromCharCode(...cipherBytes))}`; +} + +export async function _decrypt(data: string, findKey: (digest: string) => Promise): Promise { + const [metaText, cipherText] = data.split('.'); + if (!metaText || !cipherText) { + throw new Error('Malformed encrypted data'); + } + + let metaObj: unknown; + try { + metaObj = JSON.parse(atob(metaText)); + } catch (error) { + throw new Error('Malformed metadata'); + } + + // parse meta + const { a: algorithm, k: keyDigest } = encryptionMetaSchema.parse(metaObj); + + // find a matching decryption key + const keys = await findKey(keyDigest); + if (keys.length === 0) { + throw new Error('No matching decryption key found'); + } + + // convert base64 back to bytes + const bytes = Uint8Array.from(atob(cipherText), (c) => c.charCodeAt(0)); + + // extract IV from the head + const iv = bytes.slice(0, IV_BYTES); + const cipher = bytes.slice(IV_BYTES); + let lastError: unknown; + + for (const key of keys) { + let decrypted: ArrayBuffer; + try { + decrypted = await crypto.subtle.decrypt({ name: algorithm, iv }, key, cipher); + } catch (err) { + lastError = err; + continue; + } + return decoder.decode(decrypted); + } + + throw lastError; +} diff --git a/packages/runtime/src/enhancements/node/encryption.ts b/packages/runtime/src/enhancements/node/encryption.ts index 42001fc16..4859a1225 100644 --- a/packages/runtime/src/enhancements/node/encryption.ts +++ b/packages/runtime/src/enhancements/node/encryption.ts @@ -1,7 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-unused-vars */ -import { z } from 'zod'; import { ACTIONS_WITH_WRITE_PAYLOAD } from '../../constants'; import { FieldInfo, @@ -11,6 +10,7 @@ import { resolveField, type PrismaWriteActionType, } from '../../cross'; +import { Decrypter, Encrypter } from '../../encryption'; import { CustomEncryption, DbClientContract, SimpleEncryption } from '../../types'; import { InternalEnhancementOptions } from './create-enhancement'; import { Logger } from './logger'; @@ -36,27 +36,12 @@ export function withEncrypted( class EncryptedHandler extends DefaultPrismaProxyHandler { private queryUtils: QueryUtils; - private encoder = new TextEncoder(); - private decoder = new TextDecoder(); private logger: Logger; private encryptionKey: CryptoKey | undefined; private encryptionKeyDigest: string | undefined; private decryptionKeys: Array<{ key: CryptoKey; digest: string }> = []; - private encryptionMetaSchema = z.object({ - // version - v: z.number(), - // algorithm - a: z.string(), - // key digest - k: z.string(), - }); - - // constants - private readonly ENCRYPTION_KEY_BYTES = 32; - private readonly IV_BYTES = 12; - private readonly ALGORITHM = 'AES-GCM'; - private readonly ENCRYPTER_VERSION = 1; - private readonly KEY_DIGEST_BYTES = 8; + private encrypter: Encrypter | undefined; + private decrypter: Decrypter | undefined; constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { super(prisma, model, options); @@ -76,9 +61,12 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { if (!options.encryption.encryptionKey) { throw this.queryUtils.unknownError('Encryption key must be provided'); } - if (options.encryption.encryptionKey.length !== this.ENCRYPTION_KEY_BYTES) { - throw this.queryUtils.unknownError(`Encryption key must be ${this.ENCRYPTION_KEY_BYTES} bytes`); - } + + this.encrypter = new Encrypter(options.encryption.encryptionKey); + this.decrypter = new Decrypter([ + options.encryption.encryptionKey, + ...(options.encryption.decryptionKeys || []), + ]); } } @@ -86,80 +74,12 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { return 'encrypt' in encryption && 'decrypt' in encryption; } - private async loadKey(key: Uint8Array, keyUsages: KeyUsage[]): Promise { - return crypto.subtle.importKey('raw', key, this.ALGORITHM, false, keyUsages); - } - - private async computeKeyDigest(key: Uint8Array) { - const rawDigest = await crypto.subtle.digest('SHA-256', key); - return new Uint8Array(rawDigest.slice(0, this.KEY_DIGEST_BYTES)).reduce( - (acc, byte) => acc + byte.toString(16).padStart(2, '0'), - '' - ); - } - - private async getEncryptionKey(): Promise { - if (this.isCustomEncryption(this.options.encryption!)) { - throw new Error('Unexpected custom encryption settings'); - } - if (!this.encryptionKey) { - this.encryptionKey = await this.loadKey(this.options.encryption!.encryptionKey, ['encrypt', 'decrypt']); - } - return this.encryptionKey; - } - - private async getEncryptionKeyDigest() { - if (this.isCustomEncryption(this.options.encryption!)) { - throw new Error('Unexpected custom encryption settings'); - } - if (!this.encryptionKeyDigest) { - this.encryptionKeyDigest = await this.computeKeyDigest(this.options.encryption!.encryptionKey); - } - return this.encryptionKeyDigest; - } - - private async findDecryptionKeys(keyDigest: string): Promise { - if (this.isCustomEncryption(this.options.encryption!)) { - throw new Error('Unexpected custom encryption settings'); - } - - if (this.decryptionKeys.length === 0) { - const keys = [this.options.encryption!.encryptionKey, ...(this.options.encryption!.decryptionKeys || [])]; - this.decryptionKeys = await Promise.all( - keys.map(async (key) => ({ - key: await this.loadKey(key, ['decrypt']), - digest: await this.computeKeyDigest(key), - })) - ); - } - - return this.decryptionKeys.filter((entry) => entry.digest === keyDigest).map((entry) => entry.key); - } - private async encrypt(field: FieldInfo, data: string): Promise { if (this.isCustomEncryption(this.options.encryption!)) { return this.options.encryption.encrypt(this.model, field, data); } - const key = await this.getEncryptionKey(); - const iv = crypto.getRandomValues(new Uint8Array(this.IV_BYTES)); - const encrypted = await crypto.subtle.encrypt( - { - name: this.ALGORITHM, - iv, - }, - key, - this.encoder.encode(data) - ); - - // combine IV and encrypted data into a single array of bytes - const cipherBytes = [...iv, ...new Uint8Array(encrypted)]; - - // encryption metadata - const meta = { v: this.ENCRYPTER_VERSION, a: this.ALGORITHM, k: await this.getEncryptionKeyDigest() }; - - // convert concatenated result to base64 string - return `${btoa(JSON.stringify(meta))}.${btoa(String.fromCharCode(...cipherBytes))}`; + return this.encrypter!.encrypt(data); } private async decrypt(field: FieldInfo, data: string): Promise { @@ -167,47 +87,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { return this.options.encryption.decrypt(this.model, field, data); } - const [metaText, cipherText] = data.split('.'); - if (!metaText || !cipherText) { - throw new Error('Malformed encrypted data'); - } - - let metaObj: unknown; - try { - metaObj = JSON.parse(atob(metaText)); - } catch (error) { - throw new Error('Malformed metadata'); - } - - // parse meta - const { a: algorithm, k: keyDigest } = this.encryptionMetaSchema.parse(metaObj); - - // find a matching decryption key - const keys = await this.findDecryptionKeys(keyDigest); - if (keys.length === 0) { - throw new Error('No matching decryption key found'); - } - - // convert base64 back to bytes - const bytes = Uint8Array.from(atob(cipherText), (c) => c.charCodeAt(0)); - - // extract IV from the head - const iv = bytes.slice(0, this.IV_BYTES); - const cipher = bytes.slice(this.IV_BYTES); - let lastError: unknown; - - for (const key of keys) { - let decrypted: ArrayBuffer; - try { - decrypted = await crypto.subtle.decrypt({ name: algorithm, iv }, key, cipher); - } catch (err) { - lastError = err; - continue; - } - return this.decoder.decode(decrypted); - } - - throw lastError; + return this.decrypter!.decrypt(data); } // base override