TableMatrix.java

/*
 * Copyright 2021 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.gringlobal.compatibility.service.impl;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.gringlobal.soap.Datatable.Row;

/**
 * Java implementation of the Table join matrix.
 */
@Slf4j
public class TableMatrix {

	private final List<String> sysTableList;
	private final Map<String, String> alias2Full;
	private final KJoin[][] joinPathArray;

	// Once per instance cache
	private final Map<String, List<KJoin>> fixedCache = new HashMap<>();

	public TableMatrix(List<String> sysTableList, Map<String, String> alias2Full) {
		this.sysTableList = new ArrayList<>(sysTableList);
		this.alias2Full = new HashMap<>(alias2Full);

		// The matrix
		joinPathArray = new KJoin[sysTableList.size()][];
		for (int i = joinPathArray.length - 1; i >= 0; i--) {
			joinPathArray[i] = new KJoin[joinPathArray.length];
		}
	}

	/**
	 * Adds a direct link join to the matrix and computes all the indirect links
	 */
	// private void addLink(string fromTable, string toTable, string nextTable,
	// string fromField, string nextField, string extraJoinCode) {
	private void addLink(String fromTable, String toTable, String nextTable, String fromField, String nextField, String extraJoinCode) {
		/* @formatter:off */
//		int fromId = -1;
//		int toId = -1;
//		for (int i = 0; i < sysTableList.Count; i++) {
//		    if (sysTableList[i] == fromTable) { fromId = i; }
//		    if (sysTableList[i] == toTable) { toId = i; }
//		}
//		if (fromId < 0 || toId < 0) return;
//		
//		// Add the join to the matrix unless it already knows a path between the two tables
//		if (joinPathArray[fromId, toId] != null) return;
//		joinPathArray[fromId, toId] = new KJoin(fromTable, fromField, nextTable, nextField, extraJoinCode);
//		
//		// If something knows the direction to from table but not the to table add that
//		for (int i = 0; i < sysTableList.Count; i++) {
//		    if (joinPathArray[i, fromId] != null && joinPathArray[i, toId] == null) {
//		        joinPathArray[i, toId] = joinPathArray[i, fromId];
//		    }
//		}
//		
//		// If we know how to get somewhere from the to but not the from fix it
//		for (int i = 0; i < sysTableList.Count; i++) {
//		    if (joinPathArray[toId, i] != null && joinPathArray[fromId, i] == null) {
//		        joinPathArray[fromId, i] = joinPathArray[fromId, toId];
//		    }
//		}
//		
//		// combine info on all the ways we know to get to from whti all the plces we know to go from to
//		for (int i = 0; i < sysTableList.Count; i++) {
//		    if (joinPathArray[i, fromId] != null) {
//		        for (int j = 0; j < sysTableList.Count; j++) {
//		            if (joinPathArray[toId, j] != null) {
//		                if (joinPathArray[i, j] == null) {
//		                    joinPathArray[i, j] = joinPathArray[i, fromId];
//		                }
//		            }
//		        }
//		    }
//		}
		/* @formatter:on */

		int fromId = -1;
		int toId = -1;
		for (int i = 0; i < sysTableList.size(); i++) {
			if (sysTableList.get(i).equals(fromTable)) {
				fromId = i;
			}
			if (sysTableList.get(i).equals(toTable)) {
				toId = i;
			}
		}
		if (fromId < 0 || toId < 0)
			return;

		// Add the join to the matrix unless it already knows a path between the two
		// tables
		if (joinPathArray[fromId][toId] != null)
			return;
		joinPathArray[fromId][toId] = new KJoin(fromTable, fromField, nextTable, nextField, extraJoinCode);

		// If something knows the direction to from table but not the to table add that
		for (int i = 0; i < sysTableList.size(); i++) {
			if (joinPathArray[i][fromId] != null && joinPathArray[i][toId] == null) {
				joinPathArray[i][toId] = joinPathArray[i][fromId];
			}
		}

		// If we know how to get somewhere from the to but not the from fix it
		for (int i = 0; i < sysTableList.size(); i++) {
			if (joinPathArray[toId][i] != null && joinPathArray[fromId][i] == null) {
				joinPathArray[fromId][i] = joinPathArray[fromId][toId];
			}
		}

		// combine info on all the ways we know to get to from whti all the plces we
		// know to go from to
		for (int i = 0; i < sysTableList.size(); i++) {
			if (joinPathArray[i][fromId] != null) {
				for (int j = 0; j < sysTableList.size(); j++) {
					if (joinPathArray[toId][j] != null) {
						if (joinPathArray[i][j] == null) {
							joinPathArray[i][j] = joinPathArray[i][fromId];
						}
					}
				}
			}
		}
	}

	// private void replaceStep(List<string> sysTableArray, KJoin[,] joinPathArray,
	// string fromTable, string toTable, string nextTable, string fromField, string
	// nextField) {
	private void replaceStep(String fromTable, String toTable, String nextTable, String fromField, String nextField) {
		/* @formatter:off */
//    int fromId = -1;
//    int toId = -1;
//    for (int i = 0; i < sysTableArray.Count; i++) {
//        if (sysTableArray[i] == fromTable) { fromId = i; }
//        if (sysTableArray[i] == toTable) { toId = i; }
//    }
//    if (fromId < 0 || toId < 0) return;
//    joinPathArray[fromId, toId] = new KJoin(fromTable, fromField, nextTable, nextField, null);
//}
		/* @formatter:on */

		int fromId = -1;
		int toId = -1;
		for (int i = 0; i < sysTableList.size(); i++) {
			if (sysTableList.get(i).equals(fromTable)) {
				fromId = i;
			}
			if (sysTableList.get(i).equals(toTable)) {
				toId = i;
			}
		}
		if (fromId < 0 || toId < 0)
			return;
		joinPathArray[fromId][toId] = new KJoin(fromTable, fromField, nextTable, nextField, null);
	}

	// public string generateFromClause(string fromTable, string joinType,
	// List<string> joinTables) {
	public String generateFromClause(String fromTable, String joinType, Collection<String> joinTables) {
		/* @formatter:off */
//		string sql = " FROM " + fromTable;
//		
//		// Add joins for each of the tables 
//		List<string> joinedTables = new List<string>();
//		joinedTables.Add(fromTable);
//		foreach (string nTable in joinTables) {
//		    string needed = nTable;
//		    if (alias2Full.ContainsKey(needed)) { needed = alias2Full[needed]; }
//		    if (!joinedTables.Contains(needed)) {
//		        List<KJoin> joinPath = joinPathToTable(fromTable, needed);
//		        // add each join in the path to the table unless it's already joined
//		        foreach (KJoin nextJoin in joinPath) {
//		            if (!joinedTables.Contains(nextJoin.ToTableName)) {
//		                sql += joinType + nextJoin.ToString();
//		                joinedTables.Add(nextJoin.ToTableName);
//		            }
//		        }
//		    }
//		}
		/* @formatter:on */

		StringBuilder sql = new StringBuilder(200);
		sql.append(" FROM ").append(fromTable);

		// Add joins for each of the tables
		List<String> joinedTables = new ArrayList<String>();
		joinedTables.add(fromTable);
		for (String nTable : joinTables) {
			var alias = alias2Full.get(nTable);
			String needed = alias != null ? alias :nTable;
			if (!joinedTables.contains(needed)) {
				List<KJoin> joinPath = joinPathToTable(fromTable, needed);
				// add each join in the path to the table unless it's already joined
				for (KJoin nextJoin : joinPath) {
					log.debug("Adding {}", nextJoin);
					if (!joinedTables.contains(nextJoin.toTableName)) {
						sql.append(" ").append(joinType).append(" ").append(nextJoin);
						joinedTables.add(nextJoin.toTableName);
						if (joinedTables.size() > 10) {
							throw new RuntimeException("This is dumb!");
						}
					}
				}
			}
		}

		return sql.toString();
	}

	// private List<KJoin> joinPathToTable(string fromTable, string toTable) {
	private List<KJoin> joinPathToTable(String fromTable, String toTable) {
		/* @formatter:off */
//		List<KJoin> joinResult = new List<KJoin>();
//		string sofar = fromTable;
//		do {
//		    KJoin nextStep = nextJoinInPath(sofar, toTable);
//		    if (nextStep == null) break;
//		    joinResult.Add(nextStep);
//		    sofar = nextStep.ToTableName;
//		} while (sofar != toTable);
//		return joinResult;
		/* @formatter:on */
		
		String cacheKey = fromTable + ":" + toTable;

		var cache = fixedCache.get(cacheKey);
		if (cache != null) {
			return cache;
		}

		List<KJoin> joinResult = new ArrayList<>();
		String sofar = fromTable;
		do {
			KJoin nextStep = nextJoinInPath(sofar, toTable);
			if (nextStep == null)
				break;
			if (joinResult.contains(nextStep)) {
				log.debug("Already contains the same step: {}", nextStep);
				break;
			}
			if (joinResult.size() > 10) {
				throw new RuntimeException("This will not work");
			}
			joinResult.add(nextStep);
			sofar = nextStep.toTableName;
		} while (!sofar.equals(toTable));

		fixedCache.put(cacheKey, joinResult);
		return joinResult;
	}

	// private KJoin nextJoinInPath(string fromTable, string toTable) {
	private KJoin nextJoinInPath(String fromTable, String toTable) {
		/* @formatter:off */
//		int fromId = -1;
//		int toId = -1;
//		for (int i = 0; i < sysTableList.Count; i++) {
//		    if (sysTableList[i] == fromTable) { fromId = i; }
//		    if (sysTableList[i] == toTable) { toId = i; }
//		}
//		if (fromId < 0 || toId < 0) return null;
//		return joinPathArray[fromId, toId];
		/* @formatter:on */

		int fromId = -1;
		int toId = -1;
		for (int i = 0; i < sysTableList.size(); i++) {
			if (sysTableList.get(i).equals(fromTable)) {
				fromId = i;
			}
			if (sysTableList.get(i).equals(toTable)) {
				toId = i;
			}
		}
		if (fromId < 0 || toId < 0)
			return null;
		return joinPathArray[fromId][toId];
	}

	public static TableMatrix from(List<Row> rows) {
		List<SysMatrixElement> all = rows.stream().map(row -> {
			return new SysMatrixElement((String) row.getValue(0), (String) row.getValue(1), (String) row.getValue(2), (String) row.getValue(3), (String) row.getValue(4), (String) row.getValue(5), (String) row.getValue(6));
		}).collect(Collectors.toList());

		/* @formatter:off */
//		// Create a list of distinct tables and table aliases seen in the input
//		sysTableList = new List<string>();
//		alias2Full = new Dictionary<string, string>();
//		foreach (DataRow dr in spanningTreeJoinList.Tables["sys_matrix_input"].Rows) {
//		    string parentTable = dr["parent_table"].ToString().Trim().ToLower();
//		    string childTable = dr["child_table"].ToString().Trim().ToLower();
//		    if (!sysTableList.Contains(parentTable)) {
//		        sysTableList.Add(parentTable);
//		        if (parentTable.Contains(" ")) {
//		            string alias = parentTable.Substring(parentTable.IndexOf(" ") + 1);
//		            //throw Library.CreateBusinessException(getDisplayMember("hasPermission{nomask}", "DEBUG: adding alias '" + alias + "' = '" + parentTable + "'"));
//		            alias2Full.Add(alias, parentTable);
//		        }
//		    }
//		    if (!sysTableList.Contains(childTable)) {
//		        sysTableList.Add(childTable);
//		        if (childTable.Contains(" ")) {
//		            string alias = childTable.Substring(childTable.IndexOf(" ") + 1);
//		            alias2Full.Add(alias, childTable);
//		        }
//		    }
//		}
		/* @formatter:on */

		List<String> sysTableList = new ArrayList<>();
		Map<String, String> alias2Full = new HashMap<String, String>();

		all.forEach(dr -> {
			if (!sysTableList.contains(dr.parentTable)) {
				sysTableList.add(dr.parentTable);
				if (dr.parentTable.contains(" ")) {
					var alias = dr.parentTable.substring(dr.parentTable.indexOf(" ") + 1);
					// throw
					// Library.CreateBusinessException(getDisplayMember("hasPermission{nomask}",
					// "DEBUG: adding alias '" + alias + "' = '" + parentTable + "'"));
					alias2Full.put(alias, dr.parentTable);
				}
			}
			if (!sysTableList.contains(dr.childTable)) {
				sysTableList.add(dr.childTable);
				if (dr.childTable.contains(" ")) {
					var alias = dr.childTable.substring(dr.childTable.indexOf(" ") + 1);
					alias2Full.put(alias, dr.childTable);
				}
			}
		});

		TableMatrix matrix = new TableMatrix(sysTableList, alias2Full);

		/* @formatter:off */
//		// process each matrix input to configure the join matrix
//		foreach (DataRow dr in spanningTreeJoinList.Tables["sys_matrix_input"].Rows) {
//		    string childTable = dr["child_table"].ToString().Trim().ToLower();
//		    string parentTable = dr["parent_table"].ToString().Trim().ToLower();
//		    string childField = dr["child_field"].ToString().Trim().ToLower();
//		    if (string.IsNullOrEmpty(childField)) { childField = parentTable + "_id"; }
//		    string parentField = dr["parent_field"].ToString().Trim().ToLower();
//		    if (string.IsNullOrEmpty(parentField)) { parentField = parentTable + "_id"; }
//		    string extra = dr["extra_join_code"].ToString().Trim().ToLower();
//		    string command = dr["command"].ToString().Trim().ToLower();
//		    if (command == "map") {
//		        addLink(parentTable, childTable, childTable, parentField, childField, extra);
//		        addLink(childTable, parentTable, parentTable, childField, parentField, extra);
//		    }
//		    if (command == "replace") {
//		        string destTable = dr["dest_table"].ToString().Trim().ToLower();
//		        if (string.IsNullOrEmpty(destTable)) { destTable = parentTable; }
//		        replaceStep(sysTableList, joinPathArray, childTable, destTable, parentTable, childField, parentField);
//		        if (destTable == parentTable) {
//		            replaceStep(sysTableList, joinPathArray, parentTable, childTable, childTable, parentField, childField);
//		        }
//		    }
//		}
		/* @formatter:on */

		all.forEach(dr -> {
			if (dr.command.equals("map")) {
				matrix.addLink(dr.parentTable, dr.childTable, dr.childTable, dr.parentField, dr.childField, dr.extraJoinCode);
				matrix.addLink(dr.childTable, dr.parentTable, dr.parentTable, dr.childField, dr.parentField, dr.extraJoinCode);
			}
			if (dr.command.equals("replace")) {
				matrix.replaceStep(dr.childTable, dr.destTable, dr.parentTable, dr.childField, dr.parentField);
				if (dr.destTable.equals(dr.parentTable)) {
					matrix.replaceStep(dr.parentTable, dr.childTable, dr.childTable, dr.parentField, dr.childField);
				}
			}
		});

		return matrix;
	}

	private static class SysMatrixElement {
		String command;
		String childTable;
		String parentTable;
		String destTable; /* defaults to parent */
		String childField; /* defaults to parent_table_id */
		String parentField; /* defaults to parent_table_id */
		String extraJoinCode; /* rarely used */

		public SysMatrixElement(String command, String childTable, String parentTable, String destTable, String childField, String parentField, String extras) {
			this.command = command.toLowerCase().trim();
			this.childTable = childTable.toLowerCase().trim();
			this.parentTable = parentTable.toLowerCase().trim();
			this.destTable = StringUtils.defaultIfBlank(StringUtils.trim(destTable), this.parentTable).toLowerCase().trim();
			this.childField = StringUtils.defaultIfBlank(StringUtils.trim(childField), this.parentTable + "_id").toLowerCase().trim();
			this.parentField = StringUtils.defaultIfBlank(StringUtils.trim(parentField), this.parentTable + "_id").toLowerCase().trim();
			this.extraJoinCode = StringUtils.trimToNull(extras);
		}

		@Override
		public String toString() {
			return command + " " + childTable + " to " + parentTable + " dest (defaults to parent)=" + destTable + " cfield=" + childField + " pfield=" + parentField + " extra="
					+ extraJoinCode;
		}
	}

	/**
	 * Java implementation of the KJoin class.
	 * 
	 * @author .Net version by USDA-NPGS
	 */
	private static class KJoin {
		private String fromTableName;
		private String fromFieldName;
		private String toTableName;
		private String toFieldName;
		private String extraJoinCode;

		public KJoin(String fromTableName, String fromFieldName, String toTableName, String toFieldName, String extraJoinCode) {
			this.fromTableName = fromTableName;
			this.fromFieldName = fromFieldName;
			this.toTableName = toTableName;
			this.toFieldName = toFieldName;
			this.extraJoinCode = extraJoinCode;
		}

		public String toString() {
			StringBuilder returnJoin = new StringBuilder(100);

			String fromField = fromTableName.trim();

			if (fromField.contains(" ")) {
				int pos = fromField.indexOf(" ");
				// C#: fromField = fromField.Substring(pos, fromField.Length - pos);
				fromField = fromField.substring(pos);
			}
			fromField += "." + fromFieldName;

			String toField = toTableName.trim();
			if (toField.contains(" ")) {
				int pos = toField.indexOf(" ");
				// C#: toField = toField.Substring(pos, toField.Length - pos);
				toField = toField.substring(pos);
			}
			toField += "." + toFieldName;

			returnJoin.append("JOIN ").append(toTableName);
			returnJoin.append(" ON ").append(toField).append(" = ");
			returnJoin.append(fromField);
			if (!StringUtils.isBlank(extraJoinCode)) {
				returnJoin.append(" ").append(extraJoinCode);
			}
			return returnJoin.toString();
		}
	}
}