CropTraitObservationServiceImpl.java

/*
 * Copyright 2021 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.gringlobal.service.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.gringlobal.api.exception.InvalidApiUsageException;
import org.gringlobal.api.v1.MultiOp;
import org.gringlobal.api.v1.impl.CropTraitObservationController;
import org.gringlobal.custom.elasticsearch.SearchException;
import org.gringlobal.model.CropTrait;
import org.gringlobal.model.CropTraitCode;
import org.gringlobal.model.CropTraitObservation;
import org.gringlobal.model.Inventory;
import org.gringlobal.model.Method;
import org.gringlobal.model.QCropTrait;
import org.gringlobal.model.QCropTraitObservation;
import org.gringlobal.model.QInventory;
import org.gringlobal.persistence.CropTraitObservationRepository;
import org.gringlobal.persistence.CropTraitRepository;
import org.gringlobal.persistence.InventoryRepository;
import org.gringlobal.persistence.MethodRepository;
import org.gringlobal.service.CropTraitCodeService;
import org.gringlobal.service.CropTraitObservationService;
import org.gringlobal.service.CropTraitService;
import org.gringlobal.service.CropTraitCodeTranslationService.TranslatedCropTraitCode;
import org.gringlobal.service.CropTraitTranslationService.TranslatedCropTrait;
import org.gringlobal.service.filter.CropTraitCodeFilter;
import org.gringlobal.service.filter.CropTraitFilter;
import org.gringlobal.service.filter.CropTraitObservationFilter;
import org.hibernate.Hibernate;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.util.Pair;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.validation.annotation.Validated;

import com.querydsl.core.BooleanBuilder;
import com.querydsl.jpa.impl.JPAQuery;


@Service
@Validated
@Transactional(readOnly = true)
@Slf4j
public class CropTraitObservationServiceImpl extends FilteredCRUDService2Impl<CropTraitObservation, CropTraitObservationFilter, CropTraitObservationRepository> implements CropTraitObservationService {

	@Autowired
	private CropTraitService cropTraitService;

	@Autowired
	private MethodRepository methodRepository;

	@Autowired
	private InventoryRepository inventoryRepository;

	@Autowired
	private CropTraitRepository cropTraitRepository;

	@Autowired
	private CropTraitCodeService cropTraitCodeService;

	@Override
	protected JPAQuery<CropTraitObservation> entityListQuery() {
		return jpaQueryFactory.selectFrom(QCropTraitObservation.cropTraitObservation)
				// method
				.join(QCropTraitObservation.cropTraitObservation.method()).fetchJoin()
				// inventory
				.join(QCropTraitObservation.cropTraitObservation.inventory()).fetchJoin()
				// trait
				.join(QCropTraitObservation.cropTraitObservation.cropTrait()).fetchJoin()
				// trait
				.leftJoin(QCropTraitObservation.cropTraitObservation.cropTraitCode()).fetchJoin()
			;
	}

	@Override
	@PreAuthorize("@ggceSec.actionAllowed('CropTraitObservation', 'CREATE')")
	@Transactional
	public CropTraitObservation create(CropTraitObservation source) {
		log.debug("Create CropTraitObservation. Input data {}", source);
		source.setCropTrait(cropTraitService.get(source.getCropTrait().getId()));
		if (source.getCropTraitCode() != null) {
			source.setCropTraitCode(cropTraitCodeService.get(source.getCropTraitCode().getId()));
			if (!source.getCropTrait().getId().equals(source.getCropTraitCode().getCropTrait().getId())) {
				throw new InvalidApiUsageException("cropTraitCode does not belong to cropTrait");
			}
		}
		CropTraitObservation observation = new CropTraitObservation();
		observation.apply(source);

		CropTraitObservation saved = repository.save(observation);
		return _lazyLoad(saved);
	}

	@Override
	@PreAuthorize("@ggceSec.actionAllowed('CropTraitObservation', 'CREATE')")
	@Transactional
	public CropTraitObservation createFast(CropTraitObservation source) {
		log.debug("Create CropTraitObservation. Input data {}", source);
		source.setCropTrait(cropTraitService.get(source.getCropTrait().getId()));
		if (source.getCropTraitCode() != null) {
			source.setCropTraitCode(cropTraitCodeService.get(source.getCropTraitCode().getId()));
			if (!source.getCropTrait().getId().equals(source.getCropTraitCode().getCropTrait().getId())) {
				throw new InvalidApiUsageException("cropTraitCode does not belong to cropTrait");
			}
		}
		CropTraitObservation observation = new CropTraitObservation();
		observation.apply(source);

		return super.createFast(observation);
	}

	@Override
	@PreAuthorize("@ggceSec.actionAllowed('CropTraitObservation', 'CREATE')")
	@Transactional
	public MultiOp<CropTraitObservation> create(List<CropTraitObservation> inserts) {
		return super.create(inserts);
	}

	@Override
	@PreAuthorize("@ggceSec.actionAllowed('CropTraitObservation', 'CREATE')")
	@Transactional
	public MultiOp<CropTraitObservation> createFast(List<CropTraitObservation> inserts) {
		return super.createFast(inserts);
	}

	@Override
	@Transactional
	@PreAuthorize("@ggceSec.actionAllowed('CropTraitObservation', 'WRITE')")
	public CropTraitObservation forceUpdate(CropTraitObservation cto) {
		if (cto.getCropTraitCode() != null) {
			if (!cto.getCropTrait().getId().equals(cto.getCropTraitCode().getCropTrait().getId())) {
				throw new InvalidApiUsageException("cropTraitCode does not belong to cropTrait");
			}
		}

		var target = get(cto.getId());
		var lmd = target.getModifiedDate();
		target.apply(cto);
		target.setModifiedDate(lmd); // Ugh
		return repository.save(target);
	}

	@Override
	@PreAuthorize("@ggceSec.actionAllowed('CropTraitObservation', 'WRITE')")
	@Transactional
	public CropTraitObservation updateFast(CropTraitObservation updated, CropTraitObservation target) {
		updated.setCropTrait(cropTraitRepository.getReferenceById(updated.getCropTrait().getId()));
		if (updated.getCropTraitCode() != null) {
			updated.setCropTraitCode(cropTraitCodeService.get(updated.getCropTraitCode().getId()));
			if (!updated.getCropTrait().getId().equals(updated.getCropTraitCode().getCropTrait().getId())) {
				throw new InvalidApiUsageException("cropTraitCode does not belong to cropTrait");
			}
		}
		target.apply(updated);
		return repository.save(target);
	}

	@Override
	@PreAuthorize("@ggceSec.actionAllowed('CropTraitObservation', 'WRITE')")
	@Transactional
	public CropTraitObservation update(CropTraitObservation input, CropTraitObservation target) {
		log.debug("Update CropTraitObservation. Input data {}", input);
		input.setCropTrait(cropTraitRepository.getReferenceById(input.getCropTrait().getId()));
		if (input.getCropTraitCode() != null) {
			input.setCropTraitCode(cropTraitCodeService.get(input.getCropTraitCode().getId()));
			if (!input.getCropTrait().getId().equals(input.getCropTraitCode().getCropTrait().getId())) {
				throw new InvalidApiUsageException("cropTraitCode does not belong to cropTrait");
			}
		}
		target.apply(input);

		CropTraitObservation saved = repository.save(target);
		return _lazyLoad(saved);
	}

	@Override
	@PreAuthorize("@ggceSec.actionAllowed('CropTraitObservation', 'WRITE')")
	@Transactional
	public MultiOp<CropTraitObservation> update(List<CropTraitObservation> updates) {
		return super.update(updates);
	}

	@Override
	@PreAuthorize("@ggceSec.actionAllowed('CropTraitObservation', 'WRITE')")
	@Transactional
	public MultiOp<CropTraitObservation> updateFast(List<CropTraitObservation> updates) {
		return super.updateFast(updates);
	}

	@Override
	@PreAuthorize("@ggceSec.actionAllowed('CropTraitObservation', 'DELETE')")
	@Transactional
	public MultiOp<CropTraitObservation> remove(List<CropTraitObservation> deletes) {
		return super.remove(deletes);
	}

	@Override
	@PreAuthorize("@ggceSec.actionAllowed('CropTraitObservation', 'DELETE')")
	@Transactional
	public CropTraitObservation remove(CropTraitObservation entity) {
		return super.remove(entity);
	}

	@Override
	public Page<CropTraitObservation> list(CropTraitObservationFilter filter, Pageable page) throws SearchException {
		var r = super.list(filter, page);
		r.getContent().forEach(cto -> cto.lazyLoad());
		return r;
	}

	@Override
	public FilteredObservations search(CropTraitObservationFilter filter, Pageable page) {
		if (filter == null || CollectionUtils.isEmpty(filter.cropTraitId)) {
			throw new InvalidApiUsageException("List of traits is required");
		}

		BooleanBuilder predicate = new BooleanBuilder();
		predicate.and(filter.buildPredicate());

		Page<CropTraitObservation> resultPage = repository.findAll(predicate, page);

		Set<Method> methods = new HashSet<>();
		Set<TranslatedCropTrait> cropTraits = new HashSet<>();
		filter.cropTraitId.forEach((cropTraitId) -> cropTraits.add(cropTraitService.loadTranslated(cropTraitId)));

		Map<Inventory, Set<CropTraitObservation>> observations = new HashMap<>();

		resultPage.forEach(o -> {
			// initialize lazy data
			Hibernate.initialize(o.getMethod());
			Hibernate.initialize(o.getInventory());

			methods.add(o.getMethod());

			Set<CropTraitObservation> cto = observations.getOrDefault(o.getInventory(), new HashSet<>());
			cto.add(o);
			observations.put(o.getInventory(), cto);
		});

		FilteredObservations result = new FilteredObservations();
		result.filter = filter;
		result.methods = methods;
		result.cropTraits = cropTraits;
		result.observations = new HashSet<>(observations.size());

		observations.keySet().forEach(i -> result.observations.add(new InventoryWithObservations(i, observations.get(i))));

		return result;
	}

	@Override
	public Page<Inventory> getObservationInventoriesByMethod(Long methodId, Pageable pageable) {
		var method = methodRepository.getReferenceById(methodId);
		var observationPath = QCropTraitObservation.cropTraitObservation;
		var inventoryIds = jpaQueryFactory.select(observationPath.inventory().id).distinct().from(observationPath)
			.where(observationPath.method().id.eq(method.getId()).and(observationPath.inventory().isNotNull()))
			.fetch();
		return inventoryRepository.findAll(QInventory.inventory.id.in(inventoryIds), pageable);
	}

	@Override
	public Page<CropTrait> getObservationTraitsByMethod(Long methodId, Pageable pageable) {
		var method = methodRepository.getReferenceById(methodId);
		var observationPath = QCropTraitObservation.cropTraitObservation;
		var cropTraitIds = jpaQueryFactory.select(observationPath.cropTrait().id).distinct().from(observationPath)
			.where(observationPath.method().id.eq(method.getId()).and(observationPath.cropTrait().isNotNull()))
			.fetch();
		return cropTraitRepository.findAll(QCropTrait.cropTrait.id.in(cropTraitIds), pageable);
	}

	@Override
	@PreAuthorize("@ggceSec.actionAllowed('CropTraitObservation', 'CREATE')")
	@Transactional
	public int ensureObservations(CropTraitObservationController.EnsureObservationsRequest request) {
		if (request.methodId == null || CollectionUtils.isEmpty(request.inventoryId) || CollectionUtils.isEmpty(request.cropTraitId)) {
			throw new InvalidApiUsageException("Method id, CropTrait ids and inventory ids must be provided");
		}
		var method = methodRepository.getReferenceById(request.methodId);
		var observationPath = QCropTraitObservation.cropTraitObservation;

		var observations = jpaQueryFactory.selectFrom(observationPath).distinct()
			.where(observationPath.method().id.in(method.getId())
			.and(observationPath.cropTrait().id.in(request.cropTraitId))
			.and(observationPath.inventory().id.in(request.inventoryId))).fetch();

		var idPairs = observations.stream()
			.map(obs -> Pair.of(obs.getCropTrait().getId(), obs.getInventory().getId())).collect(Collectors.toList());

		List<CropTraitObservation> observationsForSave = new ArrayList<>();
		for (var traitId: request.cropTraitId) {
			for (var inventoryId: request.inventoryId) {
				if (!idPairs.contains(Pair.of(traitId, inventoryId))) {
					CropTraitObservation observation = new CropTraitObservation();
					observation.setCropTrait(cropTraitRepository.getReferenceById(traitId));
					observation.setInventory(inventoryRepository.getReferenceById(inventoryId));
					observation.setMethod(method);
					observationsForSave.add(observation);
				}
			}
		}

		var saved = repository.saveAll(observationsForSave);
		return saved.size();
	}


	@Override
	@Transactional(readOnly = true)
	public Page<TranslatedCropTraitObservation> listTranslated(CropTraitObservationFilter filter, Pageable page) throws SearchException {
		Page<CropTraitObservation> loadedCropTraitObservationPage = super.list(filter, page);

		// fetch translated source descriptors by ids from observations

		CropTraitFilter cropTraitFilter = new CropTraitFilter();
		cropTraitFilter.id = loadedCropTraitObservationPage.stream()
			.map(CropTraitObservation::getCropTrait)
			.map(CropTrait::getId)
			.collect(Collectors.toSet());

		List<TranslatedCropTrait> translatedCropTraits = cropTraitService.listFiltered(cropTraitFilter, PageRequest.of(0, loadedCropTraitObservationPage.getSize())).getContent();

		// fetch translated source descriptor codes by ids from observations

		CropTraitCodeFilter cropTraitCodeFilter = new CropTraitCodeFilter();
		cropTraitCodeFilter.id = loadedCropTraitObservationPage.stream()
			.map(CropTraitObservation::getCropTraitCode)
			.filter(Objects::nonNull)
			.map(CropTraitCode::getId)
			.collect(Collectors.toSet());

		List<TranslatedCropTraitCode> translatedCropTraitCodes = cropTraitCodeService.listFiltered(cropTraitCodeFilter, PageRequest.of(0, loadedCropTraitObservationPage.getSize())).getContent();

		Page<TranslatedCropTraitObservation> translatedCropTraitObservations = loadedCropTraitObservationPage.map(observation -> covert(observation, translatedCropTraits, translatedCropTraitCodes));

		return translatedCropTraitObservations;
	}

	private TranslatedCropTraitObservation covert(CropTraitObservation observation, List<TranslatedCropTrait> translatedDescriptors, List<TranslatedCropTraitCode> translatedCodes) {
		TranslatedCropTraitObservation translatedObservation = new TranslatedCropTraitObservation();
		translatedObservation.observation = observation;

		var translatedDescriptor = translatedDescriptors.stream()
			.filter(translatedCropTrait -> Objects.equals(translatedCropTrait.entity.getId(), observation.getCropTrait().getId()))
			.findFirst().orElse(null);

		translatedObservation.translatedCropTrait = translatedDescriptor;

		if (Objects.nonNull(observation.getCropTraitCode())) {
			translatedObservation.translatedCropTraitCode = translatedCodes.stream()
				.filter(translatedCode -> Objects.equals(translatedCode.entity.getId(), observation.getCropTraitCode().getId()))
				.findFirst().orElse(null);
		}

		return translatedObservation;
	}

	@Override
	public TranslatedCropTraitObservation getTranslated(long id) {
		var loadedObservation = super.load(id);

		var translated = new TranslatedCropTraitObservation();
		translated.observation = loadedObservation;

		var loadedSourceDescriptor = loadedObservation.getCropTrait();
		if (loadedSourceDescriptor != null) {
			translated.translatedCropTrait = cropTraitService.loadTranslated(loadedSourceDescriptor.getId());
		}

		var loadedSourceDescriptorCode = loadedObservation.getCropTraitCode();
		if (loadedSourceDescriptorCode != null) {
			translated.translatedCropTraitCode = cropTraitCodeService.loadTranslated(loadedSourceDescriptorCode.getId());
		}

		return translated;
	}

}