CropTraitObservationDataServiceImpl.java

/*
 * Copyright 2021 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.gringlobal.service.impl;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.gringlobal.api.exception.InvalidApiUsageException;
import org.gringlobal.api.v1.impl.CropTraitObservationController;
import org.gringlobal.model.CropTrait;
import org.gringlobal.model.CropTraitObservation;
import org.gringlobal.model.CropTraitObservationData;
import org.gringlobal.model.Inventory;
import org.gringlobal.model.Method;
import org.gringlobal.model.QCropTrait;
import org.gringlobal.model.QCropTraitObservationData;
import org.gringlobal.model.QInventory;
import org.gringlobal.model.community.CommunityCodeValues;
import org.gringlobal.persistence.CropTraitObservationDataRepository;
import org.gringlobal.persistence.CropTraitRepository;
import org.gringlobal.persistence.InventoryRepository;
import org.gringlobal.persistence.MethodRepository;
import org.gringlobal.service.CropTraitObservationDataService;
import org.gringlobal.service.CropTraitService;
import org.gringlobal.service.CropTraitTranslationService.TranslatedCropTrait;
import org.gringlobal.service.filter.CropTraitObservationDataFilter;
import org.hibernate.Hibernate;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.util.Pair;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.validation.annotation.Validated;

import com.querydsl.core.BooleanBuilder;
import com.querydsl.jpa.impl.JPAQuery;


@Service
@Validated
@Transactional(readOnly = true)
@Slf4j
public class CropTraitObservationDataServiceImpl extends FilteredCRUDServiceImpl<CropTraitObservationData, CropTraitObservationDataFilter, CropTraitObservationDataRepository> implements CropTraitObservationDataService {

	@Autowired
	private CropTraitService cropTraitService;

	@Autowired
	private MethodRepository methodRepository;

	@Autowired
	private InventoryRepository inventoryRepository;

	@Autowired
	private CropTraitRepository cropTraitRepository;

	@Override
	protected JPAQuery<CropTraitObservationData> entityListQuery() {
		return jpaQueryFactory.selectFrom(QCropTraitObservationData.cropTraitObservationData)
				// method
				.join(QCropTraitObservationData.cropTraitObservationData.method()).fetchJoin()
				// inventory
				.join(QCropTraitObservationData.cropTraitObservationData.inventory()).fetchJoin()
				// trait
				.join(QCropTraitObservationData.cropTraitObservationData.cropTrait()).fetchJoin()
				// trait
				.leftJoin(QCropTraitObservationData.cropTraitObservationData.cropTraitCode()).fetchJoin()
			;
	}

	@Override
	@Transactional
	public CropTraitObservationData create(CropTraitObservationData source) {
		assert(source.getId() == null);
		CropTraitObservationData observationData = new CropTraitObservationData();
		observationData.apply(source);
//		observationData.setCropTraitObservation(null);

		var saved = repository.save(observationData);
		return _lazyLoad(saved);
	}

	@Override
	@Transactional
	public CropTraitObservationData update(CropTraitObservationData input, CropTraitObservationData target) {
		assert(input.getId() != null);
		target.apply(input);
//		target.setCropTraitObservation(null);

		var saved = repository.save(target);
		return _lazyLoad(saved);
	}

	@Override
	@Transactional
	public CropTraitObservationData remove(CropTraitObservationData entity) {
		return super.remove(entity);
	}

	/**
	 * We generate CTO records based on raw CTOD. These records are not persisted!
	 */
	@Override
	public List<CropTraitObservation> generateObservations(Method method, CropTrait cropTrait) {

		var qCtod = QCropTraitObservationData.cropTraitObservationData;
		var ctods = (Collection<CropTraitObservationData>) repository.findAll(
			qCtod.method().eq(method)
				.and(qCtod.cropTrait().eq(cropTrait)));

		log.debug("Found corresponding {} CTODs", ctods.size());

		var ctos = new HashMap<String, CropTraitObservation>();

		// Generate a map by method+inventory+trait(+traitCode) key containing a List<CTOD>
		var ctodByKey = ctods.stream().collect(Collectors.toMap(
			ctod -> {
				if (StringUtils.equals("Y", ctod.getCropTrait().getIsCoded())) {
					return ctod.getMethod().getId() + "-" + ctod.getInventory().getId() + "-" + ctod.getCropTrait().getId() + "-" + ctod.getCropTraitCode().getId();
				}
				if (StringUtils.equals(CommunityCodeValues.CROP_TRAIT_DATA_TYPE_CHAR.value, ctod.getCropTrait().getDataTypeCode())) {
					return ctod.getMethod().getId() + "-" + ctod.getInventory().getId() + "-" + ctod.getCropTrait().getId() + "-" + ctod.getStringValue();
				}
				return ctod.getMethod().getId() + "-" + ctod.getInventory().getId() + "-" + ctod.getCropTrait().getId();
			},
			ctod -> List.of(ctod),
			(s, a) -> {
				if (s instanceof ArrayList<?>) {
					s.addAll(a);
					return s;
				} else {
					var merged = new ArrayList<>(s);
					merged.addAll(a);
					return merged;
				}
			}));

		log.info("Generated {} CTO keys for {} CTODs", ctodByKey.size(), ctods.size());

		for (var entry : ctodByKey.entrySet()) {
			var key = entry.getKey();
			var rawCtods = entry.getValue(); // These all tally up by method+inventory+trait(+traitCode)
			log.info("Processing {} with {} CTOD", key, rawCtods.size());

			var cto = makeCropTraitObservation(rawCtods);
			ctos.put(key, cto);
		}

		var generatedCtos = new ArrayList<>(ctos.values());

		// Update frequency
		var traitCtos = generatedCtos.stream().collect(Collectors.groupingBy(CropTraitObservation::getInventory, Collectors.toList()));
		traitCtos.entrySet().forEach(entry -> {
			var iCtos = entry.getValue();
			if (iCtos.size() > 1) {
				log.debug("Updating frequency and rank for inventory={} size={}", entry.getKey().getId(), iCtos.size());
				var sumSize = iCtos.stream().collect(Collectors.summingInt(CropTraitObservation::getSampleSize));
				iCtos.forEach(o -> o.setFrequency(o.getSampleSize().doubleValue() / sumSize));
				iCtos.sort((a, b) -> Integer.compare(b.getSampleSize(), a.getSampleSize()));
				iCtos.forEach(o -> o.setRank(iCtos.indexOf(o)));
			}
		});

		return generatedCtos;
	}

	private CropTraitObservation makeCropTraitObservation(List<CropTraitObservationData> ctods) {
		assert(ctods.size() > 0); // Must always have CTODs
		var cto = new CropTraitObservation();

		cto.setObservationData(ctods);

		// Populate keys
		var ctod1 = ctods.get(0);
		cto.setMethod(ctod1.getMethod());
		cto.setInventory(ctod1.getInventory());
		cto.setCropTrait(ctod1.getCropTrait());
		cto.setCropTraitCode(ctod1.getCropTraitCode());

		cto.setSampleSize(ctods.size()); // We have data on individuals, so sample size is the count

		if (cto.getCropTraitCode() != null) {
			// This is obviously a coded trait


		} else if (StringUtils.equals(CommunityCodeValues.CROP_TRAIT_DATA_TYPE_CHAR.value, cto.getCropTrait().getDataTypeCode())) {
			// For same for text observations
			cto.setStringValue(ctod1.getStringValue());


		} else if (StringUtils.equals(CommunityCodeValues.CROP_TRAIT_DATA_TYPE_NUMERIC.value, cto.getCropTrait().getDataTypeCode())) {
			// Numeric trait observations are summarized in CTO

			var numericValues = ctods.stream().map(CropTraitObservationData::getNumericValue).filter(f -> f != null).collect(Collectors.toList());
			if (numericValues.size() > 1) {
				var sum = numericValues.stream().collect(Collectors.summarizingDouble(Double::doubleValue));
				cto.setNumericValue(sum.getAverage());
				cto.setMaximumValue(sum.getMax());
				cto.setMeanValue(sum.getAverage());
				cto.setMinimumValue(sum.getMin());
	
				// STDDEV
				var x = numericValues.stream().collect(Collectors.summingDouble(v -> Math.pow(v - sum.getAverage(), 2)));
				cto.setStandardDeviation(Math.sqrt(x / (sum.getCount() - 1)));
	
			} else if (numericValues.size() > 0) {
				cto.setNumericValue(numericValues.get(0));
			}
		}

		return cto;
	}

	@Override
	public FilteredObservationsData search(CropTraitObservationDataFilter filter, Pageable page) {
		if (filter == null || CollectionUtils.isEmpty(filter.cropTraitId)) {
			throw new InvalidApiUsageException("List of traits is required");
		}

		BooleanBuilder predicate = new BooleanBuilder();
		predicate.and(filter.buildPredicate());

		Page<CropTraitObservationData> resultPage = repository.findAll(predicate, page);

		Set<Method> methods = new HashSet<>();
		Set<TranslatedCropTrait> cropTraits = new HashSet<>();
		filter.cropTraitId.forEach((cropTraitId) -> cropTraits.add(cropTraitService.loadTranslated(cropTraitId)));

		Map<Inventory, Set<CropTraitObservationData>> observationsData = new HashMap<>();

		resultPage.forEach(o -> {
			// initialize lazy data
			Hibernate.initialize(o.getMethod());
			Hibernate.initialize(o.getInventory());

			methods.add(o.getMethod());

			Set<CropTraitObservationData> ctod = observationsData.getOrDefault(o.getInventory(), new HashSet<>());
			ctod.add(o);
			observationsData.put(o.getInventory(), ctod);
		});

		FilteredObservationsData result = new FilteredObservationsData();
		result.filter = filter;
		result.methods = methods;
		result.cropTraits = cropTraits;
		result.observationsData = new HashSet<>(observationsData.size());

		observationsData.keySet().forEach(i -> result.observationsData.add(new InventoryWithObservationsData(i, observationsData.get(i))));

		return result;
	}

	@Override
	public Page<Inventory> getObservationDataInventoriesByMethod(Long methodId, Pageable pageable) {
		var method = methodRepository.getReferenceById(methodId);
		var observationDataPath = QCropTraitObservationData.cropTraitObservationData;
		var inventoryIds = jpaQueryFactory.select(observationDataPath.inventory().id).distinct().from(observationDataPath)
			.where(observationDataPath.method().id.eq(method.getId()).and(observationDataPath.inventory().isNotNull()))
			.fetch();
		return inventoryRepository.findAll(QInventory.inventory.id.in(inventoryIds), pageable);
	}

	@Override
	public Page<CropTrait> getObservationDataTraitsByMethod(Long methodId, Pageable pageable) {
		var method = methodRepository.getReferenceById(methodId);
		var observationDataPath = QCropTraitObservationData.cropTraitObservationData;
		var cropTraitIds = jpaQueryFactory.select(observationDataPath.cropTrait().id).distinct().from(observationDataPath)
			.where(observationDataPath.method().id.eq(method.getId()).and(observationDataPath.cropTrait().isNotNull()))
			.fetch();
		return cropTraitRepository.findAll(QCropTrait.cropTrait.id.in(cropTraitIds), pageable);
	}

	@Override
	@Transactional
	public int ensureObservationData(CropTraitObservationController.EnsureObservationsRequest request) {
		if (request.methodId == null || CollectionUtils.isEmpty(request.inventoryId) || CollectionUtils.isEmpty(request.cropTraitId)) {
			throw new InvalidApiUsageException("Method id, CropTrait ids and inventory ids must be provided");
		}
		var method = methodRepository.getReferenceById(request.methodId);
		var observationDataPath = QCropTraitObservationData.cropTraitObservationData;
		var observations = jpaQueryFactory.selectFrom(observationDataPath).distinct()
			.where(observationDataPath.method().id.in(method.getId())
				.and(observationDataPath.cropTrait().id.in(request.cropTraitId))
				.and(observationDataPath.inventory().id.in(request.inventoryId))).fetch();

		var idPairs = observations.stream().map(obs -> Pair.of(obs.getCropTrait().getId(), obs.getInventory().getId()))
			.collect(Collectors.toList());

		List<CropTraitObservationData> observationDataForSave = new ArrayList<>();
		for (var traitId: request.cropTraitId) {
			for (var inventoryId: request.inventoryId) {
				if (!idPairs.contains(Pair.of(traitId, inventoryId))) {
					CropTraitObservationData observationData = new CropTraitObservationData();
					observationData.setCropTrait(cropTraitRepository.getReferenceById(traitId));
					observationData.setInventory(inventoryRepository.getReferenceById(inventoryId));
					observationData.setMethod(method);
					observationDataForSave.add(observationData);
				}
			}
		}

		var saved = repository.saveAll(observationDataForSave);
		return saved.size();
	}
}