package com.mzl.flower.config; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.mzl.flower.mapper.ip.BlackListMapper; import com.mzl.flower.entity.ip.BlackList; import com.mzl.flower.utils.IpUtil; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.http.HttpStatus; import org.springframework.stereotype.Component; import org.springframework.web.servlet.HandlerInterceptor; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.time.LocalDateTime; import java.time.ZoneOffset; import java.util.Iterator; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @Component @Slf4j public class SlidingWindowInterceptor implements HandlerInterceptor { /** * 限流周期 */ private int RATE_CYCLE = 10; /** * 单位时间划分的小周期 */ private int SUB_CYCLE = 10; /** * 限流请求数 */ private int thresholdPerCycle = 50; /** * 计数器, k-ip,value-当前窗口的开始时间值秒及当前窗口的计数 */ private final ConcurrentHashMap> ipCounters = new ConcurrentHashMap<>(); private final BlackListMapper blackListMapper; public SlidingWindowInterceptor(BlackListMapper blackListMapper) { this.blackListMapper = blackListMapper; } @Override public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) { LocalDateTime now = LocalDateTime.now(); String ip = null; try { ip = IpUtil.getIpAddress(request); } catch (IOException e) { log.error("获取ip失败",e); ip = request.getRemoteAddr(); } if(checkBlackList(ip)){ //验证黑名单 try { response.sendError(HttpStatus.SERVICE_UNAVAILABLE.value(), "IP被禁用,请联系管理员!"); } catch (IOException e) { throw new RuntimeException(e); } log.error("禁用IP发送请求被拦截,请求ip="+ip); } boolean result = slidingWindowsTryAcquire(ip,now); if(!result){ try { response.sendError(HttpStatus.TOO_MANY_REQUESTS.value(), "系统操作太频繁,请稍后再试!"); log.error("请求太频繁ip="+ip); } catch (IOException e) { throw new RuntimeException(e); } return false; } return true; } private boolean checkBlackList(String ip) { if (StringUtils.isNotBlank(ip)) { return blackListMapper.selectCount(new LambdaQueryWrapper().eq(BlackList::getIp, ip)) > 0; } return false; } /** * 滑动窗口时间算法实现 */ boolean slidingWindowsTryAcquire(String ip, LocalDateTime now) { long currentWindowTime = now.toEpochSecond(ZoneOffset.UTC) / SUB_CYCLE * SUB_CYCLE; //获取当前时间在哪个小周期窗口 int currentWindowNum = countCurrentWindow(currentWindowTime,ip); //当前窗口总请求数 //超过阀值限流 if (currentWindowNum >= thresholdPerCycle) { return false; } //计数器+1 ConcurrentHashMap counters = ipCounters.get(ip); if(counters == null){ counters = new ConcurrentHashMap<>(); } counters.put(currentWindowTime, counters.getOrDefault(currentWindowTime, 0) + 1); ipCounters.put(ip,counters); return true; } /** * 统计当前窗口的请求数 */ private int countCurrentWindow(long currentWindowTime,String ip) { //计算窗口开始位置 long startTime = currentWindowTime - SUB_CYCLE * (RATE_CYCLE / SUB_CYCLE-1); int count = 0; //遍历存储的计数器 ConcurrentHashMap counters = ipCounters.get(ip); if(counters == null){ return 0; } Iterator> iterator = counters.entrySet().iterator(); while (iterator.hasNext()) { Map.Entry entry = iterator.next(); // 删除无效过期的子窗口计数器 if (entry.getKey() < startTime) { iterator.remove(); } else { //累加当前窗口的所有计数器之和 count =count + entry.getValue(); } } return count; } }