今天看到如何重写“自定义租户拦截器”,因为用过的人都知道,官方提供的自定义入口只是针对表名进行拦截,所以无法扩展更多自定义做的事情。
- /**
- * @author Lux Sun
- * @date 2023/7/18
- */
- @Component
- public class MyTenantHandler implements TenantHandler {
-
- @Override
- public Expression getTenantId(boolean select) {
- TenantInfo tenantInfo = TenantContext.get();
- if (tenantInfo == null) {
- return null;
- }
- String tenantId = tenantInfo.getId();
- return new StringValue(tenantId);
- }
-
- @Override
- public String getTenantIdColumn() {
- return "tenant_id";
- }
-
- @Override
- public boolean doTableFilter(String tableName) {
- if (StrUtil.equalsAny(tableName, "t_product")) {
- return true;
- }
- return false;
- }
- }
于是,网上居然有人重写 Mybatis-Plus 租户拦截器,代码如下(需要重写的类)
- /*
- * Copyright (c) 2011-2020, baomidou (jobob@qq.com).
- *
- * Licensed under the Apache License, Version 2.0 (the "License"); you may not
- * use this file except in compliance with the License. You may obtain a copy of
- * the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations under
- * the License.
- */
- package com.baomidou.mybatisplus.extension.plugins.tenant;
-
- import com.baomidou.mybatisplus.core.parser.AbstractJsqlParser;
- import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
- import com.baomidou.mybatisplus.core.toolkit.StringPool;
- import lombok.AllArgsConstructor;
- import lombok.Data;
- import lombok.EqualsAndHashCode;
- import lombok.NoArgsConstructor;
- import lombok.experimental.Accessors;
- import net.sf.jsqlparser.expression.BinaryExpression;
- import net.sf.jsqlparser.expression.Expression;
- import net.sf.jsqlparser.expression.Parenthesis;
- import net.sf.jsqlparser.expression.ValueListExpression;
- import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
- import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
- import net.sf.jsqlparser.expression.operators.relational.*;
- import net.sf.jsqlparser.schema.Column;
- import net.sf.jsqlparser.schema.Table;
- import net.sf.jsqlparser.statement.delete.Delete;
- import net.sf.jsqlparser.statement.insert.Insert;
- import net.sf.jsqlparser.statement.select.*;
- import net.sf.jsqlparser.statement.update.Update;
-
- import java.util.List;
-
- /**
- * 租户 SQL 解析器( TenantId 行级 )
- *
- * @author hubin
- * @since 2017-09-01
- */
- @Data
- @NoArgsConstructor
- @AllArgsConstructor
- @Accessors(chain = true)
- @EqualsAndHashCode(callSuper = true)
- public class TenantSqlParser extends AbstractJsqlParser {
-
- private TenantHandler tenantHandler;
-
- /**
- * select 语句处理
- */
- @Override
- public void processSelectBody(SelectBody selectBody) {
- if (selectBody instanceof PlainSelect) {
- processPlainSelect((PlainSelect) selectBody);
- } else if (selectBody instanceof WithItem) {
- WithItem withItem = (WithItem) selectBody;
- if (withItem.getSelectBody() != null) {
- processSelectBody(withItem.getSelectBody());
- }
- } else {
- SetOperationList operationList = (SetOperationList) selectBody;
- if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
- operationList.getSelects().forEach(this::processSelectBody);
- }
- }
- }
-
- /**
- * insert 语句处理
- */
- @Override
- public void processInsert(Insert insert) {
- if (tenantHandler.doTableFilter(insert.getTable().getName())) {
- // 过滤退出执行
- return;
- }
- insert.getColumns().add(new Column(tenantHandler.getTenantIdColumn()));
- if (insert.getSelect() != null) {
- processPlainSelect((PlainSelect) insert.getSelect().getSelectBody(), true);
- } else if (insert.getItemsList() != null) {
- // fixed github pull/295
- ItemsList itemsList = insert.getItemsList();
- if (itemsList instanceof MultiExpressionList) {
- ((MultiExpressionList) itemsList).getExprList().forEach(el -> el.getExpressions().add(tenantHandler.getTenantId(false)));
- } else {
- ((ExpressionList) insert.getItemsList()).getExpressions().add(tenantHandler.getTenantId(false));
- }
- } else {
- throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
- }
- }
-
- /**
- * update 语句处理
- */
- @Override
- public void processUpdate(Update update) {
- final Table table = update.getTable();
- if (tenantHandler.doTableFilter(table.getName())) {
- // 过滤退出执行
- return;
- }
- update.setWhere(this.andExpression(table, update.getWhere()));
- }
-
- /**
- * delete 语句处理
- */
- @Override
- public void processDelete(Delete delete) {
- if (tenantHandler.doTableFilter(delete.getTable().getName())) {
- // 过滤退出执行
- return;
- }
- delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
- }
-
- /**
- * delete update 语句 where 处理
- */
- protected BinaryExpression andExpression(Table table, Expression where) {
- //获得where条件表达式
- EqualsTo equalsTo = new EqualsTo();
- equalsTo.setLeftExpression(this.getAliasColumn(table));
- equalsTo.setRightExpression(tenantHandler.getTenantId(false));
- if (null != where) {
- if (where instanceof OrExpression) {
- return new AndExpression(equalsTo, new Parenthesis(where));
- } else {
- return new AndExpression(equalsTo, where);
- }
- }
- return equalsTo;
- }
-
- /**
- * 处理 PlainSelect
- */
- protected void processPlainSelect(PlainSelect plainSelect) {
- processPlainSelect(plainSelect, false);
- }
-
- /**
- * 处理 PlainSelect
- *
- * @param plainSelect ignore
- * @param addColumn 是否添加租户列,insert into select语句中需要
- */
- protected void processPlainSelect(PlainSelect plainSelect, boolean addColumn) {
- FromItem fromItem = plainSelect.getFromItem();
- if (fromItem instanceof Table) {
- Table fromTable = (Table) fromItem;
- if (!tenantHandler.doTableFilter(fromTable.getName())) {
- //#1186 github
- plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
- if (addColumn) {
- plainSelect.getSelectItems().add(new SelectExpressionItem(new Column(tenantHandler.getTenantIdColumn())));
- }
- }
- } else {
- processFromItem(fromItem);
- }
- List
joins = plainSelect.getJoins(); - if (joins != null && joins.size() > 0) {
- joins.forEach(j -> {
- processJoin(j);
- processFromItem(j.getRightItem());
- });
- }
- }
-
- /**
- * 处理子查询等
- */
- protected void processFromItem(FromItem fromItem) {
- if (fromItem instanceof SubJoin) {
- SubJoin subJoin = (SubJoin) fromItem;
- if (subJoin.getJoinList() != null) {
- subJoin.getJoinList().forEach(this::processJoin);
- }
- if (subJoin.getLeft() != null) {
- processFromItem(subJoin.getLeft());
- }
- } else if (fromItem instanceof SubSelect) {
- SubSelect subSelect = (SubSelect) fromItem;
- if (subSelect.getSelectBody() != null) {
- processSelectBody(subSelect.getSelectBody());
- }
- } else if (fromItem instanceof ValuesList) {
- logger.debug("Perform a subquery, if you do not give us feedback");
- } else if (fromItem instanceof LateralSubSelect) {
- LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
- if (lateralSubSelect.getSubSelect() != null) {
- SubSelect subSelect = lateralSubSelect.getSubSelect();
- if (subSelect.getSelectBody() != null) {
- processSelectBody(subSelect.getSelectBody());
- }
- }
- }
- }
-
- /**
- * 处理联接语句
- */
- protected void processJoin(Join join) {
- if (join.getRightItem() instanceof Table) {
- Table fromTable = (Table) join.getRightItem();
- if (this.tenantHandler.doTableFilter(fromTable.getName())) {
- // 过滤退出执行
- return;
- }
- join.setOnExpression(builderExpression(join.getOnExpression(), fromTable));
- }
- }
-
- /**
- * 处理条件:
- * 支持 getTenantHandler().getTenantId()是一个完整的表达式:tenant in (1,2)
- * 默认tenantId的表达式: LongValue(1)这种依旧支持
- */
- protected Expression builderExpression(Expression currentExpression, Table table) {
- final Expression tenantExpression = tenantHandler.getTenantId(true);
- Expression appendExpression = this.processTableAlias4CustomizedTenantIdExpression(tenantExpression, table);
- if (currentExpression == null) {
- return appendExpression;
- }
- if (currentExpression instanceof BinaryExpression) {
- BinaryExpression binaryExpression = (BinaryExpression) currentExpression;
- doExpression(binaryExpression.getLeftExpression());
- doExpression(binaryExpression.getRightExpression());
- } else if (currentExpression instanceof InExpression) {
- InExpression inExp = (InExpression) currentExpression;
- ItemsList rightItems = inExp.getRightItemsList();
- if (rightItems instanceof SubSelect) {
- processSelectBody(((SubSelect) rightItems).getSelectBody());
- }
- }
- if (currentExpression instanceof OrExpression) {
- return new AndExpression(new Parenthesis(currentExpression), appendExpression);
- } else {
- return new AndExpression(currentExpression, appendExpression);
- }
- }
-
- protected void doExpression(Expression expression) {
- if (expression instanceof FromItem) {
- processFromItem((FromItem) expression);
- } else if (expression instanceof InExpression) {
- InExpression inExp = (InExpression) expression;
- ItemsList rightItems = inExp.getRightItemsList();
- if (rightItems instanceof SubSelect) {
- processSelectBody(((SubSelect) rightItems).getSelectBody());
- }
- }
- }
-
- /**
- * 目前: 针对自定义的tenantId的条件表达式[tenant_id in (1,2,3)],无法处理多租户的字段加上表别名
- * select a.id, b.name
- * from a
- * join b on b.aid = a.id and [b.]tenant_id in (1,2) --别名[b.]无法加上 TODO
- *
- * @param expression
- * @param table
- * @return 加上别名的多租户字段表达式
- */
- protected Expression processTableAlias4CustomizedTenantIdExpression(Expression expression, Table table) {
- Expression target;
- if (expression instanceof ValueListExpression) {
- InExpression inExpression = new InExpression();
- inExpression.setLeftExpression(this.getAliasColumn(table));
- inExpression.setRightItemsList(((ValueListExpression) expression).getExpressionList());
- target = inExpression;
- } else {
- EqualsTo equalsTo = new EqualsTo();
- equalsTo.setLeftExpression(this.getAliasColumn(table));
- equalsTo.setRightExpression(expression);
- target = equalsTo;
- }
- return target;
- }
-
- /**
- * 租户字段别名设置
- *
tenantId 或 tableAlias.tenantId
- *
- * @param table 表对象
- * @return 字段
- */
- protected Column getAliasColumn(Table table) {
- StringBuilder column = new StringBuilder();
- if (table.getAlias() != null) {
- column.append(table.getAlias().getName()).append(StringPool.DOT);
- }
- column.append(tenantHandler.getTenantIdColumn());
- return new Column(column.toString());
- }
- }
发现一顿操作猛如虎下来还各种编译报错“withItem.getSelectBody()”……于是就放弃了。
后来回过头去看之前那个官方提供的自定义类发现,其实最终还是靠 true or false 来控制是否使用租户拦截器功能,一开始被以为只能填写表名给蒙蔽了~
- /**
- * @author Lux Sun
- * @date 2023/7/18
- */
- @Component
- public class MyTenantHandler implements TenantHandler {
-
- @Override
- public Expression getTenantId(boolean select) {
- TenantInfo tenantInfo = TenantContext.get();
- if (tenantInfo == null) {
- return null;
- }
- String tenantId = tenantInfo.getId();
- return new StringValue(tenantId);
- }
-
- @Override
- public String getTenantIdColumn() {
- return "tenant_id";
- }
-
- @Override
- public boolean doTableFilter(String tableName) {
- Boolean filter = SqlParserContext.get();
- if (filter) {
- return true;
- }
-
- if (StrUtil.equalsAny(tableName, "t_product")) {
- return true;
- }
- return false;
- }
- }
- /**
- * @author Lux Sun
- * @date 2023/7/18
- */
- public class SqlParserContext {
-
- private static final ThreadLocal
CONTEXT = new ThreadLocal<>(); -
- public static void set(Boolean filter) {
- CONTEXT.set(filter);
- }
-
- public static Boolean get() {
- if (CONTEXT.get() == null) {
- set(false);
- }
- return CONTEXT.get();
- }
-
- public static void clear() {
- CONTEXT.remove();
- }
- }
我给它自定义了上下文类,这样可以通过每个线程进行对本次请求进行操作,只要设置为 true 就可以跳过拦截功能,否则启动拦截功能,除非表名能命中,那么还是会进行跳过拦截操作,反之。
看一个案例,在需要用到的地方使用上下文设置为 true 即可
- /**
- * 获取财务记录
- * @param id
- * @return
- */
- @GetMapping("/finance/{id}")
- public ResultVO
getFinance(@PathVariable String id) { - SqlParserContext.set(true);
- FinancePO fice = ficeService.getById(id);
- return ResultVoUtil.buildSuccess(fice);
- }
最后,别忘了还需要注册 MybatisConfig 类,注册即可使用啦~
- /**
- * @author Lux Sun
- * @date 2023/7/18
- */
- @EnableTransactionManagement
- @Configuration
- public class MyBatisConfig {
-
- @Resource
- private MyTenantHandler myTenantHandler;
-
- @Bean
- public PaginationInterceptor paginationInterceptor() {
- PaginationInterceptor paginationInterceptor = new PaginationInterceptor();
- paginationInterceptor.setSqlParserList(Collections.singletonList(new TenantSqlParser().setTenantHandler(myTenantHandler)));
- return paginationInterceptor;
- }
- }