--[[
Copyright (c) 2011-2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
Copyright (c) 2016-2017, Andrew Lewis <nerf@judo.za.org>

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]--

if confighelp then
  return
end

-- A plugin that implements ratelimits using redis

local E = {}
local N = 'ratelimit'
local redis_params
-- Senders that are considered as bounce
local settings = {
  bounce_senders = { 'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon' },
-- Do not check ratelimits for these recipients
  whitelisted_rcpts = { 'postmaster', 'mailer-daemon' },
  prefix = 'RL',
  ham_factor_rate = 1.01,
  spam_factor_rate = 0.99,
  ham_factor_burst = 1.02,
  spam_factor_burst = 0.98,
  max_rate_mult = 5,
  max_bucket_mult = 10,
  expire = 60 * 60 * 24 * 2, -- 2 days by default
  limits = {},
  allow_local = false,
}

-- Checks bucket, updating it if needed
-- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
-- KEYS[2] - current time in milliseconds
-- KEYS[3] - bucket leak rate (messages per millisecond)
-- KEYS[4] - bucket burst
-- KEYS[5] - expire for a bucket
-- return 1 if message should be ratelimited and 0 if not
-- Redis keys used:
--   l - last hit
--   b - current burst
--   dr - current dynamic rate multiplier (*10000)
--   db - current dynamic burst multiplier (*10000)
local bucket_check_script = [[
  local last = redis.call('HGET', KEYS[1], 'l')
  local now = tonumber(KEYS[2])
  local dynr, dynb = 0, 0
  if not last then
    -- New bucket
    redis.call('HSET', KEYS[1], 'l', KEYS[2])
    redis.call('HSET', KEYS[1], 'b', '0')
    redis.call('HSET', KEYS[1], 'dr', '10000')
    redis.call('HSET', KEYS[1], 'db', '10000')
    redis.call('EXPIRE', KEYS[1], KEYS[5])
    return {0, 0, 1, 1}
  end

  last = tonumber(last)
  local burst = tonumber(redis.call('HGET', KEYS[1], 'b'))
  -- Perform leak
  if burst > 0 then
   if last < tonumber(KEYS[2]) then
    local rate = tonumber(KEYS[3])
    dynr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000.0
    rate = rate * dynr
    local leaked = ((now - last) * rate)
    burst = burst - leaked
    redis.call('HINCRBYFLOAT', KEYS[1], 'b', -(leaked))
   end
  else
   burst = 0
   redis.call('HSET', KEYS[1], 'b', '0')
  end

  dynb = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000.0

  if (burst + 1) * dynb > tonumber(KEYS[4]) then
   return {1, tostring(burst), tostring(dynr), tostring(dynb)}
  end

  return {0, tostring(burst), tostring(dynr), tostring(dynb)}
]]
local bucket_check_id


-- Updates a bucket
-- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
-- KEYS[2] - current time in milliseconds
-- KEYS[3] - dynamic rate multiplier
-- KEYS[4] - dynamic burst multiplier
-- KEYS[5] - max dyn rate (min: 1/x)
-- KEYS[6] - max burst rate (min: 1/x)
-- KEYS[7] - expire for a bucket
-- Redis keys used:
--   l - last hit
--   b - current burst
--   dr - current dynamic rate multiplier
--   db - current dynamic burst multiplier
local bucket_update_script = [[
  local last = redis.call('HGET', KEYS[1], 'l')
  local now = tonumber(KEYS[2])
  if not last then
    -- New bucket
    redis.call('HSET', KEYS[1], 'l', KEYS[2])
    redis.call('HSET', KEYS[1], 'b', '1')
    redis.call('HSET', KEYS[1], 'dr', '10000')
    redis.call('HSET', KEYS[1], 'db', '10000')
    redis.call('EXPIRE', KEYS[1], KEYS[7])
    return {1, 1, 1}
  end

  local burst = tonumber(redis.call('HGET', KEYS[1], 'b'))
  local db = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000
  local dr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000

  if dr < tonumber(KEYS[5]) and dr > 1.0 / tonumber(KEYS[5]) then
    dr = dr * tonumber(KEYS[3])
    redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000)))
  end

  if db < tonumber(KEYS[6]) and db > 1.0 / tonumber(KEYS[6]) then
    db = db * tonumber(KEYS[4])
    redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000)))
  end

  redis.call('HINCRBYFLOAT', KEYS[1], 'b', 1)
  redis.call('HSET', KEYS[1], 'l', KEYS[2])
  redis.call('EXPIRE', KEYS[1], KEYS[7])

  return {tostring(burst), tostring(dr), tostring(db)}
]]
local bucket_update_id

-- message_func(task, limit_type, prefix, bucket)
local message_func = function(_, limit_type, _, _)
  return string.format('Ratelimit "%s" exceeded', limit_type)
end

local rspamd_logger = require "rspamd_logger"
local rspamd_util = require "rspamd_util"
local rspamd_lua_utils = require "lua_util"
local lua_redis = require "lua_redis"
local fun = require "fun"
local lua_maps = require "lua_maps"
local lua_util = require "lua_util"
local rspamd_hash = require "rspamd_cryptobox_hash"


local function load_scripts(cfg, ev_base)
  bucket_check_id = lua_redis.add_redis_script(bucket_check_script, redis_params)
  bucket_update_id = lua_redis.add_redis_script(bucket_update_script, redis_params)
end

local limit_parser
local function parse_string_limit(lim, no_error)
  local function parse_time_suffix(s)
    if s == 's' then
      return 1
    elseif s == 'm' then
      return 60
    elseif s == 'h' then
      return 3600
    elseif s == 'd' then
      return 86400
    end
  end
  local function parse_num_suffix(s)
    if s == '' then
      return 1
    elseif s == 'k' then
      return 1000
    elseif s == 'm' then
      return 1000000
    elseif s == 'g' then
      return 1000000000
    end
  end
  local lpeg = require "lpeg"

  if not limit_parser then
    local digit = lpeg.R("09")
    limit_parser = {}
    limit_parser.integer =
    (lpeg.S("+-") ^ -1) *
            (digit   ^  1)
    limit_parser.fractional =
    (lpeg.P(".")   ) *
            (digit ^ 1)
    limit_parser.number =
    (limit_parser.integer *
            (limit_parser.fractional ^ -1)) +
            (lpeg.S("+-") * limit_parser.fractional)
    limit_parser.time = lpeg.Cf(lpeg.Cc(1) *
            (limit_parser.number / tonumber) *
            ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
      function (acc, val) return acc * val end)
    limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *
            (limit_parser.number / tonumber) *
            ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
      function (acc, val) return acc * val end)
    limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *
            (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
            limit_parser.time)
  end
  local t = lpeg.match(limit_parser.limit, lim)

  if t and t[1] and t[2] and t[2] ~= 0 then
    return t[2], t[1]
  end

  if not no_error then
    rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim)
  end

  return nil
end

local function parse_limit(name, data)
  local buckets = {}
  if type(data) == 'table' then
    -- 3 cases here:
    --  * old limit in format [burst, rate]
    --  * vector of strings in Andrew's string format
    --  * proper bucket table
    if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then
      -- Old style ratelimit
      rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name)
      if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then
        table.insert(buckets, {
          burst = data[1],
          rate = data[2]
        })
      elseif data[1] ~= 0 then
        rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name)
      else
        rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name)
      end
    else
      -- Recursively map parse_limit and flatten the list
      fun.each(function(l)
        -- Flatten list
        for _,b in ipairs(l) do table.insert(buckets, b) end
      end, fun.map(function(d) return parse_limit(d, name) end, data))
    end
  elseif type(data) == 'string' then
    local rep_rate, burst = parse_string_limit(data)

    if rep_rate and burst then
      table.insert(buckets, {
        burst = burst,
        rate = 1.0 / rep_rate -- reciprocal
      })
    end
  end

  -- Filter valid
  return fun.totable(fun.filter(function(val)
    return type(val.burst) == 'number' and type(val.rate) == 'number'
  end, buckets))
end

--- Check whether this addr is bounce
local function check_bounce(from)
  return fun.any(function(b) return b == from end, settings.bounce_senders)
end

local keywords = {
  ['ip'] = {
    ['get_value'] = function(task)
      local ip = task:get_ip()
      if ip and ip:is_valid() then return tostring(ip) end
      return nil
    end,
  },
  ['rip'] = {
    ['get_value'] = function(task)
      local ip = task:get_ip()
      if ip and ip:is_valid() and not ip:is_local() then return tostring(ip) end
      return nil
    end,
  },
  ['from'] = {
    ['get_value'] = function(task)
      local from = task:get_from(0)
      if ((from or E)[1] or E).addr then
        return string.lower(from[1]['addr'])
      end
      return nil
    end,
  },
  ['bounce'] = {
    ['get_value'] = function(task)
      local from = task:get_from(0)
      if not ((from or E)[1] or E).user then
        return '_'
      end
      if check_bounce(from[1]['user']) then return '_' else return nil end
    end,
  },
  ['asn'] = {
    ['get_value'] = function(task)
      local asn = task:get_mempool():get_variable('asn')
      if not asn then
        return nil
      else
        return asn
      end
    end,
  },
  ['user'] = {
    ['get_value'] = function(task)
      local auser = task:get_user()
      if not auser then
        return nil
      else
        return auser
      end
    end,
  },
  ['to'] = {
    ['get_value'] = function(task)
      return task:get_principal_recipient()
    end,
  },
}

local function gen_rate_key(task, rtype, bucket)
  local key_t = {tostring(lua_util.round(100000.0 / bucket.burst))}
  local key_keywords = lua_util.str_split(rtype, '_')
  local have_user = false

  for _, v in ipairs(key_keywords) do
    local ret

    if keywords[v] and type(keywords[v]['get_value']) == 'function' then
      ret = keywords[v]['get_value'](task)
    end
    if not ret then return nil end
    if v == 'user' then have_user = true end
    if type(ret) ~= 'string' then ret = tostring(ret) end
    table.insert(key_t, ret)
  end

  if have_user and not task:get_user() then
    return nil
  end

  return table.concat(key_t, ":")
end

local function make_prefix(redis_key, name, bucket)
  local hash_len = 24
  if hash_len > #redis_key then hash_len = #redis_key end
  local hash = settings.prefix ..
      string.sub(rspamd_hash.create(redis_key):base32(), 1, hash_len)
  -- Fill defaults
  if not bucket.spam_factor_rate then
    bucket.spam_factor_rate = settings.spam_factor_rate
  end
  if not bucket.ham_factor_rate then
    bucket.ham_factor_rate = settings.ham_factor_rate
  end
  if not bucket.spam_factor_burst then
    bucket.spam_factor_burst = settings.spam_factor_burst
  end
  if not bucket.ham_factor_burst then
    bucket.ham_factor_burst = settings.ham_factor_burst
  end

  return {
    bucket = bucket,
    name = name,
    hash = hash
  }
end

local function limit_to_prefixes(task, k, v, prefixes)
  local n = 0
  for _,bucket in ipairs(v) do
    local prefix = gen_rate_key(task, k, bucket)

    if prefix then
      prefixes[prefix] = make_prefix(prefix, k, bucket)
      n = n + 1
    end
  end

  return n
end

local function ratelimit_cb(task)
  if not settings.allow_local and
          rspamd_lua_utils.is_rspamc_or_controller(task) then return end

  -- Get initial task data
  local ip = task:get_from_ip()
  if ip and ip:is_valid() and settings.whitelisted_ip then
    if settings.whitelisted_ip:get_key(ip) then
      -- Do not check whitelisted ip
      rspamd_logger.infox(task, 'skip ratelimit for whitelisted IP')
      return
    end
  end
  -- Parse all rcpts
  local rcpts = task:get_recipients()
  local rcpts_user = {}
  if rcpts then
    fun.each(function(r)
      fun.each(function(type) table.insert(rcpts_user, r[type]) end, {'user', 'addr'})
    end, rcpts)

    if fun.any(function(r) return settings.whitelisted_rcpts:get_key(r) end, rcpts_user) then
      rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient')
      return
    end
  end
  -- Get user (authuser)
  if settings.whitelisted_user then
    local auser = task:get_user()
    if settings.whitelisted_user:get_key(auser) then
      rspamd_logger.infox(task, 'skip ratelimit for whitelisted user')
      return
    end
  end
  -- Now create all ratelimit prefixes
  local prefixes = {}
  local nprefixes = 0

  for k,v in pairs(settings.limits) do
    nprefixes = nprefixes + limit_to_prefixes(task, k, v, prefixes)
  end

  for k, hdl in pairs(settings.custom_keywords or E) do
    local ret, redis_key, bd = pcall(hdl, task)

    if ret then
      local bucket = parse_limit(k, bd)
      if bucket[1] then
        prefixes[redis_key] = make_prefix(redis_key, k, bucket[1])
      end
      nprefixes = nprefixes + 1
    else
      rspamd_logger.errx(task, 'cannot call handler for %s: %s',
          k, redis_key)
    end
  end

  local function gen_check_cb(prefix, bucket, lim_name)
    return function(err, data)
      if err then
        rspamd_logger.errx('cannot check limit %s: %s %s', prefix, err, data)
      elseif type(data) == 'table' and data[1] and data[1] == 1 then
        -- set symbol only and do NOT soft reject
        if settings.symbol then
          task:insert_result(settings.symbol, 0.0, lim_name .. "(" .. prefix .. ")")
          rspamd_logger.infox(task,
              'set_symbol_only: ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn)',
              lim_name, prefix,
              bucket.burst, bucket.rate,
              data[2], data[3], data[4])
          return
        -- set INFO symbol and soft reject
        elseif settings.info_symbol then
          task:insert_result(settings.info_symbol, 1.0,
              lim_name .. "(" .. prefix .. ")")
        end
        rspamd_logger.infox(task,
            'ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn)',
            lim_name, prefix,
            bucket.burst, bucket.rate,
            data[2], data[3], data[4])
        task:set_pre_result('soft reject',
                message_func(task, lim_name, prefix, bucket))
      end
    end
  end

  -- Don't do anything if pre-result has been already set
  if task:has_pre_result() then return end

  if nprefixes > 0 then
    -- Save prefixes to the cache to allow update
    task:cache_set('ratelimit_prefixes', prefixes)
    local now = rspamd_util.get_time()
    now = lua_util.round(now * 1000.0) -- Get milliseconds
    -- Now call check script for all defined prefixes

    for pr,value in pairs(prefixes) do
      local bucket = value.bucket
      local rate = (bucket.rate) / 1000.0 -- Leak rate in messages/ms
      rspamd_logger.debugm(N, task, "check limit %s:%s -> %s (%s/%s)",
          value.name, pr, value.hash, bucket.burst, bucket.rate)
      lua_redis.exec_redis_script(bucket_check_id,
              {key = value.hash, task = task, is_write = true},
              gen_check_cb(pr, bucket, value.name),
              {value.hash, tostring(now), tostring(rate), tostring(bucket.burst),
                  tostring(settings.expire)})
    end
  end
end

local function ratelimit_update_cb(task)
  local prefixes = task:cache_get('ratelimit_prefixes')

  if prefixes then
    if task:has_pre_result() then
      -- Already rate limited/greylisted, do nothing
      rspamd_logger.debugm(N, task, 'pre-action has been set, do not update')
      return
    end

    local is_spam = not (task:get_metric_action() == 'no action')

    -- Update each bucket
    for k, v in pairs(prefixes) do
      local bucket = v.bucket
      local function update_bucket_cb(err, data)
        if err then
          rspamd_logger.errx(task, 'cannot update rate bucket %s: %s',
                  k, err)
        else
          rspamd_logger.debugm(N, task,
              "updated limit %s:%s -> %s (%s/%s), burst: %s, dyn_rate: %s, dyn_burst: %s",
              v.name, k, v.hash,
              bucket.burst, bucket.rate,
              data[1], data[2], data[3])
        end
      end
      local now = rspamd_util.get_time()
      now = lua_util.round(now * 1000.0) -- Get milliseconds
      local mult_burst = bucket.ham_factor_burst or 1.0
      local mult_rate = bucket.ham_factor_burst or 1.0

      if is_spam then
        mult_burst = bucket.spam_factor_burst or 1.0
        mult_rate = bucket.spam_factor_rate or 1.0
      end

      lua_redis.exec_redis_script(bucket_update_id,
              {key = v.hash, task = task, is_write = true},
              update_bucket_cb,
              {v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst),
               tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult),
               tostring(settings.expire)})
    end
  end
end

local opts = rspamd_config:get_all_opt(N)
if opts then

  settings = lua_util.override_defaults(settings, opts)

  if opts['limit'] then
    rspamd_logger.errx(rspamd_config, 'Legacy ratelimit config format no longer supported')
  end

  if opts['rates'] and type(opts['rates']) == 'table' then
    -- new way of setting limits
    fun.each(function(t, lim)
      local buckets = parse_limit(t, lim)

      if buckets and #buckets > 0 then
        settings.limits[t] = buckets
      end
    end, opts['rates'])
  end

  local enabled_limits = fun.totable(fun.map(function(t)
    return t
  end, settings.limits))
  rspamd_logger.infox(rspamd_config,
          'enabled rate buckets: [%1]', table.concat(enabled_limits, ','))

  -- Ret, ret, ret: stupid legacy stuff:
  -- If we have a string with commas then load it as as static map
  -- otherwise, apply normal logic of Rspamd maps

  local wrcpts = opts['whitelisted_rcpts']
  if type(wrcpts) == 'string' then
    if string.find(wrcpts, ',') then
      settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(
        lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts')
    else
      settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',
        'Ratelimit whitelisted rcpts')
    end
  elseif type(opts['whitelisted_rcpts']) == 'table' then
    settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',
      'Ratelimit whitelisted rcpts')
  else
    -- Stupid default...
    settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(
        settings.whitelisted_rcpts, 'set', 'Ratelimit whitelisted rcpts')
  end

  if opts['whitelisted_ip'] then
    settings.whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix',
      'Ratelimit whitelist ip map')
  end

  if opts['whitelisted_user'] then
    settings.whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set',
      'Ratelimit whitelist user map')
  end

  settings.custom_keywords = {}
  if opts['custom_keywords'] then
    local ret, res_or_err = pcall(loadfile(opts['custom_keywords']))

    if ret then
      opts['custom_keywords'] = {}
      if type(res_or_err) == 'table' then
        for k,hdl in pairs(res_or_err) do
          settings['custom_keywords'][k] = hdl
        end
      elseif type(res_or_err) == 'function' then
        settings['custom_keywords']['custom'] = res_or_err
      end
    else
      rspamd_logger.errx(rspamd_config, 'cannot execute %s: %s',
          opts['custom_keywords'], res_or_err)
      settings['custom_keywords'] = {}
    end
  end

  if opts['message_func'] then
    message_func = assert(load(opts['message_func']))()
  end

  redis_params = lua_redis.parse_redis_server('ratelimit')

  if not redis_params then
    rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module')
    lua_util.disable_module(N, "redis")
  else
    local s = {
      type = 'prefilter,nostat',
      name = 'RATELIMIT_CHECK',
      priority = 7,
      callback = ratelimit_cb,
      flags = 'empty',
    }

    if settings.symbol then
      s.name = settings.symbol
    elseif settings.info_symbol then
      s.name = settings.info_symbol
    end

    rspamd_config:register_symbol(s)
    rspamd_config:register_symbol {
      type = 'idempotent',
      name = 'RATELIMIT_UPDATE',
      callback = ratelimit_update_cb,
    }
  end
end

rspamd_config:add_on_load(function(cfg, ev_base, worker)
  load_scripts(cfg, ev_base)
end)