diff --git a/src/headers.mjs b/src/headers.mjs index 5395879..b41d821 100644 --- a/src/headers.mjs +++ b/src/headers.mjs @@ -2,6 +2,7 @@ * @typedef {import('./core.d.ts').PerPageHashes} PerPageHashes * @typedef {import('./main.d.ts').CSPDirectiveNames} CSPDirectiveNames * @typedef {import('./main.d.ts').CSPDirectives} CSPDirectives + * @typedef {import('./main.d.ts').CSPOptions} CSPOptions * @typedef {import('./main.d.ts').SecurityHeadersOptions} SecurityHeadersOptions */ @@ -79,23 +80,19 @@ export const parseCspDirectives = cspHeader => { } /** - * @param {Headers} headers + * @param {Record} plainHeaders * @param {PerPageHashes} pageHashes - * @param {SecurityHeadersOptions} securityHeadersOpts - * @returns {Headers} + * @param {CSPOptions} cspOpts */ -export const patchHeaders = (headers, pageHashes, securityHeadersOpts) => { - const directives = headers.has('content-security-policy') +export const patchCspHeader = (plainHeaders, pageHashes, cspOpts) => { + const directives = Object.hasOwn(plainHeaders, 'content-security-policy') ? { - ...securityHeadersOpts.contentSecurityPolicy?.cspDirectives, + ...cspOpts.cspDirectives, ...parseCspDirectives( - /** @type {string} */ (headers.get('content-security-policy')), + /** @type {string} */ (plainHeaders['content-security-policy']), ), } - : securityHeadersOpts.contentSecurityPolicy?.cspDirectives ?? - /** @type {CSPDirectives} */ ({}) - - const plainHeaders = Object.fromEntries(headers.entries()) + : cspOpts.cspDirectives ?? /** @type {CSPDirectives} */ ({}) if (pageHashes.scripts.size > 0) { setSrcDirective(directives, 'script-src', pageHashes.scripts) @@ -106,6 +103,20 @@ export const patchHeaders = (headers, pageHashes, securityHeadersOpts) => { if (Object.keys(directives).length > 0) { plainHeaders['content-security-policy'] = serialiseCspDirectives(directives) } +} + +/** + * @param {Headers} headers + * @param {PerPageHashes} pageHashes + * @param {SecurityHeadersOptions} securityHeadersOpts + * @returns {Headers} + */ +export const patchHeaders = (headers, pageHashes, securityHeadersOpts) => { + const plainHeaders = Object.fromEntries(headers.entries()) + + if (securityHeadersOpts.contentSecurityPolicy !== undefined) { + patchCspHeader(plainHeaders, pageHashes, securityHeadersOpts.contentSecurityPolicy) + } return new Headers(plainHeaders) } diff --git a/src/main.d.ts b/src/main.d.ts index 1baca2b..0611c2c 100644 --- a/src/main.d.ts +++ b/src/main.d.ts @@ -48,6 +48,11 @@ export type CSPOptions = { */ // sriHashesStrategy?: 'all' | 'perPage' // TODO: Enable in the future + /** + * - If set, it controls the "default" CSP directives (they can be overriden + * at runtime). + * - If not set, the middleware will use a minimal set of default directives. + */ cspDirectives?: CSPDirectives } diff --git a/tests/headers.test.mts b/tests/headers.test.mts index 1be5cb5..56a3df7 100644 --- a/tests/headers.test.mts +++ b/tests/headers.test.mts @@ -126,6 +126,18 @@ describe('patchHeaders', () => { expect(patchedHeaders.has('content-security-policy')).toBe(false) }) + it('does not set csp header if no contentSecurityPolicy option is set', () => { + const headers = new Headers() + const pageHashes = { + scripts: new Set(['abc1', 'xyz2']), + styles: new Set(['dbc1', 'xyz3', 'abc2']), + } + const settings: SecurityHeadersOptions = { /* contentSecurityPolicy: {} */ } + + const patchedHeaders = patchHeaders(headers, pageHashes, settings) + expect(patchedHeaders.has('content-security-policy')).toBe(false) + }) + it('sets csp header based on settings', () => { const headers = new Headers() const pageHashes = { scripts: new Set(), styles: new Set() } @@ -150,7 +162,7 @@ describe('patchHeaders', () => { scripts: new Set(['abc1', 'xyz2']), styles: new Set(['dbc1', 'xyz3', 'abc2']), } - const settings: SecurityHeadersOptions = {} + const settings: SecurityHeadersOptions = { contentSecurityPolicy: {} } const patchedHeaders = patchHeaders(headers, pageHashes, settings) expect(patchedHeaders.get('content-security-policy')).toBe( diff --git a/vitest.config.unit.mts b/vitest.config.unit.mts index 5249863..78196bf 100644 --- a/vitest.config.unit.mts +++ b/vitest.config.unit.mts @@ -19,10 +19,10 @@ export default defineConfig({ 'coverage-unit/**/*', ], thresholds: { - statements: 70.0, - branches: 75.0, - functions: 70.0, - lines: 70.0, + statements: 72.0, + branches: 77.0, + functions: 80.0, + lines: 72.0, }, reportsDirectory: 'coverage-unit', },