ElasticQueryBuilder.java

/*
 * Copyright 2020 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.gringlobal.custom.elasticsearch;

import static com.google.common.collect.Lists.*;
import static org.elasticsearch.index.query.QueryBuilders.*;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;

import com.google.common.collect.ImmutableList;
import com.querydsl.core.QueryMetadata;
import com.querydsl.core.types.Constant;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.FactoryExpression;
import com.querydsl.core.types.Operation;
import com.querydsl.core.types.OperationImpl;
import com.querydsl.core.types.Operator;
import com.querydsl.core.types.Ops;
import com.querydsl.core.types.ParamExpression;
import com.querydsl.core.types.Path;
import com.querydsl.core.types.PathMetadata;
import com.querydsl.core.types.PathType;
import com.querydsl.core.types.Predicate;
import com.querydsl.core.types.SubQueryExpression;
import com.querydsl.core.types.TemplateExpression;
import com.querydsl.core.types.Visitor;
import com.querydsl.core.types.dsl.NumberPath;

import lombok.extern.slf4j.Slf4j;
import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchPhrasePrefixQueryBuilder;
import org.elasticsearch.index.query.MatchPhraseQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.springframework.data.elasticsearch.annotations.Field;
import org.springframework.data.elasticsearch.annotations.FieldType;

/**
 * Converter from a Querydsl predicate to Elasticsearch query.
 *
 * @author Matija Obreza
 * @author Maxym Borodenko
 */
@Slf4j
public class ElasticQueryBuilder implements Visitor<Void, Void> {

	private final List<QueryBuilder> mustClauses = new ArrayList<>();
	private final List<QueryBuilder> mustNotClauses = new ArrayList<>();

	private final Map<String, RangeQueryBuilder> ranges = new HashMap<>();

	private final ElasticQueryBuilder self = this;

	/*
	 * The name of the root path is (unfortunately) not available
	 * so we ignore the root name for properties: <code>accession.accessionNumber</code>
	 * becomes <code>accession</code>.
	 * 
	 * In exists sub-queries we want to include the full path. This flags controls
	 * how the path name is converted to ES term name.
	 */
	private boolean useFullPathNames = false;


	public QueryBuilder getQuery() {
		BoolQueryBuilder root = QueryBuilders.boolQuery();
		mustClauses.forEach(must -> root.filter(must));
		mustNotClauses.forEach(mustNot -> root.mustNot(mustNot));
		// shouldClauses.forEach(should -> root.should(should));
		// if (shouldClauses.size() > 0) root.minimumNumberShouldMatch(1);
		return root;
	}

	private String customizedPath(String path) {
		// Check useFullPathNames!
		// Just remove the entity name from the path -- hopefully that's fine
		if (!useFullPathNames && path.contains(".")) {
			int firstDot = path.indexOf('.');
			path = path.substring(firstDot + 1);
		}
		return path;
	}

	@Override
	public Void visit(Constant<?> c, Void context) {
		log.debug("+Constant: {}", c.getConstant());
		return null;
	}

	@Override
	public Void visit(FactoryExpression<?> expr, Void context) {
		log.debug("+FactoryExpression: {}", expr.getArgs());
		return null;
	}

	@Override
	public Void visit(Operation<?> expr, Void context) {
		log.debug("+Operation: {} {} {}", expr.getType(), expr.getOperator(), expr.getArgs());
		visitOperation(expr.getType(), expr.getOperator(), expr.getArgs(), context);
		return null;
	}

	private void visitOperation(Class<?> type, Operator operator, List<Expression<?>> args, Void context) {
		if (operator == Ops.AND) {
			handleAnd(args);
		} else if (operator == Ops.OR) {
			for (Expression<?> expr : args) {
				printExpression(".. " + operator, expr);
				expr.accept(self, null);
			}
		} else if (operator == Ops.EQ || operator == Ops.IN) {
			if (Path.class.isAssignableFrom(args.get(0).getClass())) {
				log.debug("EQUALS: {}", args);
				for (Expression<?> expr : args) {
					printExpression("EQUALS.. ", expr);
				}
				Path<?> a0 = (Path<?>) args.get(0);
				Expression<?> a1 = args.get(1);
				handleEquals(a0, a1);
			} else {
				Path<?> path = (Path<?>)((OperationImpl<?>) args.get(0)).getArg(0);
				PathMetadata pmd = path.getMetadata();
				mustNotClauses.add(existsQuery(customizedPath(pmd.getParent().toString() + "." + pmd.getName())));
			}
		} else if (operator == Ops.NE) {
			log.debug("NOT EQUALS: {}", args);
			for (Expression<?> expr : args) {
				printExpression("NOT EQUALS.. ", expr);
			}
			Path<?> a0 = (Path<?>) args.get(0);
			Expression<?> a1 = args.get(1);
			PathMetadata pmd = a0.getMetadata();
			mustNotClauses.add(termsQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName()), toValues(a1)));
		} else if (operator == Ops.LOE || operator == Ops.GOE || operator == Ops.BETWEEN || operator == Ops.LT || operator == Ops.GT) {
			if (Path.class.isAssignableFrom(args.get(0).getClass())) {
				log.debug("Range: {}", args);
				for (Expression<?> expr : args) {
					printExpression("LOE.. ", expr);
				}
				Path<?> a0 = (Path<?>) args.get(0);
				handleRange(operator, a0, args.get(1), args.size() > 2 ? args.get(2) : null);
			} else {
				Path<?> path = (Path<?>)((OperationImpl<?>) args.get(0)).getArg(0);
				PathMetadata pmd = path.getMetadata();
				mustClauses.add(existsQuery(customizedPath(pmd.getParent().toString() + "." + pmd.getName())));
			}
		} else if (operator == Ops.STRING_CONTAINS || operator == Ops.STARTS_WITH) {
			log.debug("{}: {}", operator, args);
			for (Expression<?> expr : args) {
				printExpression(operator + ".. ", expr);
			}
			Path<?> a0 = (Path<?>) args.get(0);
			Expression<?> a1 = args.get(1);

			handleLike(operator, a0, a1);
		} else if (operator == Ops.NOT) {
			log.debug("{}: {}", operator, args);
			for (Expression<?> expr : args) {
				printExpression(operator + ".. ", expr);
			}
			Expression<?> notExp = args.get(0);
			handleNot(notExp);
		} else if (operator == Ops.IS_NOT_NULL) {
			log.debug("{}: {}", operator, args);
			for (Expression<?> expr : args) {
				printExpression(operator + ".. ", expr);
			}
			Path<?> path = (Path<?>) args.get(0);
			PathMetadata pmd = path.getMetadata();
			mustClauses.add(existsQuery(customizedPath(pmd.getParent().toString() + "." + pmd.getName())));
		} else if (operator == Ops.IS_NULL) {
			log.debug("{}: {}", operator, args);
			for (Expression<?> expr : args) {
				printExpression(operator + ".. ", expr);
			}
			Path<?> path = (Path<?>) args.get(0);
			PathMetadata pmd = path.getMetadata();
			mustNotClauses.add(existsQuery(customizedPath(pmd.getParent().toString() + "." + pmd.getName())));
		} else if (operator == Ops.COL_IS_EMPTY) {
			log.debug("{}: {}", operator, args);
			for (Expression<?> expr : args) {
				printExpression(operator + ".. ", expr);
			}
			Path<?> path = (Path<?>) args.get(0);
			PathMetadata pmd = path.getMetadata();

			// If collection is nested in ES, exists doesn't work!
			var annoEl = path.getAnnotatedElement();
			var fieldAnnotation = annoEl.getAnnotation(Field.class);
			if (fieldAnnotation != null && fieldAnnotation.type() == FieldType.Nested) {
				log.debug("Skipping exists query on a ES Nested field");
			} else {
				mustNotClauses.add(existsQuery(customizedPath(pmd.getParent().toString() + "." + pmd.getName())));
			}
		} else if (operator == Ops.EXISTS) {
			log.debug("{}: {}", operator, args);
			for (Expression<?> expr : args) {
				printExpression(operator + ".. ", expr);
			}
			SubQueryExpression<?> subQuery = (SubQueryExpression<?>) args.get(0);
			var metadata = subQuery.getMetadata();
			var projection = metadata.getProjection();
			var where = metadata.getWhere();
			printExpression(operator + ".. ", where);
			handleNestedQuery(projection.toString(), where); // FIXME Alternative to .toString()?

		} else {
			log.error("Op {}: {}", operator, args);
		}
		// Expression<?> a0 = args.get(0);
		// Expression<?> a1 = args.get(1);
		// printExpression("a1: " + type.getName() + " " + operator, a1);
	}

	private void handleNestedQuery(String path, Predicate where) {
		log.debug("Nested query for {}", path);
		ElasticQueryBuilder andBuilder = new ElasticQueryBuilder();
		andBuilder.useFullPathNames = true;
		where.accept(andBuilder, null);

		mustClauses.add(nestedQuery(path, andBuilder.getQuery(), ScoreMode.Total));
	}

	private int size() {
		return this.mustClauses.size() + this.mustNotClauses.size();
	}

	private void handleAnd(List<Expression<?>> args) {
		handleAnd(args, useFullPathNames);
	}

	private void handleAnd(List<Expression<?>> args, boolean useFullPathNames) {
		log.debug("AND expr: {}", args);
		ElasticQueryBuilder andBuilder = new ElasticQueryBuilder();
		andBuilder.useFullPathNames = useFullPathNames;
		for (Expression<?> a : args) {
			a.accept(andBuilder, null);
		}
		if (andBuilder.size() == 1 && andBuilder.mustClauses.size() > 0) {
			mustClauses.addAll(andBuilder.mustClauses);
		} else if (andBuilder.size() == 1 && andBuilder.mustNotClauses.size() > 0) {
			mustNotClauses.addAll(andBuilder.mustNotClauses);
		}  else {
			mustClauses.add(andBuilder.getQuery());
		}
	}

	private void handleNot(Expression<?> notExp) {
		log.debug("NOT expr: {}", notExp);
		ElasticQueryBuilder notBuilder = new ElasticQueryBuilder();
		notBuilder.useFullPathNames = useFullPathNames;
		notExp.accept(notBuilder, null);

		notBuilder.mustClauses.forEach(mustNot -> mustNotClauses.add(mustNot));
		notBuilder.mustNotClauses.forEach(must -> mustClauses.add(must));
	}

	private void handleLike(Operator operator, Path<?> path, Expression<?> val) {
		PathMetadata pmd = path.getMetadata();
		// SimpleQueryStringBuilder qsq = simpleQueryStringQuery( +":" + toValue(val));
		if (operator == Ops.STARTS_WITH) {
			MatchPhrasePrefixQueryBuilder matchPrefixQuery = matchPhrasePrefixQuery(customizedPath(pmd.getParent().toString() + "." + pmd.getName()), toValue(val));
			mustClauses.add(matchPrefixQuery);
		} else if (operator == Ops.STRING_CONTAINS) {
			MatchPhraseQueryBuilder matchPrefixQuery = matchPhraseQuery(customizedPath(pmd.getParent().toString() + "." + pmd.getName()), toValue(val));
			mustClauses.add(matchPrefixQuery);
		} else {
			throw new RuntimeException("Unsupported ES handleLike operator: " + operator);
		}
	}

	private void handleRange(Operator operator, Path<?> path, Expression<?> val1, Expression<?> val2) {
		PathMetadata pmd = path.getMetadata();
		RangeQueryBuilder rq;

		if (ranges.get(path.toString()) != null) {
			rq = ranges.get(path.toString());
		} else {
			rq = rangeQuery(customizedPath(pmd.getParent().toString() + "." + pmd.getName()));
			ranges.put(path.toString(), rq);
		}

		if (operator == Ops.LOE) {
			rq.lte(toValue(val1));
		} else if (operator == Ops.LT) {
			rq.lt(toValue(val1));
		} else if (operator == Ops.GOE) {
			rq.gte(toValue(val1));
		} else if (operator == Ops.GT) {
			rq.gt(toValue(val1));
		} else if (operator == Ops.BETWEEN) {
			rq.gte(toValue(val1));
			rq.lte(toValue(val2));
		} else if (operator == Ops.LOE) {
			rq.lte(toValue(val1));
		}
		mustClauses.add(rq);
	}

	private void handleEquals(Path<?> path, Expression<?> value) {
		PathMetadata pmd = path.getMetadata();
		if (pmd.getPathType() == PathType.COLLECTION_ANY) {
			log.debug("Path ANY for {}={}", pmd.getParent(), value);
			mustClauses.add(termsQuery(customizedPath(pmd.getParent().toString()), toValues(value)));
		} else if (value instanceof Path) {
			// This happens mostly in sub-queries where value is the reference back to the main entity
			log.debug("Skipping path variable {}", value);
		} else {
			log.debug("Path {} for {}={}", pmd.getPathType(), pmd.getParent(), value);
			mustClauses.add(termsQuery(customizedPath(getParentPath(pmd.getParent()) + "." + pmd.getName()), toValues(value)));
		}
	}

	private String getParentPath(Path<?> path) {
		String pathValue = path.toString();

		if (pathValue.startsWith("any")) {
			return getParentPath(path.getMetadata().getParent());
		}

		return pathValue;
	}

	private Object toValue(Expression<?> value) {
		if (value instanceof Constant<?>) {
			Constant<?> cons = (Constant<?>) value;
			Object obj = cons.getConstant();
			return convertValue(obj);
		}

		throw new RuntimeException("Unhandled value " + value);
	}

	private Object convertValue(Object obj) {
		if (obj == null) {
			return null;
		}
		Class<? extends Object> objClass = obj.getClass();
		log.debug("toValue of {}: c={}", obj, objClass);
		if (objClass.isEnum()) {
			return obj.toString();
		}
		if (UUID.class.isAssignableFrom(objClass)) {
			return obj.toString();
		}
		return obj;
	}

	private Collection<?> toValues(Expression<?> value) {
		if (value instanceof Constant<?>) {
			Constant<?> cons = (Constant<?>) value;
			Object obj = cons.getConstant();
			log.debug("toValues of {}: c={}", obj, obj.getClass());

			if (obj instanceof Collection<?>) {
				Collection<?> c = (Collection<?>) obj;
				return c.stream().map((forString) -> convertValue(forString)).collect(Collectors.toList());
			} else {
				return newArrayList(convertValue(obj));
			}
		}

		throw new RuntimeException("Unhandled value " + value + " of type " + value.getClass());
	}

	private void printExpression(String prefix, Expression<?> expr) {
		if (expr instanceof NumberPath<?>) {
			NumberPath<?> path = (NumberPath<?>) expr;
			log.debug("{}: NumberPath {} {}", prefix, path.getRoot(), path.getType());

		} else if (expr instanceof Path<?>) {
			Path<?> path = (Path<?>) expr;
			PathMetadata pmd = path.getMetadata();
			if (pmd.getPathType() == PathType.COLLECTION_ANY) {
				log.debug("{}: {} {} parent={}", prefix, pmd.getPathType(), pmd.getElement(), pmd.getParent());
			} else {
				log.debug("{}: {} {}/{} parent={}", prefix, pmd.getPathType(), pmd.getName(), pmd.getElement(), pmd.getParent());
			}

		} else if (expr instanceof Constant<?>) {
			Constant<?> cons = (Constant<?>) expr;
			log.debug("{}: Constant {} {}", prefix, cons.getConstant(), cons.getType());

		} else if (expr instanceof Predicate) {
			Predicate pred = (Predicate) expr;
			log.debug("{}: should visit Predicate {}", prefix, pred);

		} else if (expr instanceof SubQueryExpression) {
			SubQueryExpression<?> query = (SubQueryExpression<?>) expr;
			log.debug("{}: should visit SubQueryExpression {}", prefix, query);

		} else {
			log.debug("{}: {} {}", prefix, expr.getClass(), expr.getType());
		}
	}

	@Override
	public Void visit(ParamExpression<?> param, Void context) {
		log.debug("+ParamExpression: {} {} {}", param.getType(), param.isAnon(), param.getName());
		return null;
	}

	@Override
	public Void visit(Path<?> path, Void context) {
		final PathType pathType = path.getMetadata().getPathType();
		final Object element = path.getMetadata().getElement();
		List<Object> args;
		if (path.getMetadata().getParent() != null) {
			args = ImmutableList.of(path.getMetadata().getParent(), element);
		} else {
			args = ImmutableList.of(element);
		}
		log.debug("+Path: {} {} {}", pathType, pathType.name(), args);
		return null;
	}

	@Override
	public Void visit(SubQueryExpression<?> query, Void context) {
		QueryMetadata qm = query.getMetadata();
		log.debug("+SubQueryExpression: {}", qm);

		return null;
	}

	@Override
	public Void visit(TemplateExpression<?> expr, Void context) {
		log.debug("+TemplateExpr: {} {}", expr.getTemplate(), expr.getArgs());
		return null;
	}

}