Enhance-Request-And-Response
前言
基于过滤器实现请求响应的增强
Filter 拦截请求
import com.light.cloud.common.web.enhance.EnhanceHandler;
import com.light.cloud.common.web.enhance.EnhanceHandlerProxy;
import com.light.cloud.common.web.enhance.RequestWrapper;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 请求增强过滤器 <p>
* <ul>
* <li>1. 拦截请求,并处理</li>
* </ul>
*/
@Slf4j
public class RequestEnhanceFilter extends OncePerRequestFilter {
private final EnhanceHandlerProxy enhanceHandlerProxy;
public RequestEnhanceFilter(List<EnhanceHandler> enhanceHandlers) {
this.enhanceHandlerProxy = new EnhanceHandlerProxy(enhanceHandlers);
}
/**
* 对请求进行增强
*
* @param request 请求对象
* @param response 响应对象
* @param filterChain 过滤器链
* @throws ServletException
* @throws IOException
*/
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
// 包装请求 响应对象,方便获取参数
ContentCachingRequestWrapper requestWrapper = request instanceof ContentCachingRequestWrapper ?
(ContentCachingRequestWrapper) request : new ContentCachingRequestWrapper(request);
ContentCachingResponseWrapper responseWrapper = response instanceof ContentCachingResponseWrapper ?
(ContentCachingResponseWrapper) response : new ContentCachingResponseWrapper(response);
// 再次包装请求对象
RequestWrapper wrapperRequest = new RequestWrapper(requestWrapper);
// 请求预处理
Map<String, Object> context = new HashMap<>();
context.put(EnhanceHandler.CONTEXT_KEY_URL, request.getRequestURL());
context.put(EnhanceHandler.CONTEXT_KEY_URI, request.getRequestURI());
enhanceHandlerProxy.preHandle(wrapperRequest, context);
// 执行目标方法
filterChain.doFilter(wrapperRequest, responseWrapper);
// 写回响应值
byte[] contentAsByteArray = responseWrapper.getContentAsByteArray();
// Note: 此处需要使用原始的 HttpServletResponse 对象,不能使用 Wrapper的对象,否则会导致响应值丢失
ServletOutputStream outputStream = response.getOutputStream();
outputStream.write(contentAsByteArray);
outputStream.flush();
outputStream.close();
// 请求后处理
enhanceHandlerProxy.postHandle(responseWrapper, context);
}
}
RequestWrapper 包装请求,可以多次读取 InputStream
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.MediaType;
import org.springframework.web.util.ContentCachingRequestWrapper;
import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Enumeration;
import java.util.Map;
/**
* 请求缓存对象,方便多次读取请求中的参数 <p>
* <p>
* 1. {@link ContentCachingRequestWrapper}可以直接缓存请求URL上的参数,即 {@link ContentCachingRequestWrapper#getParameterMap()}<p>
* 2. 由于请求体Body中的参数是流式传输,{@link ContentCachingRequestWrapper}无法直接缓存,必须是在{@link ContentCachingRequestWrapper#inputStream}的内容使用过后才能缓存,
* 下次需要再使用body只能使用此方法{@link ContentCachingRequestWrapper#getContentAsByteArray()}才能再次获取body中的值。<p>
* 3. 在一些场景下(过滤器,拦截器),需要在请求之前获取请求体中的参数,所以需要单独使用此类来进行解析缓存。<p>
*
* @see {@link ContentCachingRequestWrapper}
*/
@Slf4j
public class RequestWrapper extends HttpServletRequestWrapper {
private final String body;
public RequestWrapper(HttpServletRequest request) {
super(request);
this.body = parseStringBody(request);
}
// region implements ServletRequest
@Override
public ServletInputStream getInputStream() throws IOException {
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(body.getBytes(StandardCharsets.UTF_8));
ServletInputStream servletInputStream = new ServletInputStream() {
@Override
public int read() throws IOException {
return byteArrayInputStream.read();
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener listener) {
}
};
return servletInputStream;
}
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(this.getInputStream()));
}
@Override
public String getParameter(String name) {
return super.getParameter(name);
}
@Override
public Map<String, String[]> getParameterMap() {
return super.getParameterMap();
}
@Override
public Enumeration<String> getParameterNames() {
return super.getParameterNames();
}
@Override
public String[] getParameterValues(String name) {
return super.getParameterValues(name);
}
// endregion
public String getBody() {
return this.body;
}
public byte[] getContentAsByteArray() {
return this.body.getBytes(StandardCharsets.UTF_8);
}
/**
* 解析请求的body
*/
private String parseStringBody(ServletRequest request) {
if (!shouldParse(request)) {
return null;
}
try (
InputStream inputStream = request.getInputStream();
InputStreamReader inputStreamReader = new InputStreamReader(inputStream, StandardCharsets.UTF_8);
BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
) {
StringBuilder builder = new StringBuilder();
char[] buffer = new char[512];
int len = 0;
while ((len = bufferedReader.read(buffer)) > 0) {
builder.append(buffer, 0, len);
}
return builder.toString();
} catch (IOException e) {
log.error("读取请求 参数失败!", e);
}
return null;
}
/**
* 解析请求的body
*/
private byte[] parseBytesBody(ServletRequest request) {
if (!shouldParse(request)) {
return null;
}
try (
InputStream inputStream = request.getInputStream();
BufferedInputStream bufferedInputStream = new BufferedInputStream(inputStream);
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
) {
byte[] buffer = new byte[1024];
int len = 0;
while ((len = bufferedInputStream.read(buffer)) > 0) {
byteArrayOutputStream.write(buffer, 0, len);
}
return byteArrayOutputStream.toByteArray();
} catch (IOException e) {
log.error("读取请求参数失败!", e);
}
return null;
}
private boolean shouldParse(ServletRequest request) {
String contentType = request.getContentType();
if (StringUtils.isNotBlank(contentType)) {
try {
MediaType mediaType = MediaType.parseMediaType(contentType);
return MediaType.APPLICATION_JSON.includes(mediaType);
} catch (IllegalArgumentException ex) {
}
}
return false;
}
}