1
zhujie
5 天以前 ec15861e14c66c38b1a8f5fffc6975d7da6c315c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package com.mzl.flower.config;
 
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.mzl.flower.mapper.ip.BlackListMapper;
import com.mzl.flower.entity.ip.BlackList;
import com.mzl.flower.utils.IpUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;
 
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
 
@Component
@Slf4j
public class SlidingWindowInterceptor implements HandlerInterceptor {
 
    /**
     * 限流周期
     */
    private int RATE_CYCLE = 10;
 
    /**
     * 单位时间划分的小周期
     */
    private int SUB_CYCLE = 10;
    /**
     * 限流请求数
     */
    private int thresholdPerCycle = 50;
    /**
     * 计数器, k-ip,value-当前窗口的开始时间值秒及当前窗口的计数
     */
    private final ConcurrentHashMap<String, ConcurrentHashMap<Long, Integer>> ipCounters = new ConcurrentHashMap<>();
 
 
    private final BlackListMapper blackListMapper;
 
    public SlidingWindowInterceptor(BlackListMapper blackListMapper) {
        this.blackListMapper = blackListMapper;
    }
 
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) {
        LocalDateTime now = LocalDateTime.now();
        String ip = null;
        try {
            ip = IpUtil.getIpAddress(request);
        } catch (IOException e) {
            log.error("获取ip失败",e);
            ip = request.getRemoteAddr();
        }
        if(checkBlackList(ip)){ //验证黑名单
            try {
                response.sendError(HttpStatus.SERVICE_UNAVAILABLE.value(), "IP被禁用,请联系管理员!");
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
            log.error("禁用IP发送请求被拦截,请求ip="+ip);
        }
        boolean result = slidingWindowsTryAcquire(ip,now);
        if(!result){
            try {
                response.sendError(HttpStatus.TOO_MANY_REQUESTS.value(), "系统操作太频繁,请稍后再试!");
                log.error("请求太频繁ip="+ip);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
            return false;
        }
        return true;
    }
 
    private boolean checkBlackList(String ip) {
        if (StringUtils.isNotBlank(ip)) {
            return blackListMapper.selectCount(new LambdaQueryWrapper<BlackList>().eq(BlackList::getIp, ip)) > 0;
        }
       return false;
    }
 
    /**
     * 滑动窗口时间算法实现
     */
    boolean slidingWindowsTryAcquire(String ip, LocalDateTime now) {
        long currentWindowTime = now.toEpochSecond(ZoneOffset.UTC) / SUB_CYCLE * SUB_CYCLE; //获取当前时间在哪个小周期窗口
        int currentWindowNum = countCurrentWindow(currentWindowTime,ip); //当前窗口总请求数
        //超过阀值限流
        if (currentWindowNum >= thresholdPerCycle) {
            return false;
        }
        //计数器+1
        ConcurrentHashMap<Long, Integer> counters = ipCounters.get(ip);
        if(counters == null){
            counters = new ConcurrentHashMap<>();
        }
        counters.put(currentWindowTime, counters.getOrDefault(currentWindowTime, 0) + 1);
        ipCounters.put(ip,counters);
        return true;
    }
    /**
     * 统计当前窗口的请求数
     */
    private int countCurrentWindow(long currentWindowTime,String ip) {
        //计算窗口开始位置
        long startTime = currentWindowTime - SUB_CYCLE * (RATE_CYCLE / SUB_CYCLE-1);
        int count = 0;
        //遍历存储的计数器
        ConcurrentHashMap<Long, Integer> counters = ipCounters.get(ip);
        if(counters == null){
            return 0;
        }
        Iterator<Map.Entry<Long, Integer>> iterator = counters.entrySet().iterator();
        while (iterator.hasNext()) {
            Map.Entry<Long, Integer> entry = iterator.next();
            // 删除无效过期的子窗口计数器
            if (entry.getKey() < startTime) {
                iterator.remove();
            } else {
                //累加当前窗口的所有计数器之和
                count =count + entry.getValue();
            }
        }
        return count;
    }
}