JPA分页踩坑指南
1.原生sql查询返回vo类包含主键id,无法自动映射,需要用到投影ResultTransformer,所以我定义了一个投影工具类
JpaCommonService
2.异步调用原生查询方法的时候,需要用
NativeQuery<?> query = entityManager.createNativeQuery(sql).unwrap(NativeQuery.class);
,不能用NativeQueryImpl<?> query = entityManager.createNativeQuery(sql).unwrap(NativeQueryImpl.class);
3.用jpql查询entityManager.createQuery()分页查询,如分页会涉及到子查询是不支持的,只能用原生sql, 不支持select count(1) from (select * from User) t
以下是分页返回vo的具体实现
分页返回vo
下面是用到的一些类和方法
分页参数
package com.example.springbootjpadruid.domain.common;import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class BasePage {private int page = 1;private int size = 10;/*** 是否需要分页*/private boolean needPage = false;public int getPage() {if (page < 1) {return 0;} else {return page - 1;}}public int getOffset() {return getPage() * size;}
}
分页返回结果
package com.example.springbootjpadruid.domain.common;import cn.hutool.core.bean.BeanUtil;
import lombok.Data;
import org.springframework.data.domain.Page;import java.util.List;@Data
public class PageVO<T> {private int page = 1;private int size = 10;private long total;private int totalPage;private List<T> list;public static <S, T> PageVO<T> of(Page<S> page, Class<T> clazz) {PageVO<T> pageVO = new PageVO<>();pageVO.setPage(page.getNumber() + 1);pageVO.setSize(page.getSize());pageVO.setTotal(page.getTotalElements());List<S> content = page.getContent();pageVO.setList(BeanUtil.copyToList(content, clazz));pageVO.setTotalPage(page.getTotalPages());return pageVO;}public static <T> PageVO<T> of(Page<T> page) {PageVO<T> pageVO = new PageVO<>();pageVO.setPage(page.getNumber() + 1);pageVO.setSize(page.getSize());pageVO.setTotal(page.getTotalElements());pageVO.setList(page.getContent());pageVO.setTotalPage(page.getTotalPages());return pageVO;}
}
原生sql查询通用方法
package com.example.springbootjpadruid.service;import com.alibaba.fastjson.JSON;
import com.example.springbootjpadruid.domain.common.BasePage;
import org.hibernate.query.NativeQuery;
import org.hibernate.transform.ResultTransformer;
import org.springframework.beans.SimpleTypeConverter;
import org.springframework.stereotype.Service;import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.Query;
import java.beans.Introspector;
import java.beans.PropertyDescriptor;
import java.lang.reflect.Method;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;/*** @author Jonny* @description 原生sql查询,自动将下划线转驼峰*/
@Service
public class JpaCommonService {private static final SimpleTypeConverter CONVERTER = new SimpleTypeConverter();@PersistenceContextprivate EntityManager entityManager;/*** 查询数据集合** @param sql 查询sql,参数用:name格式* @param params 查询参数map格式,key对应参数中的:name* @param clazz 实体类型,为空则直接转换为map格式* @return list*/@SuppressWarnings("unchecked")public <P extends BasePage> List<?> queryList(String sql, P params, Class<?> clazz) {String jsonParam = JSON.toJSONString(params);Map<String, Object> mapParams = JSON.parseObject(jsonParam, Map.class);// 提取 SQL 中的参数并过滤无效参数Set<String> sqlParams = extractSqlParams(sql);Map<String, Object> filteredParams = filterParams(sqlParams, mapParams);return queryList(sql, mapParams, filteredParams, clazz);}/*** 总记录数** @param sql 查询sql,参数用:name格式* @param params 查询参数map格式,key对应参数中的:name* @return list*/@SuppressWarnings("unchecked")public <P extends BasePage> long count(String sql, P params) {String jsonParam = JSON.toJSONString(params);Map<String, Object> mapParams = JSON.parseObject(jsonParam, Map.class);// 提取 SQL 中的参数并过滤无效参数Set<String> sqlParams = extractSqlParams(sql);Map<String, Object> filteredParams = filterParams(sqlParams, mapParams);return count(sql, mapParams, filteredParams, Long.class);}/*** 查询数据集合** @param sql 查询sql,参数用:name格式* @param params 查询参数map格式,key对应参数中的:name* @param filteredParams sql中使用的参数* @param clazz 实体类型,为空则直接转换为map格式* @return list*/@SuppressWarnings("all")public List<?> queryList(String sql, Map<String, Object> params, Map<String, Object> extractSqlParams, Class<?> clazz) {try {NativeQuery<?> query = entityManager.createNativeQuery(sql).unwrap(NativeQuery.class);if (Objects.nonNull(extractSqlParams)) {extractSqlParams.forEach((k, v) -> query.setParameter(k, v));}needPage(params, query);if (Objects.isNull(clazz)) {query.unwrap(NativeQuery.class).setResultTransformer(new AliasToEntityMapResultTransformer());return query.getResultList();} else {query.unwrap(NativeQuery.class).setResultTransformer(new AliasToBeanResultTransformer(clazz));return query.getResultList();}} catch (Exception e) {throw new RuntimeException("查询或转换数据时出错", e);}}@SuppressWarnings("all")public long count(String countSql, Map<String, Object> params, Map<String, Object> filteredParams, Class<?> clazz) {long total;try {Query countQuery = entityManager.createNativeQuery(countSql);if (Objects.nonNull(filteredParams)) {filteredParams.forEach(countQuery::setParameter);}total = Long.parseLong(countQuery.getSingleResult().toString());} catch (Exception e) {throw new RuntimeException("查询记录数出错", e);}return total;}/*** 判断是否需要分页* 当需要分页时,需要设置setNeedPage(true),并设置offset和size** @param params 查询参数* @param query 查询对象*/private void needPage(Map<String, Object> params, Query query) {if (Boolean.TRUE.equals(params.get("needPage"))) {int offset = (int) params.get("offset");int size = (int) params.get("size");query.setFirstResult(offset);query.setMaxResults(size);}}/*** 从 SQL 中提取所有命名参数** @param sql SQL 查询* @return 参数名集合*/private static Set<String> extractSqlParams(String sql) {Pattern pattern = Pattern.compile(":(\\w+)");Matcher matcher = pattern.matcher(sql);Set<String> params = new HashSet<>();while (matcher.find()) {params.add(matcher.group(1));}return params;}/*** 根据 SQL 中的参数过滤多余的 Map 参数** @param sqlParams SQL 中使用的参数* @param params 原始参数* @return 过滤后的参数*/private static Map<String, Object> filterParams(Set<String> sqlParams, Map<String, Object> params) {if (params == null || params.isEmpty()) {return Collections.emptyMap();}return params.entrySet().stream().filter(entry -> sqlParams.contains(entry.getKey())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));}private static class AliasToEntityMapResultTransformer implements ResultTransformer {@Overridepublic Object transformTuple(Object[] tuple, String[] aliases) {Map<String, Object> result = new HashMap<>();for (int i = 0; i < aliases.length; i++) {if (aliases[i] != null) {result.put(aliases[i].toLowerCase(), tuple[i]);}}return result;}@Overridepublic List transformList(List collection) {return collection;}}private static class AliasToBeanResultTransformer implements ResultTransformer {private final Class<?> resultClass;public AliasToBeanResultTransformer(Class<?> resultClass) {this.resultClass = resultClass;}/*** 转换结果集 自定将下划线转驼峰** @param tuple 结果集* @param aliases 字段名* @return 实体对象*/@Overridepublic Object transformTuple(Object[] tuple, String[] aliases) {try {Object result = resultClass.getDeclaredConstructor().newInstance();PropertyDescriptor[] props = Introspector.getBeanInfo(resultClass).getPropertyDescriptors();Map<String, Method> writeMethodMap = Arrays.stream(props).filter(p -> Objects.nonNull(p.getWriteMethod())).collect(Collectors.toMap(p -> p.getName().toLowerCase(), PropertyDescriptor::getWriteMethod));for (int i = 0; i < aliases.length; i++) {if (aliases[i] == null) {continue;}String fieldName = aliases[i].toLowerCase().replace("_", "");Method writeMethod = writeMethodMap.get(fieldName);if (writeMethod != null) {Object value = CONVERTER.convertIfNecessary(tuple[i], writeMethod.getParameterTypes()[0]);writeMethod.invoke(result, value);}}return result;} catch (Exception e) {throw new RuntimeException("实体映射错误: " + resultClass.getName(), e);}}@Overridepublic List transformList(List collection) {return collection;}}
}
使用方法
package com.example.springbootjpadruid.service.impl;import com.example.springbootjpadruid.domain.common.PageVO;
import com.example.springbootjpadruid.domain.entity.primary.User;
import com.example.springbootjpadruid.domain.query.UserQuery;
import com.example.springbootjpadruid.domain.vo.UserVO;
import com.example.springbootjpadruid.repository.primary.UserRepository;
import com.example.springbootjpadruid.service.JpaCommonService;
import com.example.springbootjpadruid.service.UserService;
import com.github.wenhao.jpa.PredicateBuilder;
import com.github.wenhao.jpa.Specifications;
import org.hibernate.query.NativeQuery;
import org.hibernate.transform.ResultTransformer;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.PageRequest;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;import javax.annotation.Resource;
import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.Query;
import java.util.List;@Service
public class UserServiceImpl implements UserService {@Resourceprivate UserRepository repository;@PersistenceContextprivate EntityManager entityManager;@Resourceprivate JpaCommonService jpaCommonService;@Overridepublic String init(List<User> users) {repository.saveAll(users);return "ok";}@Overridepublic User getUser() {PredicateBuilder<User> where = Specifications.and();where.eq("username", "zhangsan");where.eq("age", 20);return repository.findOne(where.build()).orElse(null);}/*** 分页1** @param param* @return*/@Overridepublic PageVO<UserVO> listPage(UserQuery param) {PageRequest pageable = PageRequest.of(param.getPage(), param.getSize());Page<User> page = repository.findAll(pageable);return PageVO.of(page, UserVO.class);}/*** 分页2** @param param 查询参数* @return 分页数据*/@SuppressWarnings("all")@Overridepublic PageVO<UserVO> nativeListPage(UserQuery param) {StringBuffer sql = new StringBuffer("select id as id, username as username, age as age, address as address from sys_user where 1 = 1");buildWhere(param, sql);StringBuffer countSql = new StringBuffer("select count(1) from sys_user where 1 = 1");buildWhere(param, countSql);PageRequest pageable = PageRequest.of(param.getPage(), param.getSize());NativeQuery query = entityManager.createNativeQuery(sql.toString()).unwrap(NativeQuery.class);query.setFirstResult(param.getOffset());query.setMaxResults(param.getSize());// 主键id无法直接映射,需要手动设置// query.setResultTransformer(Transformers.aliasToBean(UserVO.class));// 由于返回数据包含主键id,无法直接映射,只能借助投影手动映射query.setResultTransformer(new ResultTransformer() {@Overridepublic Object transformTuple(Object[] objects, String[] strings) {UserVO userVO = new UserVO();userVO.setId(Long.parseLong(objects[0].toString()));userVO.setUsername(objects[1].toString());userVO.setAge(Integer.parseInt(objects[2].toString()));userVO.setAddress(objects[3].toString());return userVO;}@Overridepublic List transformList(List list) {return list;}});buildParam(param, query);Query countQuery = entityManager.createNativeQuery(countSql.toString());buildParam(param, countQuery);long total = Long.parseLong(countQuery.getSingleResult().toString());PageImpl<UserVO> pageImpl = new PageImpl<>(query.getResultList(), pageable, total);return PageVO.of(pageImpl);}/*** 分页3* 原生sql自定转驼峰,带下划线不管大小自动转标准驼峰,全大写会转小写然后转标准驼峰* 解决分页方法2的原生sql如果没用as别名,无法自动转驼峰的问题** @param param* @return*/@SuppressWarnings("all")@Overridepublic PageVO<UserVO> listPageByJpaUtil(UserQuery param) {// StringBuffer sql = new StringBuffer("select id, username, age, address, user_role as USER_ROLE from sys_user where 1 = 1");StringBuffer sql = new StringBuffer("select * from sys_user where 1 = 1");buildWhere(param, sql);StringBuffer countSql = new StringBuffer("select count(1) from sys_user where 1 = 1");buildWhere(param, countSql);PageRequest pageable = PageRequest.of(param.getPage(), param.getSize());param.setNeedPage(Boolean.TRUE);List<UserVO> userList = (List<UserVO>) jpaCommonService.queryList(sql.toString(), param, UserVO.class);long total = jpaCommonService.count(countSql.toString(), param);PageImpl<UserVO> pageImpl = new PageImpl<>(userList, pageable, total);return PageVO.of(pageImpl);}private void buildParam(UserQuery param, Query query) {if (StringUtils.hasText(param.getUsername())) {query.setParameter("username", param.getUsername());}}private void buildWhere(UserQuery param, StringBuffer sql) {if (StringUtils.hasText(param.getUsername())) {sql.append(" and username = :username");}}}