package com.mzl.flower.config; import com.mzl.flower.config.exception.SelfAuth2Exception; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.core.AuthenticationException; import org.springframework.security.oauth2.common.DefaultThrowableAnalyzer; import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.common.exceptions.InsufficientScopeException; import org.springframework.security.oauth2.common.exceptions.InvalidGrantException; import org.springframework.security.oauth2.common.exceptions.OAuth2Exception; import org.springframework.security.oauth2.provider.error.WebResponseExceptionTranslator; import org.springframework.security.web.util.ThrowableAnalyzer; import org.springframework.web.HttpRequestMethodNotSupportedException; import java.io.IOException; public class SelfWebResponseExceptionTranslator implements WebResponseExceptionTranslator { private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer(); @Override public ResponseEntity translate(Exception e) throws Exception { Throwable[] causeChain = throwableAnalyzer.determineCauseChain(e); Exception ase = (AuthenticationException) throwableAnalyzer.getFirstThrowableOfType(AuthenticationException.class, causeChain); if (ase != null) { return handleOAuth2Exception(new UnauthorizedException(e.getMessage(), e)); } ase = (AccessDeniedException) throwableAnalyzer .getFirstThrowableOfType(AccessDeniedException.class, causeChain); if (ase instanceof AccessDeniedException) { return handleOAuth2Exception(new ForbiddenException(ase.getMessage(), ase)); } ase = (InvalidGrantException) throwableAnalyzer .getFirstThrowableOfType(InvalidGrantException.class, causeChain); if (ase != null) { return handleOAuth2Exception(new InvalidException(ase.getMessage(), ase)); } ase = (AccessDeniedException) throwableAnalyzer .getFirstThrowableOfType(AccessDeniedException.class, causeChain); if (ase instanceof AccessDeniedException) { return handleOAuth2Exception(new ForbiddenException(ase.getMessage(), ase)); } ase = (HttpRequestMethodNotSupportedException) throwableAnalyzer.getFirstThrowableOfType( HttpRequestMethodNotSupportedException.class, causeChain); if (ase instanceof HttpRequestMethodNotSupportedException) { return handleOAuth2Exception(new MethodNotAllowed(ase.getMessage(), ase)); } ase = (OAuth2Exception) throwableAnalyzer.getFirstThrowableOfType(OAuth2Exception.class, causeChain); if (ase != null) { return handleOAuth2Exception((OAuth2Exception) ase); } return handleOAuth2Exception(new ServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase(), e)); } private ResponseEntity handleOAuth2Exception(OAuth2Exception e) throws IOException { int status = e.getHttpErrorCode(); HttpHeaders headers = new HttpHeaders(); headers.set("Cache-Control", "no-store"); headers.set("Pragma", "no-cache"); if (status == HttpStatus.UNAUTHORIZED.value() || (e instanceof InsufficientScopeException)) { headers.set("WWW-Authenticate", String.format("%s %s", OAuth2AccessToken.BEARER_TYPE, e.getSummary())); } ResponseEntity response = new ResponseEntity(new SelfAuth2Exception(e.getMessage(), String.valueOf(e.getHttpErrorCode())), headers, HttpStatus.valueOf(status)); return response; } private static class InvalidException extends SelfAuth2Exception { public InvalidException(String msg, Throwable t) { super(msg, t); } @Override public String getOAuth2ErrorCode() { return "invalid_exception"; } @Override public int getHttpErrorCode() { return 426; } } private static class ForbiddenException extends SelfAuth2Exception { public ForbiddenException(String msg, Throwable t) { super(msg, t); } @Override public String getOAuth2ErrorCode() { return "access_denied"; } @Override public int getHttpErrorCode() { return 403; } } private static class ServerErrorException extends SelfAuth2Exception { public ServerErrorException(String msg, Throwable t) { super(msg, t); } @Override public String getOAuth2ErrorCode() { return "server_error"; } @Override public int getHttpErrorCode() { return 500; } } private static class UnauthorizedException extends SelfAuth2Exception { public UnauthorizedException(String msg, Throwable t) { super(msg, t); } @Override public String getOAuth2ErrorCode() { return "unauthorized"; } @Override public int getHttpErrorCode() { return 401; } } private static class MethodNotAllowed extends SelfAuth2Exception { public MethodNotAllowed(String msg, Throwable t) { super(msg, t); } @Override public String getOAuth2ErrorCode() { return "method_not_allowed"; } @Override public int getHttpErrorCode() { return 405; } } }