解决$在mybatis中权限注入导致的安全问题文章来源:https://www.toymoban.com/news/detail-667098.html
@Slf4j
@Component
public class MybatisSecureProcessor implements BeanPostProcessor {
//自定义的符号
private char customToken = '@';
//Mybatis的配置变量
private Properties mybatisVariables;
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
//拦截SqlSessionFactory
if (! (bean instanceof SqlSessionFactory)) {
return bean;
}
SqlSessionFactory sqlSessionFactory = (SqlSessionFactory) bean;
Configuration configuration = sqlSessionFactory.getConfiguration();
//获取所有sql片段,这里不能是Collection<MappedStatement>, 因为里面可能有Ambiguity
Collection<?> mappedStatements = configuration.getMappedStatements();
//获取所有配置的变量
Properties variables = configuration.getVariables();
if (variables != null && variables.size() != 0) {
this.mybatisVariables = variables;
}
//这里面会有Ambiguity
for (Object item : mappedStatements) {
if (! (item instanceof MappedStatement)) {
continue;
}
MappedStatement mappedStatement = (MappedStatement) item;
try {
SqlSource sqlSource = mappedStatement.getSqlSource();
//我们只处理动态sql
if (sqlSource instanceof DynamicSqlSource) {
handleDynamicSqlSource((DynamicSqlSource) sqlSource);
} else if (sqlSource instanceof RawSqlSource) {
handleRawSqlSource(mappedStatement,(RawSqlSource) sqlSource);
} else {
log.error("不支持的mapper" + sqlSource.getClass().getSimpleName() + " unhandled");
}
} catch (Exception e) {
e.printStackTrace();
}
}
return bean;
}
private void handleDynamicSqlSource(DynamicSqlSource sqlSource) {
try {
Field rootSqlNodeField = DynamicSqlSource.class.getDeclaredField("rootSqlNode");
rootSqlNodeField.setAccessible(true);
SqlNode rootSqlNode = (SqlNode) rootSqlNodeField.get(sqlSource);
iterateSqlNode(rootSqlNode, sqlSource, rootSqlNodeField);
} catch (Exception e) {
e.printStackTrace();
}
}
private void handleRawSqlSource(MappedStatement mappedStatement,RawSqlSource sqlSource) throws NoSuchFieldException, IllegalAccessException {
//有兴趣可以实现一下
//改造方法:1、获取 sqlsource StaticSqlSource;
Field sqlSource1 = sqlSource.getClass().getDeclaredField("sqlSource");
sqlSource1.setAccessible(true);
Object o = sqlSource1.get(sqlSource);
Field sql1 = o.getClass().getDeclaredField("sql");
sql1.setAccessible(true);
BoundSql boundSql = sqlSource.getBoundSql(mappedStatement);
String sql = boundSql.getSql();
if (sql.contains(customToken+"{")) {
String newSql = tryParseCustomToken(sql);
sql1.set(o, newSql);
//将存放此sqlNode的地方换成TextSqlNode
TextSqlNode textSqlNode = new TextSqlNode(newSql);
DynamicSqlSource dynamicSqlSource = new DynamicSqlSource(mappedStatement.getConfiguration(),textSqlNode);
Field sqlSourceAsp = mappedStatement.getClass().getDeclaredField("sqlSource");
sqlSourceAsp.setAccessible(true);
//替换数据源
sqlSourceAsp.set(mappedStatement,dynamicSqlSource);
}
}
*
* 遍历所有sqlNode。
* @param sqlNode 需要被遍历的sqlNode。
* @param target 需要被遍历的sqlNode所在的对象。
* @param field 需要被遍历的sqlNode所在的字段。
* @throws Exception
@SuppressWarnings("unchecked")
private void iterateSqlNode(SqlNode sqlNode, Object target, Field field) throws Exception {
if (sqlNode instanceof MixedSqlNode) {
Field contentsField = MixedSqlNode.class.getDeclaredField("contents");
contentsField.setAccessible(true);
List<SqlNode> contents = (List<SqlNode>) contentsField.get(sqlNode);
for (SqlNode n : contents) {
iterateSqlNode(n, sqlNode, contentsField);
}
} else if (sqlNode instanceof StaticTextSqlNode) {
Field textField = StaticTextSqlNode.class.getDeclaredField("text");
textField.setAccessible(true);
String text = (String) textField.get(sqlNode);
String afterParsed = tryParseCustomToken(text);
if (afterParsed == null) {
return;
}
//将存放此sqlNode的地方换成TextSqlNode
TextSqlNode textSqlNode = new TextSqlNode(afterParsed);
saveNewNode(textSqlNode, sqlNode, target, field);
} else if (sqlNode instanceof ForEachSqlNode) {
Field contentsField = ForEachSqlNode.class.getDeclaredField("contents");
contentsField.setAccessible(true);
SqlNode contents = (SqlNode) contentsField.get(sqlNode);
iterateSqlNode(contents, sqlNode, contentsField);
} else if (sqlNode instanceof IfSqlNode) {
Field contentsField = IfSqlNode.class.getDeclaredField("contents");
contentsField.setAccessible(true);
SqlNode contents = (SqlNode) contentsField.get(sqlNode);
iterateSqlNode(contents, sqlNode, contentsField);
}
//TODO ...处理其他你需要处理的类型
}
//寻找自定义占位符
private String tryParseCustomToken(String text) {
List<TokenRecord> records = new LinkedList<>();
int strLength = text.length();
findCustom: for (int i = 0; i < strLength; i++) {
if (text.charAt(i) == this.customToken && text.charAt(i + 1) == '{') {
for (int j = i + 2; j < strLength; j++) {
if (text.charAt(j) == '}') {
TokenRecord record = new TokenRecord(i, i + 1, j);
String placeholderName = text.substring(i + 2, j);
record.placeholderName = placeholderName;
//注意,Mybatis会在首次加载Mapper的时候,把配置变量中存在的占位符先替换掉,而不是等到SQL执行的时候再替换
//例如,配置文件中有mybatis.configuration.variables.abc=xxx
//那么,Mybatis初始化的时候会把Mapper中的所有${abc}替换为xxx
//我们使用了自定义占位符,所以要替Mybatis完成这一步
if (this.mybatisVariables != null) {
record.placeholderValue = this.mybatisVariables.getProperty(placeholderName);
}
records.add(record);
continue findCustom;
}
}
throw new IllegalStateException("token not match");
}
}
if (records.isEmpty()) {
return null;
}
//生成替换后的sql字符串
StringBuilder builder = new StringBuilder();
//原text中将要被拼接的字符索引
int appendIndex = 0;
for (TokenRecord tr : records) {
builder.append(text, appendIndex, tr.customTokenIndex);
if (tr.placeholderValue != null) {
builder.append(tr.placeholderValue);
} else {
builder.append('$').append('{');
builder.append(tr.placeholderName);
builder.append('}');
}
appendIndex = tr.endBracketIndex + 1;
}
builder.append(text, appendIndex, text.length() - 1);
return builder.toString();
}
//把替换后的sqlNode保存到对应字段
@SuppressWarnings("unchecked")
public void saveNewNode(TextSqlNode newNode, SqlNode oldNode, Object targetObject, Field field) throws Exception {
Class<?> fieldType = field.getType();
if (List.class.isAssignableFrom(fieldType)) {
List<Object> list = (List<Object>) field.get(targetObject);
for (int i = 0; i < list.size(); i++) {
if (list.get(i) == oldNode) {
list.set(i, newNode);
return;
}
}
} else if (SqlNode.class.isAssignableFrom(fieldType)) {
field.set(targetObject, newNode);
}
}
//记录占位符替换信息的结构体
private static class TokenRecord {
int customTokenIndex;
int startBracketIndex;
int endBracketIndex;
String placeholderName;
String placeholderValue;
TokenRecord(int customTokenIndex, int startBracketIndex, int endBracketIndex) {
this.customTokenIndex = customTokenIndex;
this.startBracketIndex = startBracketIndex;
this.endBracketIndex = endBracketIndex;
}
}
}
文章来源地址https://www.toymoban.com/news/detail-667098.html
到了这里,关于Mybatils 中使用$代码逃避扫描漏洞的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!