Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: more CLI options #695

Merged
merged 2 commits into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/schema/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"semver": "^7.3.8",
"sleep-promise": "^9.1.0",
"strip-color": "^0.1.0",
"tiny-invariant": "^1.3.1",
"ts-morph": "^16.0.0",
"ts-pattern": "^4.3.0",
"upper-case-first": "^2.0.2",
Expand Down
17 changes: 10 additions & 7 deletions packages/schema/src/cli/actions/generate.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import { PluginError } from '@zenstackhq/sdk';
import colors from 'colors';
import path from 'path';
import { Context } from '../../types';
import { PackageManagers } from '../../utils/pkg-utils';
import { CliError } from '../cli-error';
import {
checkNewVersion,
Expand All @@ -11,13 +9,15 @@ import {
loadDocument,
requiredPrismaVersion,
} from '../cli-util';
import { PluginRunner } from '../plugin-runner';
import { PluginRunner, PluginRunnerOptions } from '../plugin-runner';

type Options = {
schema: string;
packageManager: PackageManagers | undefined;
output?: string;
dependencyCheck: boolean;
versionCheck: boolean;
compile: boolean;
defaultPlugins: boolean;
};

/**
Expand Down Expand Up @@ -53,14 +53,17 @@ export async function generate(projectPath: string, options: Options) {

async function runPlugins(options: Options) {
const model = await loadDocument(options.schema);
const context: Context = {

const runnerOpts: PluginRunnerOptions = {
schema: model,
schemaPath: path.resolve(options.schema),
outDir: path.dirname(options.schema),
defaultPlugins: options.defaultPlugins,
output: options.output,
compile: options.compile,
};

try {
await new PluginRunner().run(context);
await new PluginRunner().run(runnerOpts);
} catch (err) {
if (err instanceof PluginError) {
console.error(colors.red(`${err.plugin}: ${err.message}`));
Expand Down
6 changes: 4 additions & 2 deletions packages/schema/src/cli/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ export function createProgram() {
.addOption(configOption)
.addOption(pmOption)
.addOption(new Option('--prisma <file>', 'location of Prisma schema file to bootstrap from'))
.addOption(new Option('--tag [tag]', 'the NPM package tag to use when installing dependencies'))
.addOption(new Option('--tag <tag>', 'the NPM package tag to use when installing dependencies'))
.addOption(noVersionCheckOption)
.argument('[path]', 'project path', '.')
.action(initAction);
Expand All @@ -90,8 +90,10 @@ export function createProgram() {
.command('generate')
.description('Run code generation.')
.addOption(schemaOption)
.addOption(new Option('-o, --output <path>', 'default output directory for built-in plugins'))
.addOption(configOption)
.addOption(pmOption)
.addOption(new Option('--no-default-plugins', 'do not run default plugins'))
.addOption(new Option('--no-compile', 'do not compile the output of built-in plugins'))
.addOption(noVersionCheckOption)
.addOption(noDependencyCheck)
.action(generateAction);
Expand Down
130 changes: 85 additions & 45 deletions packages/schema/src/cli/plugin-runner.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-var-requires */
import type { DMMF } from '@prisma/generator-helper';
import { isPlugin, Plugin } from '@zenstackhq/language/ast';
import { isPlugin, Model, Plugin } from '@zenstackhq/language/ast';
import {
getDataModels,
getDMMF,
Expand All @@ -19,9 +19,7 @@ import ora from 'ora';
import path from 'path';
import { ensureDefaultOutputFolder } from '../plugins/plugin-utils';
import telemetry from '../telemetry';
import type { Context } from '../types';
import { getVersion } from '../utils/version-utils';
import { config } from './config';

type PluginInfo = {
name: string;
Expand All @@ -32,23 +30,31 @@ type PluginInfo = {
module: any;
};

export type PluginRunnerOptions = {
schema: Model;
schemaPath: string;
output?: string;
defaultPlugins: boolean;
compile: boolean;
};

/**
* ZenStack plugin runner
*/
export class PluginRunner {
/**
* Runs a series of nested generators
*/
async run(context: Context): Promise<void> {
async run(options: PluginRunnerOptions): Promise<void> {
const version = getVersion();
console.log(colors.bold(`⌛️ ZenStack CLI v${version}, running plugins`));

ensureDefaultOutputFolder();
ensureDefaultOutputFolder(options);

const plugins: PluginInfo[] = [];
const pluginDecls = context.schema.declarations.filter((d): d is Plugin => isPlugin(d));
const pluginDecls = options.schema.declarations.filter((d): d is Plugin => isPlugin(d));

let prismaOutput = resolvePath('./prisma/schema.prisma', { schemaPath: context.schemaPath, name: '' });
let prismaOutput = resolvePath('./prisma/schema.prisma', { schemaPath: options.schemaPath, name: '' });

for (const pluginDecl of pluginDecls) {
const pluginProvider = this.getPluginProvider(pluginDecl);
Expand All @@ -73,59 +79,35 @@ export class PluginRunner {

const dependencies = this.getPluginDependencies(pluginModule);
const pluginName = this.getPluginName(pluginModule, pluginProvider);
const options: PluginOptions = { schemaPath: context.schemaPath, name: pluginName };
const pluginOptions: PluginOptions = { schemaPath: options.schemaPath, name: pluginName };

pluginDecl.fields.forEach((f) => {
const value = getLiteral(f.value) ?? getLiteralArray(f.value);
if (value === undefined) {
throw new PluginError(pluginName, `Invalid option value for ${f.name}`);
}
options[f.name] = value;
pluginOptions[f.name] = value;
});

plugins.push({
name: pluginName,
provider: pluginProvider,
dependencies,
options,
options: pluginOptions,
run: pluginModule.default as PluginFunction,
module: pluginModule,
});

if (pluginProvider === '@core/prisma' && typeof options.output === 'string') {
if (pluginProvider === '@core/prisma' && typeof pluginOptions.output === 'string') {
// record custom prisma output path
prismaOutput = resolvePath(options.output, options);
prismaOutput = resolvePath(pluginOptions.output, pluginOptions);
}
}

// make sure prerequisites are included
const corePlugins: Array<{ provider: string; options?: Record<string, unknown> }> = [
{ provider: '@core/prisma' },
{ provider: '@core/model-meta' },
{ provider: '@core/access-policy' },
];

if (getDataModels(context.schema).some((model) => hasValidationAttributes(model))) {
// '@core/zod' plugin is auto-enabled if there're validation rules
corePlugins.push({ provider: '@core/zod', options: { modelOnly: true } });
}

// core plugins introduced by dependencies
plugins
.flatMap((p) => p.dependencies)
.forEach((dep) => {
if (dep.startsWith('@core/')) {
const existing = corePlugins.find((p) => p.provider === dep);
if (existing) {
// reset options to default
existing.options = undefined;
} else {
// add core dependency
corePlugins.push({ provider: dep });
}
}
});
// get core plugins that need to be enabled
const corePlugins = this.calculateCorePlugins(options, plugins);

// shift/insert core plugins to the front
for (const corePlugin of corePlugins.reverse()) {
const existingIdx = plugins.findIndex((p) => p.provider === corePlugin.provider);
if (existingIdx >= 0) {
Expand All @@ -141,7 +123,7 @@ export class PluginRunner {
name: pluginName,
provider: corePlugin.provider,
dependencies: [],
options: { schemaPath: context.schemaPath, name: pluginName, ...corePlugin.options },
options: { schemaPath: options.schemaPath, name: pluginName, ...corePlugin.options },
run: pluginModule.default,
module: pluginModule,
});
Expand All @@ -161,12 +143,17 @@ export class PluginRunner {
}
}

if (plugins.length === 0) {
console.log(colors.yellow('No plugins configured.'));
return;
}

const warnings: string[] = [];

let dmmf: DMMF.Document | undefined = undefined;
for (const { name, provider, run, options } of plugins) {
for (const { name, provider, run, options: pluginOptions } of plugins) {
// const start = Date.now();
await this.runPlugin(name, run, context, options, dmmf, warnings);
await this.runPlugin(name, run, options, pluginOptions, dmmf, warnings);
// console.log(`✅ Plugin ${colors.bold(name)} (${provider}) completed in ${Date.now() - start}ms`);
if (provider === '@core/prisma') {
// load prisma DMMF
Expand All @@ -175,14 +162,64 @@ export class PluginRunner {
});
}
}

console.log(colors.green(colors.bold('\n👻 All plugins completed successfully!')));

warnings.forEach((w) => console.warn(colors.yellow(w)));

console.log(`Don't forget to restart your dev server to let the changes take effect.`);
}

private calculateCorePlugins(options: PluginRunnerOptions, plugins: PluginInfo[]) {
const corePlugins: Array<{ provider: string; options?: Record<string, unknown> }> = [];

if (options.defaultPlugins) {
corePlugins.push(
{ provider: '@core/prisma' },
{ provider: '@core/model-meta' },
{ provider: '@core/access-policy' }
);
} else if (plugins.length > 0) {
// "@core/prisma" plugin is always enabled if any plugin is configured
corePlugins.push({ provider: '@core/prisma' });
}

// "@core/access-policy" has implicit requirements
if ([...plugins, ...corePlugins].find((p) => p.provider === '@core/access-policy')) {
// make sure "@core/model-meta" is enabled
if (!corePlugins.find((p) => p.provider === '@core/model-meta')) {
corePlugins.push({ provider: '@core/model-meta' });
}

// '@core/zod' plugin is auto-enabled by "@core/access-policy"
// if there're validation rules
if (!corePlugins.find((p) => p.provider === '@core/zod') && this.hasValidation(options.schema)) {
corePlugins.push({ provider: '@core/zod', options: { modelOnly: true } });
}
}

// core plugins introduced by dependencies
plugins
.flatMap((p) => p.dependencies)
.forEach((dep) => {
if (dep.startsWith('@core/')) {
const existing = corePlugins.find((p) => p.provider === dep);
if (existing) {
// reset options to default
existing.options = undefined;
} else {
// add core dependency
corePlugins.push({ provider: dep });
}
}
});

return corePlugins;
}

private hasValidation(schema: Model) {
return getDataModels(schema).some((model) => hasValidationAttributes(model));
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
private getPluginName(pluginModule: any, pluginProvider: string): string {
return typeof pluginModule.name === 'string' ? (pluginModule.name as string) : pluginProvider;
Expand All @@ -200,7 +237,7 @@ export class PluginRunner {
private async runPlugin(
name: string,
run: PluginFunction,
context: Context,
runnerOptions: PluginRunnerOptions,
options: PluginOptions,
dmmf: DMMF.Document | undefined,
warnings: string[]
Expand All @@ -216,7 +253,10 @@ export class PluginRunner {
options,
},
async () => {
let result = run(context.schema, options, dmmf, config);
let result = run(runnerOptions.schema, options, dmmf, {
output: runnerOptions.output,
compile: runnerOptions.compile,
});
if (result instanceof Promise) {
result = await result;
}
Expand Down
11 changes: 6 additions & 5 deletions packages/schema/src/plugins/access-policy/index.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Model } from '@zenstackhq/language/ast';
import { PluginOptions } from '@zenstackhq/sdk';
import { PluginFunction } from '@zenstackhq/sdk';
import PolicyGenerator from './policy-guard-generator';

export const name = 'Access Policy';

export default async function run(model: Model, options: PluginOptions) {
return new PolicyGenerator().generate(model, options);
}
const run: PluginFunction = async (model, options, _dmmf, globalOptions) => {
return new PolicyGenerator().generate(model, options, globalOptions);
};

export default run;
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
import {
ExpressionContext,
PluginError,
PluginGlobalOptions,
PluginOptions,
RUNTIME_PACKAGE,
analyzePolicies,
Expand Down Expand Up @@ -65,8 +66,8 @@ import { ExpressionWriter, FALSE, TRUE } from './expression-writer';
* Generates source file that contains Prisma query guard objects used for injecting database queries
*/
export default class PolicyGenerator {
async generate(model: Model, options: PluginOptions) {
let output = options.output ? (options.output as string) : getDefaultOutputFolder();
async generate(model: Model, options: PluginOptions, globalOptions?: PluginGlobalOptions) {
let output = options.output ? (options.output as string) : getDefaultOutputFolder(globalOptions);
if (!output) {
throw new PluginError(options.name, `Unable to determine output path, not running plugin`);
}
Expand Down Expand Up @@ -147,7 +148,14 @@ export default class PolicyGenerator {

sf.addStatements('export default policy');

const shouldCompile = options.compile !== false;
let shouldCompile = true;
if (typeof options.compile === 'boolean') {
// explicit override
shouldCompile = options.compile;
} else if (globalOptions) {
shouldCompile = globalOptions.compile;
}

if (!shouldCompile || options.preserveTsFiles === true) {
// save ts files
await saveProject(project);
Expand Down
Loading