package com.mzl.flower.config;
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.mzl.flower.mapper.ip.BlackListMapper;
|
import com.mzl.flower.entity.ip.BlackList;
|
import com.mzl.flower.utils.IpUtil;
|
import lombok.extern.slf4j.Slf4j;
|
import org.apache.commons.lang3.StringUtils;
|
import org.springframework.http.HttpStatus;
|
import org.springframework.stereotype.Component;
|
import org.springframework.web.servlet.HandlerInterceptor;
|
|
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletResponse;
|
import java.io.IOException;
|
import java.time.LocalDateTime;
|
import java.time.ZoneOffset;
|
import java.util.Iterator;
|
import java.util.Map;
|
import java.util.concurrent.ConcurrentHashMap;
|
|
@Component
|
@Slf4j
|
public class SlidingWindowInterceptor implements HandlerInterceptor {
|
|
/**
|
* 限流周期
|
*/
|
private int RATE_CYCLE = 10;
|
|
/**
|
* 单位时间划分的小周期
|
*/
|
private int SUB_CYCLE = 10;
|
/**
|
* 限流请求数
|
*/
|
private int thresholdPerCycle = 50;
|
/**
|
* 计数器, k-ip,value-当前窗口的开始时间值秒及当前窗口的计数
|
*/
|
private final ConcurrentHashMap<String, ConcurrentHashMap<Long, Integer>> ipCounters = new ConcurrentHashMap<>();
|
|
|
private final BlackListMapper blackListMapper;
|
|
public SlidingWindowInterceptor(BlackListMapper blackListMapper) {
|
this.blackListMapper = blackListMapper;
|
}
|
|
@Override
|
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
|
LocalDateTime now = LocalDateTime.now();
|
String ip = null;
|
try {
|
ip = IpUtil.getIpAddress(request);
|
} catch (IOException e) {
|
log.error("获取ip失败",e);
|
ip = request.getRemoteAddr();
|
}
|
if(checkBlackList(ip)){ //验证黑名单
|
try {
|
response.sendError(HttpStatus.SERVICE_UNAVAILABLE.value(), "IP被禁用,请联系管理员!");
|
} catch (IOException e) {
|
throw new RuntimeException(e);
|
}
|
log.error("禁用IP发送请求被拦截,请求ip="+ip);
|
}
|
boolean result = slidingWindowsTryAcquire(ip,now);
|
if(!result){
|
try {
|
response.sendError(HttpStatus.TOO_MANY_REQUESTS.value(), "系统操作太频繁,请稍后再试!");
|
log.error("请求太频繁ip="+ip);
|
} catch (IOException e) {
|
throw new RuntimeException(e);
|
}
|
return false;
|
}
|
return true;
|
}
|
|
private boolean checkBlackList(String ip) {
|
if (StringUtils.isNotBlank(ip)) {
|
return blackListMapper.selectCount(new LambdaQueryWrapper<BlackList>().eq(BlackList::getIp, ip)) > 0;
|
}
|
return false;
|
}
|
|
/**
|
* 滑动窗口时间算法实现
|
*/
|
boolean slidingWindowsTryAcquire(String ip, LocalDateTime now) {
|
long currentWindowTime = now.toEpochSecond(ZoneOffset.UTC) / SUB_CYCLE * SUB_CYCLE; //获取当前时间在哪个小周期窗口
|
int currentWindowNum = countCurrentWindow(currentWindowTime,ip); //当前窗口总请求数
|
//超过阀值限流
|
if (currentWindowNum >= thresholdPerCycle) {
|
return false;
|
}
|
//计数器+1
|
ConcurrentHashMap<Long, Integer> counters = ipCounters.get(ip);
|
if(counters == null){
|
counters = new ConcurrentHashMap<>();
|
}
|
counters.put(currentWindowTime, counters.getOrDefault(currentWindowTime, 0) + 1);
|
ipCounters.put(ip,counters);
|
return true;
|
}
|
/**
|
* 统计当前窗口的请求数
|
*/
|
private int countCurrentWindow(long currentWindowTime,String ip) {
|
//计算窗口开始位置
|
long startTime = currentWindowTime - SUB_CYCLE * (RATE_CYCLE / SUB_CYCLE-1);
|
int count = 0;
|
//遍历存储的计数器
|
ConcurrentHashMap<Long, Integer> counters = ipCounters.get(ip);
|
if(counters == null){
|
return 0;
|
}
|
Iterator<Map.Entry<Long, Integer>> iterator = counters.entrySet().iterator();
|
while (iterator.hasNext()) {
|
Map.Entry<Long, Integer> entry = iterator.next();
|
// 删除无效过期的子窗口计数器
|
if (entry.getKey() < startTime) {
|
iterator.remove();
|
} else {
|
//累加当前窗口的所有计数器之和
|
count =count + entry.getValue();
|
}
|
}
|
return count;
|
}
|
}
|