TraitDataServiceImpl.java

/*
 * Copyright 2024 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.gringlobal.api.v2.facade.impl;

import java.util.List;
import java.util.LinkedList;
import java.util.Map.Entry;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.stream.Collectors;

import java.util.function.Function;

import org.apache.commons.collections4.CollectionUtils;
import org.gringlobal.api.model.InventoryInfo;
import org.gringlobal.api.v2.facade.TraitDataService;
import org.gringlobal.api.v2.mapper.MapstructMapper;
import org.gringlobal.api.v2.model.TraitDataFilter;
import org.gringlobal.api.v2.model.Traits;
import org.gringlobal.api.v2.model.TraitsPage;
import org.gringlobal.model.QAccession;
import org.gringlobal.model.QCropTrait;
import org.gringlobal.model.QCropTraitCode;
import org.gringlobal.model.QCropTraitObservation;
import org.gringlobal.model.QInventory;
import org.gringlobal.service.CropTraitTranslationService;
import org.gringlobal.service.CropTraitTranslationService.TranslatedCropTrait;
import org.gringlobal.service.filter.CropTraitFilter;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cache.CacheManager;
import org.springframework.cache.interceptor.SimpleKey;
import org.springframework.data.domain.Pageable;
import org.springframework.stereotype.Service;

import com.querydsl.core.types.Predicate;
import com.querydsl.core.types.ExpressionUtils;
import com.querydsl.jpa.impl.JPAQueryFactory;

import lombok.extern.slf4j.Slf4j;

/**
 * Provides trait observation data in wide format.
 *
 * @author Matija Obreza
 */
@Service
@Slf4j
public class TraitDataServiceImpl implements TraitDataService {

	@Autowired
	private JPAQueryFactory jpaQueryFactory;

	@Autowired
	private MapstructMapper mapper;

	@Autowired
	private CropTraitTranslationService cropTraitTranslationService;

	@Autowired
	private CacheManager cacheManager;

	@Override
	public TraitsPage getTraitData(TraitDataFilter filter, Pageable page) {

		List<Long> cropTraitIds = filter == null ? null : filter.getCropTraitIds();

		if (CollectionUtils.isEmpty(cropTraitIds)) {
			if (filter != null && CollectionUtils.isNotEmpty(filter.getMethodIds())) {
				cropTraitIds = jpaQueryFactory.from(QCropTraitObservation.cropTraitObservation)
					.select(QCropTraitObservation.cropTraitObservation.cropTrait().id).distinct()
					.orderBy(QCropTraitObservation.cropTraitObservation.cropTrait().id.asc())
					.where(QCropTraitObservation.cropTraitObservation.method().id.in(filter.getMethodIds()))
					.fetch();
			} else {
				// All traits
				cropTraitIds = jpaQueryFactory.from(QCropTrait.cropTrait)
					.select(QCropTrait.cropTrait.id)
					.orderBy(QCropTrait.cropTrait.id.asc())
					.fetch();
			}
		} else {
			// Work with the selection
		}
		log.info("Working with {} traits", cropTraitIds.size());

		var allInventories = getCachedInventoriesForFilter(filter, cropTraitIds);

		var selectedInventories = allInventories;
		log.info("Found {} inventories matching intial filters", selectedInventories.size());

		// Now we trim the list by keeping only inventory ids that have selected trait observations, if any
		if (filter != null && filter.getObservations() != null) {
			// var aliasInventory = new QInventory("i");
			// var aliasAccession = new QAccession("a");
			// var aliasCTC = new QCropTraitCode("ctc");

			var obsF = new ArrayList<Predicate>(3);
			for (var obs : filter.getObservations().entrySet()) {
				log.info("Filtering for obs: {} {}", obs.getKey(), obs.getValue());
				var valFilt = obs.getValue();

				obsF.clear();
				// var aliasInventory = new QInventory("i");
				// var aliasAccession = new QAccession("a");
				var obsQ = jpaQueryFactory.from(QCropTraitObservation.cropTraitObservation) // Cache
					.select(QCropTraitObservation.cropTraitObservation.inventory().id).distinct() // Must be ordered
					.orderBy(QCropTraitObservation.cropTraitObservation.inventory().id.asc());

				if (filter != null && CollectionUtils.isNotEmpty(filter.getMethodIds())) { // Apply filtering my method
					log.info("Applying filter by methodId: {}", filter.getMethodIds());
					obsQ.where(QCropTraitObservation.cropTraitObservation.method().id.in(filter.getMethodIds()));
				}

				// var OR = new BooleanBuilder();
				if (valFilt.numericValue != null) obsF.add(QCropTraitObservation.cropTraitObservation.cropTrait().id.eq(obs.getKey()).and(valFilt.numericValue.buildQuery(QCropTraitObservation.cropTraitObservation.numericValue)));
				if (valFilt.stringValue != null) obsF.add(QCropTraitObservation.cropTraitObservation.cropTrait().id.eq(obs.getKey()).and(valFilt.stringValue.buildQuery(QCropTraitObservation.cropTraitObservation.stringValue)));
				if (CollectionUtils.isNotEmpty(valFilt.cropTraitCode)) {
					var aliasCTC = new QCropTraitCode("ctc");
					obsQ.leftJoin(QCropTraitObservation.cropTraitObservation.cropTraitCode(), aliasCTC); // Must be left join
					obsF.add(QCropTraitObservation.cropTraitObservation.cropTrait().id.eq(obs.getKey()).and(aliasCTC.code.in(valFilt.cropTraitCode)));
				}

				if (obsF.size() > 0) {
					// .innerJoin(QCropTraitObservation.cropTraitObservation.inventory(), aliasInventory)
					// .innerJoin(aliasInventory.accession(), aliasAccession)
					selectedInventories = obsQ.where(ExpressionUtils.allOf(obsF)).stream().filter(selectedInventories::contains).collect(Collectors.toList());
					log.info("Narrowed down to {} inventories with cropTrait: {}", selectedInventories.size(), obs.getKey());
				} else {
					log.info("Filter for cropTrait: {} was empty!", obs.getKey());
				}
			}
		}

		if (selectedInventories != allInventories) {
			log.info("Sorting remaining {} inventories", selectedInventories.size());
			selectedInventories.sort((a, b) -> Integer.compare(allInventories.indexOf(a), allInventories.indexOf(b)));
		} else {
			// Sorting is not required
		}

		log.info("Will consider {} inventories with {} different traits for pagination", selectedInventories.size(), cropTraitIds.size());

		// Paginate
		int start = (int) page.getOffset();
		int end = (int) Math.min(page.getOffset() + page.getPageSize(), selectedInventories.size());

		if (start > end) {
			return new TraitsPage(List.of(), page, selectedInventories.size());
		}

		var pageInventories = selectedInventories.subList(start, end);
		log.info("Using {} inventories for start: {} end: {} page: {}", pageInventories.size(), start, end, page);

		// Load CropTraitObservations for the selected page of inventories -- no limits!
		// WARNING may nead breaking up into smaller chunks so we don't hit MSSQL limits
		var ctoQ = jpaQueryFactory.from(QCropTraitObservation.cropTraitObservation)
		// var ctoQ = new BlazeJPAQuery<>(entityManager, criteriaBuilderFactory)
			.select(QCropTraitObservation.cropTraitObservation)
			.where(QCropTraitObservation.cropTraitObservation.cropTrait().id.in(cropTraitIds) // WARNING Limits of 1000!
				.and(QCropTraitObservation.cropTraitObservation.inventory().id.in(pageInventories)) // WARNING Limits of 1000!
			);

		if (filter != null && CollectionUtils.isNotEmpty(filter.getMethodIds())) { // Apply filtering my method
			ctoQ.where(QCropTraitObservation.cropTraitObservation.method().id.in(filter.getMethodIds()));
		}

		var observations = ctoQ.fetch();
		log.info("Fetched {} observations", observations.size());

		// var result = new LinkedList<Traits>();
		var resultByInvId = new LinkedHashMap<Long, Traits>(pageInventories.size());

		for (var cto : observations) {
			var ctoInvId = cto.getInventory().getId();
			log.trace("Processing {}", ctoInvId);
			var inventoryTraits = resultByInvId.computeIfAbsent(ctoInvId, (invId) -> {
				var t = new Traits();
				t.setCto(new LinkedHashMap<>());
				return t;
			});
			inventoryTraits.getCto().compute(cto.getCropTrait().getId(), (k, v) -> {
				if (v == null) v = new LinkedList<>();
				v.add(mapper.mapValue(cto));
				return v;
			});
		}

		// Load inventory info for selected page
		log.info("Loading data of {} inventories", pageInventories.size());
		var inventoryData = jpaQueryFactory.from(QInventory.inventory)
			.select(QInventory.inventory)
			.where(QInventory.inventory.id.in(pageInventories))
			.stream().map(inventory -> {
				// log.trace("Mapping inventory {}", inventory.getId());
				return mapper.mapInfo(inventory);
			}).collect(Collectors.toMap(InventoryInfo::getId, Function.identity()));

		log.info("Got InventoryInfo for {} records", inventoryData.size());

		// Return results as sorted in pageInventories
		var list = resultByInvId.entrySet().stream()
			// Order by pageInventory order
			.sorted((a, b) -> Integer.compare(pageInventories.indexOf(a.getKey()), pageInventories.indexOf(b.getKey())))
			// Assign inventory info by key
			.peek(entry -> entry.getValue().setI(inventoryData.get(entry.getKey())))
			// Keep values only
			.map(Entry::getValue)
			.collect(Collectors.toList());

		log.info("Have final list of {} items", list.size());

		var res = new TraitsPage(list, page, selectedInventories.size());
		if (start == 0) {
			// Add selected trait definitions
			var ctf = new CropTraitFilter();
			ctf.id = new HashSet<>(cropTraitIds);
			log.info("Loading {} translated CropTrait definitions", cropTraitIds.size());
			res.traits = cropTraitTranslationService.list(ctf, Pageable.unpaged()).stream().collect(Collectors.toMap(TranslatedCropTrait::getId, Function.identity()));
		}
		log.info("Done.");
		return res;
	}

	/**
	 * Cached version.
	 */
	private List<Long> getCachedInventoriesForFilter(TraitDataFilter filter, List<Long> cropTraitIds) {
		var cache = cacheManager.getCache("inventoryTraitsWide");
		var cacheKey = new SimpleKey(filter, cropTraitIds);
		if (cache != null) {
			log.debug("Using cache to get inventory IDs for filter and traits");
			return cache.get(cacheKey, () -> getInventoriesForFilter(filter, cropTraitIds));
		} else {
			log.debug("Caching is not enabled");
			return getInventoriesForFilter(filter, cropTraitIds);
		}
	}

	/**
	 * Find all Inventory (ordered by accessionNumber, inventoryNumber) that has at least one CropTraitObservation for selected CropTraits
	 */
	private List<Long> getInventoriesForFilter(TraitDataFilter filter, List<Long> cropTraitIds) {
		log.warn("Finding all inventories for cropTraitIds: {} and filter {}", cropTraitIds, filter);
		// var invQ = new BlazeJPAQuery<>(entityManager, criteriaBuilderFactory)
		var aliasInventory = new QInventory("i");
		var aliasAccession = new QAccession("a");
		// var aliasCTC = new QCropTraitCode("ctc");
		var invQ = jpaQueryFactory.from(QCropTraitObservation.cropTraitObservation) // Cache
			.innerJoin(QCropTraitObservation.cropTraitObservation.inventory(), aliasInventory)
			.innerJoin(aliasInventory.accession(), aliasAccession)
			// .leftJoin(QCropTraitObservation.cropTraitObservation.cropTraitCode(), aliasCTC) // Must be left join
			.select(
				aliasInventory.id,
				aliasAccession.accessionNumberPart1,
				aliasAccession.accessionNumberPart2,
				aliasInventory.inventoryNumberPart1,
				aliasInventory.inventoryNumberPart2
			).distinct()
			.where(QCropTraitObservation.cropTraitObservation.cropTrait().id.in(cropTraitIds)) // WARNING Limit of 1000!
			.orderBy(
				aliasAccession.accessionNumberPart1.asc(),
				aliasAccession.accessionNumberPart2.asc(),
				aliasInventory.inventoryNumberPart1.asc(),
				aliasInventory.inventoryNumberPart2.asc(),
				aliasInventory.id.asc()
			);

		if (filter != null && CollectionUtils.isNotEmpty(filter.getMethodIds())) { // Apply filtering my method
			invQ.where(QCropTraitObservation.cropTraitObservation.method().id.in(filter.getMethodIds()));
		}
		if (filter != null && filter.getInventory() != null) {
			invQ.where(ExpressionUtils.allOf(filter.getInventory().collectPredicates(aliasInventory)));
		}

		// This is a sorted list of all inventory ids matching the inventory/method/cropTrait filters
		var allInventories = invQ.stream().map(t -> t.get(0, Long.class)).collect(Collectors.toList());
		return allInventories;
	}

}