• MyBatis-Plus - 自定义租户拦截器,一定要这样吗?!


    今天看到如何重写“自定义租户拦截器”,因为用过的人都知道,官方提供的自定义入口只是针对表名进行拦截,所以无法扩展更多自定义做的事情。

    1. /**
    2. * @author Lux Sun
    3. * @date 2023/7/18
    4. */
    5. @Component
    6. public class MyTenantHandler implements TenantHandler {
    7. @Override
    8. public Expression getTenantId(boolean select) {
    9. TenantInfo tenantInfo = TenantContext.get();
    10. if (tenantInfo == null) {
    11. return null;
    12. }
    13. String tenantId = tenantInfo.getId();
    14. return new StringValue(tenantId);
    15. }
    16. @Override
    17. public String getTenantIdColumn() {
    18. return "tenant_id";
    19. }
    20. @Override
    21. public boolean doTableFilter(String tableName) {
    22. if (StrUtil.equalsAny(tableName, "t_product")) {
    23. return true;
    24. }
    25. return false;
    26. }
    27. }

    于是,网上居然有人重写 Mybatis-Plus 租户拦截器,代码如下(需要重写的类)

    1. /*
    2. * Copyright (c) 2011-2020, baomidou (jobob@qq.com).
    3. *

    4. * Licensed under the Apache License, Version 2.0 (the "License"); you may not
    5. * use this file except in compliance with the License. You may obtain a copy of
    6. * the License at
    7. *

    8. * https://www.apache.org/licenses/LICENSE-2.0
    9. *

    10. * Unless required by applicable law or agreed to in writing, software
    11. * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
    12. * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
    13. * License for the specific language governing permissions and limitations under
    14. * the License.
    15. */
    16. package com.baomidou.mybatisplus.extension.plugins.tenant;
    17. import com.baomidou.mybatisplus.core.parser.AbstractJsqlParser;
    18. import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
    19. import com.baomidou.mybatisplus.core.toolkit.StringPool;
    20. import lombok.AllArgsConstructor;
    21. import lombok.Data;
    22. import lombok.EqualsAndHashCode;
    23. import lombok.NoArgsConstructor;
    24. import lombok.experimental.Accessors;
    25. import net.sf.jsqlparser.expression.BinaryExpression;
    26. import net.sf.jsqlparser.expression.Expression;
    27. import net.sf.jsqlparser.expression.Parenthesis;
    28. import net.sf.jsqlparser.expression.ValueListExpression;
    29. import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
    30. import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
    31. import net.sf.jsqlparser.expression.operators.relational.*;
    32. import net.sf.jsqlparser.schema.Column;
    33. import net.sf.jsqlparser.schema.Table;
    34. import net.sf.jsqlparser.statement.delete.Delete;
    35. import net.sf.jsqlparser.statement.insert.Insert;
    36. import net.sf.jsqlparser.statement.select.*;
    37. import net.sf.jsqlparser.statement.update.Update;
    38. import java.util.List;
    39. /**
    40. * 租户 SQL 解析器( TenantId 行级 )
    41. *
    42. * @author hubin
    43. * @since 2017-09-01
    44. */
    45. @Data
    46. @NoArgsConstructor
    47. @AllArgsConstructor
    48. @Accessors(chain = true)
    49. @EqualsAndHashCode(callSuper = true)
    50. public class TenantSqlParser extends AbstractJsqlParser {
    51. private TenantHandler tenantHandler;
    52. /**
    53. * select 语句处理
    54. */
    55. @Override
    56. public void processSelectBody(SelectBody selectBody) {
    57. if (selectBody instanceof PlainSelect) {
    58. processPlainSelect((PlainSelect) selectBody);
    59. } else if (selectBody instanceof WithItem) {
    60. WithItem withItem = (WithItem) selectBody;
    61. if (withItem.getSelectBody() != null) {
    62. processSelectBody(withItem.getSelectBody());
    63. }
    64. } else {
    65. SetOperationList operationList = (SetOperationList) selectBody;
    66. if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
    67. operationList.getSelects().forEach(this::processSelectBody);
    68. }
    69. }
    70. }
    71. /**
    72. * insert 语句处理
    73. */
    74. @Override
    75. public void processInsert(Insert insert) {
    76. if (tenantHandler.doTableFilter(insert.getTable().getName())) {
    77. // 过滤退出执行
    78. return;
    79. }
    80. insert.getColumns().add(new Column(tenantHandler.getTenantIdColumn()));
    81. if (insert.getSelect() != null) {
    82. processPlainSelect((PlainSelect) insert.getSelect().getSelectBody(), true);
    83. } else if (insert.getItemsList() != null) {
    84. // fixed github pull/295
    85. ItemsList itemsList = insert.getItemsList();
    86. if (itemsList instanceof MultiExpressionList) {
    87. ((MultiExpressionList) itemsList).getExprList().forEach(el -> el.getExpressions().add(tenantHandler.getTenantId(false)));
    88. } else {
    89. ((ExpressionList) insert.getItemsList()).getExpressions().add(tenantHandler.getTenantId(false));
    90. }
    91. } else {
    92. throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
    93. }
    94. }
    95. /**
    96. * update 语句处理
    97. */
    98. @Override
    99. public void processUpdate(Update update) {
    100. final Table table = update.getTable();
    101. if (tenantHandler.doTableFilter(table.getName())) {
    102. // 过滤退出执行
    103. return;
    104. }
    105. update.setWhere(this.andExpression(table, update.getWhere()));
    106. }
    107. /**
    108. * delete 语句处理
    109. */
    110. @Override
    111. public void processDelete(Delete delete) {
    112. if (tenantHandler.doTableFilter(delete.getTable().getName())) {
    113. // 过滤退出执行
    114. return;
    115. }
    116. delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
    117. }
    118. /**
    119. * delete update 语句 where 处理
    120. */
    121. protected BinaryExpression andExpression(Table table, Expression where) {
    122. //获得where条件表达式
    123. EqualsTo equalsTo = new EqualsTo();
    124. equalsTo.setLeftExpression(this.getAliasColumn(table));
    125. equalsTo.setRightExpression(tenantHandler.getTenantId(false));
    126. if (null != where) {
    127. if (where instanceof OrExpression) {
    128. return new AndExpression(equalsTo, new Parenthesis(where));
    129. } else {
    130. return new AndExpression(equalsTo, where);
    131. }
    132. }
    133. return equalsTo;
    134. }
    135. /**
    136. * 处理 PlainSelect
    137. */
    138. protected void processPlainSelect(PlainSelect plainSelect) {
    139. processPlainSelect(plainSelect, false);
    140. }
    141. /**
    142. * 处理 PlainSelect
    143. *
    144. * @param plainSelect ignore
    145. * @param addColumn 是否添加租户列,insert into select语句中需要
    146. */
    147. protected void processPlainSelect(PlainSelect plainSelect, boolean addColumn) {
    148. FromItem fromItem = plainSelect.getFromItem();
    149. if (fromItem instanceof Table) {
    150. Table fromTable = (Table) fromItem;
    151. if (!tenantHandler.doTableFilter(fromTable.getName())) {
    152. //#1186 github
    153. plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
    154. if (addColumn) {
    155. plainSelect.getSelectItems().add(new SelectExpressionItem(new Column(tenantHandler.getTenantIdColumn())));
    156. }
    157. }
    158. } else {
    159. processFromItem(fromItem);
    160. }
    161. List joins = plainSelect.getJoins();
    162. if (joins != null && joins.size() > 0) {
    163. joins.forEach(j -> {
    164. processJoin(j);
    165. processFromItem(j.getRightItem());
    166. });
    167. }
    168. }
    169. /**
    170. * 处理子查询等
    171. */
    172. protected void processFromItem(FromItem fromItem) {
    173. if (fromItem instanceof SubJoin) {
    174. SubJoin subJoin = (SubJoin) fromItem;
    175. if (subJoin.getJoinList() != null) {
    176. subJoin.getJoinList().forEach(this::processJoin);
    177. }
    178. if (subJoin.getLeft() != null) {
    179. processFromItem(subJoin.getLeft());
    180. }
    181. } else if (fromItem instanceof SubSelect) {
    182. SubSelect subSelect = (SubSelect) fromItem;
    183. if (subSelect.getSelectBody() != null) {
    184. processSelectBody(subSelect.getSelectBody());
    185. }
    186. } else if (fromItem instanceof ValuesList) {
    187. logger.debug("Perform a subquery, if you do not give us feedback");
    188. } else if (fromItem instanceof LateralSubSelect) {
    189. LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
    190. if (lateralSubSelect.getSubSelect() != null) {
    191. SubSelect subSelect = lateralSubSelect.getSubSelect();
    192. if (subSelect.getSelectBody() != null) {
    193. processSelectBody(subSelect.getSelectBody());
    194. }
    195. }
    196. }
    197. }
    198. /**
    199. * 处理联接语句
    200. */
    201. protected void processJoin(Join join) {
    202. if (join.getRightItem() instanceof Table) {
    203. Table fromTable = (Table) join.getRightItem();
    204. if (this.tenantHandler.doTableFilter(fromTable.getName())) {
    205. // 过滤退出执行
    206. return;
    207. }
    208. join.setOnExpression(builderExpression(join.getOnExpression(), fromTable));
    209. }
    210. }
    211. /**
    212. * 处理条件:
    213. * 支持 getTenantHandler().getTenantId()是一个完整的表达式:tenant in (1,2)
    214. * 默认tenantId的表达式: LongValue(1)这种依旧支持
    215. */
    216. protected Expression builderExpression(Expression currentExpression, Table table) {
    217. final Expression tenantExpression = tenantHandler.getTenantId(true);
    218. Expression appendExpression = this.processTableAlias4CustomizedTenantIdExpression(tenantExpression, table);
    219. if (currentExpression == null) {
    220. return appendExpression;
    221. }
    222. if (currentExpression instanceof BinaryExpression) {
    223. BinaryExpression binaryExpression = (BinaryExpression) currentExpression;
    224. doExpression(binaryExpression.getLeftExpression());
    225. doExpression(binaryExpression.getRightExpression());
    226. } else if (currentExpression instanceof InExpression) {
    227. InExpression inExp = (InExpression) currentExpression;
    228. ItemsList rightItems = inExp.getRightItemsList();
    229. if (rightItems instanceof SubSelect) {
    230. processSelectBody(((SubSelect) rightItems).getSelectBody());
    231. }
    232. }
    233. if (currentExpression instanceof OrExpression) {
    234. return new AndExpression(new Parenthesis(currentExpression), appendExpression);
    235. } else {
    236. return new AndExpression(currentExpression, appendExpression);
    237. }
    238. }
    239. protected void doExpression(Expression expression) {
    240. if (expression instanceof FromItem) {
    241. processFromItem((FromItem) expression);
    242. } else if (expression instanceof InExpression) {
    243. InExpression inExp = (InExpression) expression;
    244. ItemsList rightItems = inExp.getRightItemsList();
    245. if (rightItems instanceof SubSelect) {
    246. processSelectBody(((SubSelect) rightItems).getSelectBody());
    247. }
    248. }
    249. }
    250. /**
    251. * 目前: 针对自定义的tenantId的条件表达式[tenant_id in (1,2,3)],无法处理多租户的字段加上表别名
    252. * select a.id, b.name
    253. * from a
    254. * join b on b.aid = a.id and [b.]tenant_id in (1,2) --别名[b.]无法加上 TODO
    255. *
    256. * @param expression
    257. * @param table
    258. * @return 加上别名的多租户字段表达式
    259. */
    260. protected Expression processTableAlias4CustomizedTenantIdExpression(Expression expression, Table table) {
    261. Expression target;
    262. if (expression instanceof ValueListExpression) {
    263. InExpression inExpression = new InExpression();
    264. inExpression.setLeftExpression(this.getAliasColumn(table));
    265. inExpression.setRightItemsList(((ValueListExpression) expression).getExpressionList());
    266. target = inExpression;
    267. } else {
    268. EqualsTo equalsTo = new EqualsTo();
    269. equalsTo.setLeftExpression(this.getAliasColumn(table));
    270. equalsTo.setRightExpression(expression);
    271. target = equalsTo;
    272. }
    273. return target;
    274. }
    275. /**
    276. * 租户字段别名设置
    277. *

      tenantId 或 tableAlias.tenantId

    278. *
    279. * @param table 表对象
    280. * @return 字段
    281. */
    282. protected Column getAliasColumn(Table table) {
    283. StringBuilder column = new StringBuilder();
    284. if (table.getAlias() != null) {
    285. column.append(table.getAlias().getName()).append(StringPool.DOT);
    286. }
    287. column.append(tenantHandler.getTenantIdColumn());
    288. return new Column(column.toString());
    289. }
    290. }

    发现一顿操作猛如虎下来还各种编译报错“withItem.getSelectBody()”……于是就放弃了。

    后来回过头去看之前那个官方提供的自定义类发现,其实最终还是靠 true or false 来控制是否使用租户拦截器功能,一开始被以为只能填写表名给蒙蔽了~

    1. /**
    2. * @author Lux Sun
    3. * @date 2023/7/18
    4. */
    5. @Component
    6. public class MyTenantHandler implements TenantHandler {
    7. @Override
    8. public Expression getTenantId(boolean select) {
    9. TenantInfo tenantInfo = TenantContext.get();
    10. if (tenantInfo == null) {
    11. return null;
    12. }
    13. String tenantId = tenantInfo.getId();
    14. return new StringValue(tenantId);
    15. }
    16. @Override
    17. public String getTenantIdColumn() {
    18. return "tenant_id";
    19. }
    20. @Override
    21. public boolean doTableFilter(String tableName) {
    22. Boolean filter = SqlParserContext.get();
    23. if (filter) {
    24. return true;
    25. }
    26. if (StrUtil.equalsAny(tableName, "t_product")) {
    27. return true;
    28. }
    29. return false;
    30. }
    31. }
    1. /**
    2. * @author Lux Sun
    3. * @date 2023/7/18
    4. */
    5. public class SqlParserContext {
    6. private static final ThreadLocal CONTEXT = new ThreadLocal<>();
    7. public static void set(Boolean filter) {
    8. CONTEXT.set(filter);
    9. }
    10. public static Boolean get() {
    11. if (CONTEXT.get() == null) {
    12. set(false);
    13. }
    14. return CONTEXT.get();
    15. }
    16. public static void clear() {
    17. CONTEXT.remove();
    18. }
    19. }

    我给它自定义了上下文类,这样可以通过每个线程进行对本次请求进行操作,只要设置为 true 就可以跳过拦截功能,否则启动拦截功能,除非表名能命中,那么还是会进行跳过拦截操作,反之。

    看一个案例,在需要用到的地方使用上下文设置为 true 即可

    1. /**
    2. * 获取财务记录
    3. * @param id
    4. * @return
    5. */
    6. @GetMapping("/finance/{id}")
    7. public ResultVO getFinance(@PathVariable String id) {
    8. SqlParserContext.set(true);
    9. FinancePO fice = ficeService.getById(id);
    10. return ResultVoUtil.buildSuccess(fice);
    11. }

    最后,别忘了还需要注册 MybatisConfig 类,注册即可使用啦~

    1. /**
    2. * @author Lux Sun
    3. * @date 2023/7/18
    4. */
    5. @EnableTransactionManagement
    6. @Configuration
    7. public class MyBatisConfig {
    8. @Resource
    9. private MyTenantHandler myTenantHandler;
    10. @Bean
    11. public PaginationInterceptor paginationInterceptor() {
    12. PaginationInterceptor paginationInterceptor = new PaginationInterceptor();
    13. paginationInterceptor.setSqlParserList(Collections.singletonList(new TenantSqlParser().setTenantHandler(myTenantHandler)));
    14. return paginationInterceptor;
    15. }
    16. }
  • 相关阅读:
    WebSocket--1.协议解析
    Gradle系列——Gradle文件操作,Gradle依赖(基于Gradle文档7.5)day3-1
    外汇天眼:外汇交易商常见黑心手法大公开!投资务必留意这5种骗局
    16.cuBLAS开发指南中文版--cuBLAS中的Level-1函数rotm()和rotmg()
    一键将Web页面保存至Anki
    unity PostProcess 屏幕后处理
    【Linux虚拟机安装】在VMware Workstation上安装ubuntu虚拟机
    git能pink成功,为什么一直克隆超时啊
    Acwing第 67 场周赛
    Codeforces Round #832 (Div. 2)——A、B、C、D
  • 原文地址:https://blog.csdn.net/Dream_Weave/article/details/133913875