ExtendedJpaRepositoryImpl.java

package org.gringlobal.spring.persistence;

import java.io.Serializable;
import java.util.List;

import javax.persistence.EntityManager;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Sort.Order;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
import org.springframework.data.mapping.PropertyPath;
import org.springframework.data.querydsl.EntityPathResolver;
import org.springframework.data.querydsl.QSort;
import org.springframework.data.querydsl.SimpleEntityPathResolver;

import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.core.types.dsl.PathBuilder;
import com.querydsl.core.types.EntityPath;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.OrderSpecifier;
import com.querydsl.core.types.OrderSpecifier.NullHandling;
import com.querydsl.core.types.Path;
import com.querydsl.jpa.impl.JPAQuery;
import com.querydsl.jpa.JPQLQuery;

import org.springframework.util.Assert;
import org.springframework.data.support.PageableExecutionUtils;

public class ExtendedJpaRepositoryImpl<T, ID extends Serializable> extends SimpleJpaRepository<T, ID>
    implements ExtendedJpaRepository<T, ID> {

  private EntityManager entityManager;

  @Autowired
  private EntityPathResolver resolver = SimpleEntityPathResolver.INSTANCE;
  private EntityPath<T> path;
  private PathBuilder<?> builder;

  public ExtendedJpaRepositoryImpl(JpaEntityInformation<T, ?> entityInformation, EntityManager entityManager) {
    super(entityInformation, entityManager);
    this.entityManager = entityManager;
    this.path = resolver.createPath(super.getDomainClass());
    this.builder = new PathBuilder<T>(path.getType(), path.getMetadata());
  }

  @Override
  public Page<T> findAll(JPAQuery<T> query, Pageable page) {
    Assert.notNull(query, "Predicate must not be null!");
    Assert.notNull(page, "Pageable must not be null!");

    Long total = query.fetchCount();
    applyPagination(page, query);

    return PageableExecutionUtils.getPage(query.fetch(), page, total::longValue);
  }

  private JPAQuery<T> applyPagination(Pageable pageable, JPAQuery<T> query) {

    Assert.notNull(pageable, "Pageable must not be null!");
    Assert.notNull(query, "Query must not be null!");

    if (pageable.isUnpaged()) {
      return query;
    }

    query.offset(pageable.getOffset());
    query.limit(pageable.getPageSize());

    return applySorting(pageable.getSort(), query);
  }

  private JPAQuery<T> applySorting(Sort sort, JPAQuery<T> query) {
    Assert.notNull(sort, "Sort must not be null!");
    Assert.notNull(query, "Query must not be null!");

    if (sort.isUnsorted()) {
      return query;
    }

    if (sort instanceof QSort) {
      return addOrderByFrom((QSort) sort, query);
    }

    return addOrderByFrom(sort, query);
  }

  /**
   * Applies the given {@link OrderSpecifier}s to the given {@link JPQLQuery}.
   * Potentially transforms the given
   * {@code OrderSpecifier}s to be able to injection potentially necessary
   * left-joins.
   *
   * @param qsort must not be {@literal null}.
   * @param query must not be {@literal null}.
   */
  private JPAQuery<T> addOrderByFrom(QSort qsort, JPAQuery<T> query) {

    List<OrderSpecifier<?>> orderSpecifiers = qsort.getOrderSpecifiers();

    return query.orderBy(orderSpecifiers.toArray(new OrderSpecifier[0]));
  }

  /**
   * Converts the {@link Order} items of the given {@link Sort} into
   * {@link OrderSpecifier} and attaches those to the
   * given {@link JPQLQuery}.
   *
   * @param sort  must not be {@literal null}.
   * @param query must not be {@literal null}.
   * @return
   */
  private JPAQuery<T> addOrderByFrom(Sort sort, JPAQuery<T> query) {

    Assert.notNull(sort, "Sort must not be null!");
    Assert.notNull(query, "Query must not be null!");

    for (Order order : sort) {
      query.orderBy(toOrderSpecifier(order));
    }

    return query;
  }

  /**
   * Transforms a plain {@link Order} into a QueryDsl specific
   * {@link OrderSpecifier}.
   *
   * @param order must not be {@literal null}.
   * @return
   */
  @SuppressWarnings({ "rawtypes", "unchecked" })
  private OrderSpecifier<?> toOrderSpecifier(Order order) {

    return new OrderSpecifier(
        order.isAscending() ? com.querydsl.core.types.Order.ASC : com.querydsl.core.types.Order.DESC,
        buildOrderPropertyPathFrom(order), toQueryDslNullHandling(order.getNullHandling()));
  }

  /**
   * Converts the given {@link org.springframework.data.domain.Sort.NullHandling}
   * to the appropriate Querydsl
   * {@link NullHandling}.
   *
   * @param nullHandling must not be {@literal null}.
   * @return
   * @since 1.6
   */
  private NullHandling toQueryDslNullHandling(org.springframework.data.domain.Sort.NullHandling nullHandling) {

    Assert.notNull(nullHandling, "NullHandling must not be null!");

    switch (nullHandling) {

      case NULLS_FIRST:
        return NullHandling.NullsFirst;

      case NULLS_LAST:
        return NullHandling.NullsLast;

      case NATIVE:
      default:
        return NullHandling.Default;
    }
  }

  /**
   * Creates an {@link Expression} for the given {@link Order} property.
   *
   * @param order must not be {@literal null}.
   * @return
   */
  private Expression<?> buildOrderPropertyPathFrom(Order order) {

    Assert.notNull(order, "Order must not be null!");

    PropertyPath path = PropertyPath.from(order.getProperty(), builder.getType());
    Expression<?> sortPropertyExpression = builder;

    while (path != null) {

      sortPropertyExpression = !path.hasNext() && order.isIgnoreCase() && String.class.equals(path.getType()) //
          ? Expressions.stringPath((Path<?>) sortPropertyExpression, path.getSegment()).lower() //
          : Expressions.path(path.getType(), (Path<?>) sortPropertyExpression, path.getSegment());

      path = path.next();
    }

    return sortPropertyExpression;
  }

}