DuplicateFinder.java

/*
 * Copyright 2021 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.gringlobal.worker.dupe;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

import com.fasterxml.jackson.annotation.JsonUnwrapped;
import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.genesys.blocks.model.EntityId;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.transaction.annotation.Transactional;

import org.genesys.taxonomy.checker.StringSimilarity;


/**
 * Duplicate Finder base.
 */
@Slf4j
public abstract class DuplicateFinder<T extends EntityId> {

	public enum HitRating {
		BEST(4), GOOD(3), OK(2), POOR(1);

		private final int rating;

		HitRating(int rating) {
			this.rating = rating;
		}

		@JsonValue
		public int getRating() {
			return rating;
		}
	}

	public static class SimilarityHit<T> {
		public T source;
		public List<Hit<T>> results;

		public SimilarityHit(T source, List<Hit<T>> results) {
			this.source = source;
			this.results = results;
		}
	}

	public static class Hit<T> {
		@JsonUnwrapped
		public T result;
		public HitRating hitRating;
		public double score = 0;
		public List<String> matches = new ArrayList<>();

		public Hit(T result, Double score) {
			this.result = result;
			this.score = score != null ? score.doubleValue() : 0;
		}
	}

	@Transactional(readOnly = true)
	@PreAuthorize("hasAuthority('GROUP_ADMINS')")
	public final List<SimilarityHit<T>> findSimilar(List<T> targets) {
		assert (targets != null);

		log.warn("Finding duplicates for {} targets", targets.size());
		return targets.stream().map((accession) -> new SimilarityHit<T>(accession, findSimilar(accession))).collect(Collectors.toList());
	}

	/**
	 * Find entities similar to the target.
	 *
	 * @param target the target
	 * @return the list
	 */
	@Transactional(readOnly = true)
	public final List<Hit<T>> findSimilar(T target) {
		return findSimilar(target, CollectionUtils.emptyCollection());
	}

	/**
	 * Find entities similar to the target, but excluding the ones with IDs listed in excludedById
	 *
	 * @param target the target
	 * @param excludedById the list of candidate IDs to exclude from matching
	 * @return the list
	 */
	@Transactional(readOnly = true)
	public List<Hit<T>> findSimilar(T target, Collection<Long> excludedById) {
		assert (target != null);

		log.info("Searching for duplicates of {}", target);

		List<T> candidates = getCandidates(target, excludedById);

		if (target.getId() != null) {
			candidates.removeIf(candidate -> candidate.getId().equals(target.getId()));
		}
		// Remove excluded candidates by id
		if (CollectionUtils.isNotEmpty(excludedById)) {
			candidates.removeIf(candidate -> excludedById.contains(candidate.getId()));
		}
		log.info("Found {} potential hits", candidates.size());

		Set<Long> ids = new HashSet<>();
		var uniqueHits = candidates.stream().filter((hit) -> {
			if (ids.contains(hit.getId())) {
				return false;
			} else {
				ids.add(hit.getId());
				return true;
			}
		}).map((candidate) -> new Hit<T>(candidate, 0d))
				// Process
				.peek((candidate) -> scoreHit(target, candidate))
				// sort
				.sorted((a, b) -> Double.compare(b.score, a.score))
				// filter
				.limit(20)
				// done
				.collect(Collectors.toList());

		// double bestScore = uniqueHits.stream().map((hit) ->
		// hit.score).max(Comparator.comparing(Double::valueOf)).orElse(1.0);

		// Adjust maxScore to something that gives good results (theoretical max > 1360)
		var bestScoreThreshold = getBestScoreThreshold(); // Math.min(bestScore, 1000);
		uniqueHits.forEach((hit) -> {
			var perc = hit.score / bestScoreThreshold;
			hit.hitRating = perc > 0.9 ? HitRating.BEST : perc > 0.7 ? HitRating.GOOD : perc > 0.4 ? HitRating.OK : HitRating.POOR;
		});

		log.info("Found {} duplicates of {}", uniqueHits.size(), target);
		return uniqueHits;
	}

	/**
	 * Gets the best score threshold.
	 *
	 * @return the best score threshold
	 */
	protected abstract double getBestScoreThreshold();

	/**
	 * Find all candidates that are potential matches for target.
	 *
	 * @param target the target
	 * @param excludedById the IDs of excluded entities
	 * @return list of candidates
	 */
	protected abstract List<T> getCandidates(final T target, final Collection<Long> excludedById);

	/**
	 * Convert ES query to a safe ES query by replacing non digit, non word chars
	 * with " ".
	 *
	 * @param rawEsQuery the raw ES search string
	 * @return the safe search string
	 */
	protected final String toSafeEsQuery(String rawEsQuery) {
		assert (rawEsQuery != null);
		// Replace non digit, non word chars with " "
		String esQuery = rawEsQuery.trim().replaceAll("[^\\p{L}\\d]+", " ");
		return esQuery;
	}

	/**
	 * Score the target against the Hit. Scoring should be transitive.
	 *
	 * @param target the target
	 * @param hit the potential Match
	 * @return similarity score, the higer the better
	 */
	protected abstract double scoreHit(T target, Hit<T> hit);

	protected final boolean notNullEquals(final Collection<String> matches, final String a, final String b) {
		if (a == null || b == null || a.length() == 0 || b.length() == 0) {
			return false;
		}
		if (StringUtils.equalsIgnoreCase(a, b)) {
			matches.add(b);
			return true;
		} else {
			return false;
		}
	}

	/**
	 * similarityScore returns a string similarity value in the range [0, 1.0]
	 * (where 1.0 is full match).
	 *
	 * @param original the original
	 * @param candidate the candidate
	 * @return the score between 0 and 1.0 where 0 is no similarity and 1.0 is full
	 * match
	 */
	protected final double similarityScore(String original, String candidate) {
		if (original == null || candidate == null || original.length() == 0 || candidate.length() == 0) {
			return 0;
		}

		var score = (StringSimilarity.diceCoefficientOptimized(original.toLowerCase(), candidate.toLowerCase()) + StringSimilarity.getLevenshteinCoefficient(original.toLowerCase(),
				candidate.toLowerCase())) / 2.0f;
		return score;
	}

	protected final double similarityScore(final Collection<String> matches, final String original, final String candidate) {
		var score = similarityScore(original, candidate);
		if (score > 0.7) {
			matches.add(candidate);
		}

		return score;
	}

	protected final double stringsAndNumbersCompare(final Collection<String> matches, String a, String b) {
		var result = stringsAndNumbersCompare(a, b);
		if (result >= 0.5) {
			matches.add(a);
			matches.add(b);
		}
		return result;
	}

	private static final Pattern NUMBERS_AND_STRINGS = Pattern.compile("(\\p{L}+)|0*(\\d+)");

	/**
	 * Split input strings into sets consisting of parts of only digits and only
	 * letters (in lower case). Compare the two sets.
	 *
	 * @param a
	 * @param b
	 * @return a value between 0 and 1.
	 */
	protected final double stringsAndNumbersCompare(String a, String b) {
		if (StringUtils.isBlank(a) || StringUtils.isBlank(b)) {
			return 0;
		}

		var ma = uniqueStringsAndNumbers(a);
		var mb = uniqueStringsAndNumbers(b);

		return compareStringsAndNumbers(ma, mb);
	}

	/**
	 * Compare text parts, strings separately from numbers.
	 *
	 * For each matching string in a and b, add 5. For each matching number in a and
	 * b, add 20; for strings add a bit less.
	 *
	 * @param ma Set of Number | String
	 * @param mb Set of Number | String
	 * @return value in the range of 0 to 1
	 */
	protected final static double compareStringsAndNumbers(Set<Object> ma, Set<Object> mb) {
		AtomicInteger nums = new AtomicInteger();
		AtomicInteger same = new AtomicInteger();

		ma.forEach((man) -> {
			var len = man instanceof String ? 18 : 20; // Strings are less important than numbers
			nums.addAndGet(len);
			if (mb.contains(man)) {
				same.addAndGet(len);
			}
		});
		mb.forEach((man) -> {
			var len = man instanceof String ? 18 : 20; // Strings are less important than numbers
			nums.addAndGet(len);
			if (ma.contains(man)) {
				same.addAndGet(len);
			}
		});

		return nums.doubleValue() == 0 ? 0.0 : same.doubleValue() / nums.doubleValue();
	}

	private final Cache<String, Set<Object>> uniqueStringsAndNumbersCache = CacheBuilder.newBuilder()
			// size
			.maximumSize(100)
			// expiration
			.expireAfterWrite(10, TimeUnit.SECONDS).build();

	protected final Set<Object> uniqueStringsAndNumbers(final String a) {
		if (StringUtils.isBlank(a)) {
			return Set.of();
		}
		try {
			return uniqueStringsAndNumbersCache.get(a, () -> {
				return NUMBERS_AND_STRINGS.matcher(a).results()
						// type conversion
						.map((r) -> r.group(1) != null ? r.group(1).toLowerCase() : Long.parseLong(r.group(2)))
						// get
						.collect(Collectors.toSet());
			});
		} catch (ExecutionException e) {
			throw new RuntimeException("Something went wrong", e);
		}
	}

	private final Cache<String, List<Object>> stringsAndNumbersCache = CacheBuilder.newBuilder()
			// size
			.maximumSize(100)
			// expiration
			.expireAfterWrite(10, TimeUnit.SECONDS).build();

	protected final List<Object> toStringsAndNumbers(final String a) {
		if (StringUtils.isBlank(a)) {
			return List.of();
		}
		try {
			return stringsAndNumbersCache.get(a, () -> {
				return NUMBERS_AND_STRINGS.matcher(a).results()
						// type conversion
						.map((r) -> r.group(1) != null ? r.group(1) : Long.parseLong(r.group(2)))
						// get
						.collect(Collectors.toList());
			});
		} catch (ExecutionException e) {
			throw new RuntimeException("Something went wrong", e);
		}
	}

	protected final static String spaceStringsAndNumbers(final String a) {
		if (StringUtils.isBlank(a)) {
			return null;
		}
		return NUMBERS_AND_STRINGS.matcher(a).results().map((r) -> r.group(1) != null ? r.group(1) : r.group(2)).collect(Collectors.joining(" "));
	}

	protected double compareStrings(final Collection<String> matches, double scoreForMatch, Collection<String> as, Collection<String> bs) {
		if (CollectionUtils.isEmpty(as) || CollectionUtils.isEmpty(bs)) {
			return 0;
		}

		DoubleAdder da = new DoubleAdder();
		bs.forEach((balias) -> {
			as.forEach((aalias) -> {
				if (notNullEquals(matches, aalias, balias)) {
					da.add(scoreForMatch);
				} else {
					da.add(stringsAndNumbersCompare(matches, aalias, balias) * (scoreForMatch * 0.8));
				}
			});
		});

		return da.sum();
	}

	protected void test(String[] texts) {
		for (var a : texts) {
			for (var b : texts) {
				if (StringUtils.equals(a, b)) {
					assert (similarityScore(a, b) == 1);
					assert (stringsAndNumbersCompare(a, b) == 1);
					continue;
				}
				System.err.println("\n\n" + a + "\n" + b + "\n----");
				System.err.println("similarityScore:             \t" + similarityScore(a, b));
				System.err.println("stringsAndNumbersCompare:    \t" + stringsAndNumbersCompare(a, b));

				System.err.println("\n\n");
			}
		}
	}
}