新增mybatis语法的<>自动转义,优化mybatis语法和?{}不兼容的问题

This commit is contained in:
mxd
2021-12-30 20:45:41 +08:00
parent 1b91c08ee8
commit 5bda73696b
8 changed files with 91 additions and 112 deletions

View File

@@ -5,18 +5,12 @@ import org.ssssssss.magicapi.interceptor.SQLInterceptor;
import org.ssssssss.magicapi.model.RequestEntity;
import org.ssssssss.magicapi.modules.mybatis.MybatisParser;
import org.ssssssss.magicapi.modules.mybatis.SqlNode;
import org.ssssssss.script.MagicScriptContext;
import org.ssssssss.script.functions.StreamExtension;
import org.ssssssss.script.parsing.GenericTokenParser;
import org.ssssssss.script.parsing.ast.literal.BooleanLiteral;
import org.ssssssss.magicapi.modules.mybatis.TextSqlNode;
import org.ssssssss.script.runtime.RuntimeContext;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* SQL参数处理
@@ -25,14 +19,6 @@ import java.util.stream.IntStream;
*/
public class BoundSql {
private static final GenericTokenParser CONCAT_TOKEN_PARSER = new GenericTokenParser("${", "}", false);
private static final GenericTokenParser REPLACE_TOKEN_PARSER = new GenericTokenParser("#{", "}", true);
private static final GenericTokenParser IF_TOKEN_PARSER = new GenericTokenParser("?{", "}", true);
private static final GenericTokenParser IF_PARAM_TOKEN_PARSER = new GenericTokenParser("?{", ",", true);
private static final Pattern REPLACE_MULTI_WHITE_LINE = Pattern.compile("(\r?\n(\\s*\r?\n)+)");
private static final List<String> MYBATIS_TAGS = Arrays.asList("</where>", "</if>", "</trim>", "</set>", "</foreach>");
@@ -91,40 +77,12 @@ public class BoundSql {
this.sqlOrXml = sqlNode.getSql(varMap);
this.parameters = sqlNode.getParameters();
} else {
normal(runtimeContext, varMap);
normal(varMap);
}
}
private void normal(RuntimeContext runtimeContext, Map<String, Object> varMap) {
MagicScriptContext context = runtimeContext.getScriptContext();
// 处理?{}参数
this.sqlOrXml = IF_TOKEN_PARSER.parse(this.sqlOrXml.trim(), text -> {
AtomicBoolean ifTrue = new AtomicBoolean(false);
String val = IF_PARAM_TOKEN_PARSER.parse("?{" + text, param -> {
ifTrue.set(BooleanLiteral.isTrue(context.eval(param, varMap)));
return null;
});
return ifTrue.get() ? val : "";
});
// 处理${}参数
this.sqlOrXml = CONCAT_TOKEN_PARSER.parse(this.sqlOrXml, text -> String.valueOf(context.eval(text, varMap)));
// 处理#{}参数
this.sqlOrXml = REPLACE_TOKEN_PARSER.parse(this.sqlOrXml, text -> {
Object value = context.eval(text, varMap);
if (value == null) {
parameters.add(null);
return "?";
}
try {
//对集合自动展开
List<Object> objects = StreamExtension.arrayLikeToList(value);
parameters.addAll(objects);
return IntStream.range(0, objects.size()).mapToObj(t -> "?").collect(Collectors.joining(","));
} catch (Exception e) {
parameters.add(value);
return "?";
}
});
private void normal(Map<String, Object> varMap) {
this.sqlOrXml = TextSqlNode.parseSql(this.sqlOrXml, varMap, parameters);
this.sqlOrXml = this.sqlOrXml == null ? null : REPLACE_MULTI_WHITE_LINE.matcher(this.sqlOrXml.trim()).replaceAll("\r\n");
}

View File

@@ -64,8 +64,6 @@ public class ForeachSqlNode extends SqlNode {
if (value == null) {
return "";
}
// 开始拼接SQL,
String sql = StringUtils.defaultString(this.open);
// 如果集合是Collection对象或其子类则转成数组
if (value instanceof Collection) {
value = ((Collection) value).toArray();
@@ -74,19 +72,22 @@ public class ForeachSqlNode extends SqlNode {
if (!value.getClass().isArray()) {
return "";
}
// 开始拼接SQL,
StringBuilder sqlBuilder = new StringBuilder(StringUtils.defaultString(this.open));
// 获取数组长度
int len = Array.getLength(value);
for (int i = 0; i < len; i++) {
// 存入item对象
paramMap.put(this.item, Array.get(value, i));
// 拼接子节点
sql += executeChildren(paramMap, parameters);
sqlBuilder.append(executeChildren(paramMap, parameters));
// 拼接分隔符
if (i + 1 < len) {
sql += StringUtils.defaultString(this.separator);
sqlBuilder.append(StringUtils.defaultString(this.separator));
}
}
// 拼接结束SQL
return sql + StringUtils.defaultString(this.close);
sqlBuilder.append(StringUtils.defaultString(this.close));
return sqlBuilder.toString();
}
}
}

View File

@@ -16,7 +16,7 @@ public class IfSqlNode extends SqlNode {
/**
* 判断表达式
*/
private String test;
private final String test;
public IfSqlNode(String test) {
this.test = test;
@@ -32,4 +32,4 @@ public class IfSqlNode extends SqlNode {
}
return "";
}
}
}

View File

@@ -8,12 +8,21 @@ import org.w3c.dom.NodeList;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import java.io.ByteArrayInputStream;
import java.util.regex.Pattern;
public class MybatisParser {
private static final Pattern ESCAPE_LT_PATTERN = Pattern.compile("<([\\d'\"\\s=>#$?(])");
private static final Pattern ESCAPE_GT_PATTERN = Pattern.compile("([})\\s<\\d])>");
private static final String ESCAPE_LT_REPLACEMENT = "&lt;$1";
private static final String ESCAPE_GT_REPLACEMENT = "$1&gt;";
public static SqlNode parse(String xml) {
try {
xml = "<mybatis>" + xml + "</mybatis>";
xml = "<magic-api>" + escapeXml(xml) + "</magic-api>";
DocumentBuilder documentBuilder = DocumentBuilderFactory.newInstance().newDocumentBuilder();
Document document = documentBuilder.parse(new ByteArrayInputStream(xml.getBytes()));
SqlNode sqlNode = new TextSqlNode("");
@@ -24,6 +33,10 @@ public class MybatisParser {
}
}
private static String escapeXml(String xml) {
return ESCAPE_GT_PATTERN.matcher(ESCAPE_LT_PATTERN.matcher(xml).replaceAll(ESCAPE_LT_REPLACEMENT)).replaceAll(ESCAPE_GT_REPLACEMENT);
}
private static void parseNodeList(SqlNode sqlNode, NodeList nodeList) {
for (int i = 0, len = nodeList.getLength(); i < len; i++) {
Node node = nodeList.item(i);
@@ -82,20 +95,22 @@ public class MybatisParser {
* 解析set节点
*/
private static SetSqlNode parseSetSqlNode() {
SetSqlNode setSqlNode = new SetSqlNode();
return setSqlNode;
return new SetSqlNode();
}
/**
* 解析where节点
*/
private static WhereSqlNode parseWhereSqlNode() {
WhereSqlNode whereSqlNode = new WhereSqlNode();
return whereSqlNode;
return new WhereSqlNode();
}
private static String getNodeAttributeValue(Node node, String attributeKey) {
Node item = node.getAttributes().getNamedItem(attributeKey);
return item != null ? item.getNodeValue() : null;
}
public static void main(String[] args) {
System.out.println(escapeXml("<where> <if test=\"111\"> and 1 < 2 and 1<6 and 2>#{666}</if></where>"));
}
}

View File

@@ -5,8 +5,6 @@ import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* sql节点
@@ -15,14 +13,7 @@ import java.util.regex.Pattern;
* @version : 2020-05-18
*/
public abstract class SqlNode {
/**
* 提取#{}的正则
*/
final Pattern expressionRegx = Pattern.compile("#\\{(.*?)\\}");
/**
* 提取${}的正则
*/
final Pattern replaceRegx = Pattern.compile("\\$\\{(.*?)\\}");
/**
* 子节点
*/
@@ -56,30 +47,15 @@ public abstract class SqlNode {
* 获取子节点SQL
*/
public String executeChildren(Map<String, Object> paramMap, List<Object> parameters) {
String sql = "";
StringBuilder sqlBuilder = new StringBuilder();
for (SqlNode node : nodes) {
sql += StringUtils.defaultString(node.getSql(paramMap, parameters)) + " ";
sqlBuilder.append(StringUtils.defaultString(node.getSql(paramMap, parameters)));
sqlBuilder.append(" ");
}
return sql;
return sqlBuilder.toString();
}
public List<Object> getParameters() {
return parameters;
}
/**
* 根据正则表达式提取参数
*
* @param pattern 正则表达式
* @param sql SQL
*/
public List<String> extractParameter(Pattern pattern, String sql) {
Matcher matcher = pattern.matcher(sql);
List<String> results = new ArrayList<>();
while (matcher.find()) {
results.add(matcher.group(1));
}
return results;
}
}
}

View File

@@ -1,11 +1,15 @@
package org.ssssssss.magicapi.modules.mybatis;
import org.apache.commons.lang3.StringUtils;
import org.ssssssss.magicapi.script.ScriptManager;
import org.ssssssss.script.functions.StreamExtension;
import org.ssssssss.script.parsing.GenericTokenParser;
import org.ssssssss.script.parsing.ast.literal.BooleanLiteral;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* 普通SQL节点
@@ -14,33 +18,58 @@ import java.util.Objects;
* @version : 2020-05-18
*/
public class TextSqlNode extends SqlNode {
private static final GenericTokenParser CONCAT_TOKEN_PARSER = new GenericTokenParser("${", "}", false);
private static final GenericTokenParser REPLACE_TOKEN_PARSER = new GenericTokenParser("#{", "}", true);
private static final GenericTokenParser IF_TOKEN_PARSER = new GenericTokenParser("?{", "}", true);
private static final GenericTokenParser IF_PARAM_TOKEN_PARSER = new GenericTokenParser("?{", ",", true);
/**
* SQL
*/
private String text;
private final String text;
public TextSqlNode(String text) {
this.text = text;
}
public static String parseSql(String sql, Map<String, Object> varMap, List<Object> parameters) {
// 处理?{}参数
sql = IF_TOKEN_PARSER.parse(sql.trim(), text -> {
AtomicBoolean ifTrue = new AtomicBoolean(false);
String val = IF_PARAM_TOKEN_PARSER.parse("?{" + text, param -> {
ifTrue.set(BooleanLiteral.isTrue(ScriptManager.executeExpression(param, varMap)));
return null;
});
return ifTrue.get() ? val : "";
});
// 处理${}参数
sql = CONCAT_TOKEN_PARSER.parse(sql, text -> String.valueOf(ScriptManager.executeExpression(text, varMap)));
// 处理#{}参数
sql = REPLACE_TOKEN_PARSER.parse(sql, text -> {
Object value = ScriptManager.executeExpression(text, varMap);
if (value == null) {
parameters.add(null);
return "?";
}
try {
//对集合自动展开
List<Object> objects = StreamExtension.arrayLikeToList(value);
parameters.addAll(objects);
return IntStream.range(0, objects.size()).mapToObj(t -> "?").collect(Collectors.joining(","));
} catch (Exception e) {
parameters.add(value);
return "?";
}
});
return sql;
}
@Override
public String getSql(Map<String, Object> paramMap, List<Object> parameters) {
String sql = text;
if (StringUtils.isNotBlank(text)) {
// 提取#{}表达式
List<String> expressions = extractParameter(expressionRegx, text);
for (String expression : expressions) {
// 执行表达式
Object val = ScriptManager.executeExpression(expression, paramMap);
parameters.add(val);
sql = sql.replaceFirst(expressionRegx.pattern(), "?");
}
expressions = extractParameter(replaceRegx, text);
for (String expression : expressions) {
Object val = ScriptManager.executeExpression(expression, paramMap);
sql = sql.replaceFirst(replaceRegx.pattern(), Objects.toString(val, ""));
}
}
return sql + executeChildren(paramMap, parameters).trim();
return parseSql(text, paramMap, parameters) + executeChildren(paramMap, parameters).trim();
}
}
}

View File

@@ -86,4 +86,4 @@ public class TrimSqlNode extends SqlNode {
}
return sqlBuffer.toString();
}
}
}

View File

@@ -24,4 +24,4 @@ public class WhereSqlNode extends TrimSqlNode {
return sql;
}
}
}