diff --git a/lib/Config.js b/lib/Config.js index f6f3273133..b10c1df346 100644 --- a/lib/Config.js +++ b/lib/Config.js @@ -1839,6 +1839,11 @@ class Config extends EventEmitter { if (config.rateLimiting?.enabled) { + // rate limiting uses the same localCache config defined for S3 to avoid + // config duplication. + assert(this.localCache, 'missing required property of rateLimit ' + + 'configuration: localCache'); + this.rateLimiting.enabled = true; assert.strictEqual(typeof config.rateLimiting.serviceUserArn, 'string'); diff --git a/lib/api/apiUtils/rateLimit/client.js b/lib/api/apiUtils/rateLimit/client.js new file mode 100644 index 0000000000..26332f0c63 --- /dev/null +++ b/lib/api/apiUtils/rateLimit/client.js @@ -0,0 +1,78 @@ +const fs = require('fs'); + +const Redis = require('ioredis'); + +const { config } = require('../../../Config'); + +const updateCounterScript = fs.readFileSync(`${__dirname }/updateCounter.lua`).toString(); + +const SCRIPTS = { + updateCounter: { + numberOfKeys: 1, + lua: updateCounterScript, + }, +}; + +class RateLimitClient { + constructor(redisConfig) { + this.redis = new Redis({ + ...redisConfig, + scripts: SCRIPTS, + lazyConnect: true, + }); + } + + /** + * @typedef {Object} CounterUpdateBatch + * @property {string} key - counter key + * @property {number} cost - cost to add to counter + */ + + /** + * @typedef {Object} CounterUpdateBatchResult + * @property {string} key - counter key + * @property {number} value - current value of counter + */ + + /** + * @callback RateLimitClient~batchUpdate + * @param {Error|null} err + * @param {CounterUpdateBatchResult[]|undefined} + */ + + /** + * Add cost to the counter at key. + * Returns the new value for the counter + * + * @param {CounterUpdateBatch[]} batch - batch of counter updates + * @param {RateLimitClient~batchUpdate} cb + */ + updateLocalCounters(batch, cb) { + const pipeline = this.redis.pipeline(); + for (const { key, cost } of batch) { + pipeline.updateCounter(key, cost); + } + + pipeline.exec((err, results) => { + if (err) { + cb(err); + return; + } + + cb(null, results.map((res, i) => ({ + key: batch[i].key, + value: res[1], + }))); + }); + } +} + +let instance; +if (config.rateLimiting.enabled) { + instance = new RateLimitClient(config.localCache); +} + +module.exports = { + instance, + RateLimitClient +}; diff --git a/lib/api/apiUtils/rateLimit/updateCounter.lua b/lib/api/apiUtils/rateLimit/updateCounter.lua new file mode 100644 index 0000000000..b12b33ad59 --- /dev/null +++ b/lib/api/apiUtils/rateLimit/updateCounter.lua @@ -0,0 +1,27 @@ +-- updateCounter +-- +-- Adds the passed COST to the GCRA counter at KEY. +-- If no counter currently exists a new one is created from the current time. +-- The key expiration is set to the updated value. +-- Returns the value of the updated key. + +local ts = redis.call('TIME') +local currentTime = ts[1] * 1000 +currentTime = currentTime + math.floor(ts[2] / 1000) + +local newValue = currentTime + tonumber(ARGV[1]) + +local counterExists = redis.call('EXISTS', KEYS[1]) +if counterExists == 1 then + local currentValue = tonumber(redis.call('GET', KEYS[1])) + if currentValue > currentTime then + newValue = currentValue + tonumber(ARGV[1]) + end +end + +redis.call('SET', KEYS[1], newValue) + +local expiry = math.ceil(newValue / 1000) +redis.call('EXPIREAT', KEYS[1], expiry) + +return newValue diff --git a/tests/functional/aws-node-sdk/test/rateLimit/client.js b/tests/functional/aws-node-sdk/test/rateLimit/client.js new file mode 100644 index 0000000000..139e88e469 --- /dev/null +++ b/tests/functional/aws-node-sdk/test/rateLimit/client.js @@ -0,0 +1,44 @@ +const assert = require('assert'); + +const { config } = require('../../../../../lib/Config'); +const { RateLimitClient } = require('../../../../../lib/api/apiUtils/rateLimit/client'); + + +const counterKey = 'foo'; + +describe('Test RateLimitClient', () => { + let client; + + before(done => { + client = new RateLimitClient(config.localCache); + client.redis.connect(done); + }); + + beforeEach(done => { + client.redis.del(counterKey, err => done(err)); + }); + + it('should set the value of an empty counter', done => { + const batch = [{ key: counterKey, cost: 10000 }]; + client.updateLocalCounters(batch, (err, res) => { + assert.ifError(err); + assert.strictEqual(res.length, 1); + assert.strictEqual(res[0].key, counterKey); + done(); + }); + }); + + it('should increment the value of an existing counter', done => { + const batch = [{ key: counterKey, cost: 10000 }]; + client.updateLocalCounters(batch, (err, res) => { + assert.ifError(err); + const { value: existingValue } = res[0]; + client.updateLocalCounters(batch, (err, res) => { + assert.ifError(err); + const { value: newValue } = res[0]; + assert(newValue > existingValue, `${newValue} is not greater than ${existingValue}`); + done(); + }); + }); + }); +}); diff --git a/tests/unit/api/apiUtils/rateLimit/client.js b/tests/unit/api/apiUtils/rateLimit/client.js new file mode 100644 index 0000000000..487322f8c0 --- /dev/null +++ b/tests/unit/api/apiUtils/rateLimit/client.js @@ -0,0 +1,83 @@ +const assert = require('assert'); + +const { RateLimitClient } = require('../../../../../lib/api/apiUtils/rateLimit/client'); + +class RedisStub { + constructor() { + this.data = {}; + this.execErr = null; + } + + pipeline() { + return new PipelineStub(this.execErr); + } + + setExecErr(err) { + this.execErr = err; + } +} + +class PipelineStub { + constructor(execErr) { + this.ops = []; + this.execErr = execErr; + } + + updateCounter(key, cost) { + this.ops.push([key, cost]); + } + + exec(cb) { + if (this.execErr) { + cb(this.execErr); + } else { + cb(null, this.ops.map(v => [1, v[1]])); + } + } +} + +describe('test RateLimitClient', () => { + let client; + + before(() => { + client = new RateLimitClient({}); + }); + + beforeEach(() => { + client.redis = new RedisStub(); + }); + + it('should update a batch of counters', done => { + const batch = [ + { key: 'foo', cost: 100 }, + { key: 'bar', cost: 200 }, + { key: 'qux', cost: 300 }, + ]; + + client.updateLocalCounters(batch, (err, results) => { + assert.ifError(err); + assert.deepStrictEqual(results, [ + { key: 'foo', value: 100 }, + { key: 'bar', value: 200 }, + { key: 'qux', value: 300 }, + ]); + done(); + }); + }); + + it('should pass through errors', done => { + const execErr = new Error('bad stuff'); + client.redis.setExecErr(execErr); + const batch = [ + { key: 'foo', cost: 100 }, + { key: 'bar', cost: 200 }, + { key: 'qux', cost: 300 }, + ]; + + client.updateLocalCounters(batch, (err, results) => { + assert.strictEqual(err, execErr); + assert.strictEqual(results, undefined); + done(); + }); + }); +});