LocaleURLFilter.java

/*
 * Copyright 2020 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.gringlobal.spring.locale;

import java.io.IOException;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Locale;
import java.util.regex.Matcher;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.web.filter.GenericFilterBean;

/**
 * Handles the locale selection by URL prefix (/es, /pt, ...)
 */
@Slf4j
public class LocaleURLFilter extends GenericFilterBean {

	private static final LocaleURLMatcher localeUrlMatcher = new LocaleURLMatcher();
	public static final String REQUEST_LOCALE_ATTR = LocaleURLFilter.class.getName() + ".LOCALE";
	private static final String REQUEST_INTERNAL_URL = LocaleURLFilter.class.getName() + ".INTERNALURL";
	private static final String REQUEST_LOCALE_LANGUAGE = LocaleURLFilter.class.getName() + ".LANGUAGE";

	private String[] allowedLocales = null;
	private Locale defaultLocale;

	public void setDefaultLocale(Locale defaultLocale) {
		this.defaultLocale = defaultLocale;
		log.info("Using default locale: {}", this.defaultLocale);
	}

	public void setAllowedLocales(String... allowedLocales) {
		this.allowedLocales = allowedLocales;
		log.info("Using allowed locales: {}", Arrays.toString(this.allowedLocales));
	}

	public void setExcludePaths(String... excludePaths) {
		for (final String e : excludePaths) {
			log.info("Excluding path: {}", e);
		}
		localeUrlMatcher.setExcludedPaths(excludePaths);
	}

	@Override
	public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
		doFilterInternal((HttpServletRequest)request, (HttpServletResponse)response, chain);
	}

	private void doFilterInternal(HttpServletRequest httpRequest, HttpServletResponse httpResponse, FilterChain filterChain) throws IOException, ServletException {
		final String url = httpRequest.getRequestURI().substring(httpRequest.getContextPath().length());
		log.debug("LocaleURLFilter {}", url);

		if (localeUrlMatcher.isExcludedPath(url)) {
			log.debug("Excluded " + url);
			filterChain.doFilter(httpRequest, httpResponse);
			return;
		}

		log.debug("Incoming URL: {}", url);
		{
			final Enumeration<String> attrNames = httpRequest.getAttributeNames();
			while (attrNames.hasMoreElements()) {
				final String attrName = attrNames.nextElement();
				log.trace("Request attr {} = {}", attrName, httpRequest.getAttribute(attrName));
			}
		}

		String existingUrlLanguage = (String) httpRequest.getAttribute(REQUEST_LOCALE_LANGUAGE);
		if (existingUrlLanguage != null) {
			final LocaleWrappedServletResponse localeResponse = new LocaleWrappedServletResponse(httpResponse, localeUrlMatcher, existingUrlLanguage, defaultLocale
				.toLanguageTag());
			log.debug("Found REQUEST_LOCALE_LANGUAGE {} in request", existingUrlLanguage);
			filterChain.doFilter(httpRequest, localeResponse);
			return;
		}

		final Matcher matcher = localeUrlMatcher.matcher(url);
		if (matcher.matches()) {
			final String urlLanguage = matcher.group(1);
			final String remainingUrl = matcher.group(2);

			if (this.allowedLocales != null) {
				boolean localeAllowed = false;
				for (final String allowedLocale : this.allowedLocales) {
					if (allowedLocale.equalsIgnoreCase(urlLanguage)) {
						localeAllowed = true;
						break;
					}
				}

				if (!localeAllowed) {
					log.warn("Locale not allowed. Temporary redirect to default locale.");
					httpResponse.sendRedirect(getInternalUrl(remainingUrl, httpRequest.getQueryString()));
					return;
				}
			}

			final Locale urlLocale = Locale.forLanguageTag(urlLanguage);

			// Redirect for default locale
			if (urlLocale.equals(this.defaultLocale)) {
				final String defaultLocaleUrl = getInternalUrl(remainingUrl, httpRequest.getQueryString());
				log.info("Default locale requested, permanent-redirect to {}", defaultLocaleUrl);

				httpResponse.reset();
				httpResponse.setStatus(HttpServletResponse.SC_MOVED_PERMANENTLY);
				httpResponse.setHeader("Location", defaultLocaleUrl);
				return;
			}

			httpRequest.setAttribute(REQUEST_LOCALE_ATTR, urlLocale);
			httpRequest.setAttribute(REQUEST_LOCALE_LANGUAGE, urlLanguage);
			httpRequest.setAttribute(REQUEST_INTERNAL_URL, getInternalUrl(remainingUrl, httpRequest.getQueryString()));

			if (log.isTraceEnabled()) {
				log.debug("URL matches! lang={} remaining={}", urlLanguage, remainingUrl);
				log.debug("Country: {} Lang: {} locale={}", urlLocale.getCountry(), urlLocale.getLanguage(), urlLocale);

				final Enumeration<String> attrNames = httpRequest.getAttributeNames();
				while (attrNames.hasMoreElements()) {
					final String attrName = attrNames.nextElement();
					log.debug("Request attr {} = {}", attrName, httpRequest.getAttribute(attrName));
				}

				log.debug("Proxying request to remaining URL {}", remainingUrl);
			}

			LocaleContextHolder.setLocale(urlLocale);

			final LocaleWrappedServletResponse localeResponse = new LocaleWrappedServletResponse(httpResponse, localeUrlMatcher, urlLanguage, defaultLocale.toLanguageTag());
			final LocaleWrappedServletRequest localeRequest = new LocaleWrappedServletRequest(httpRequest, url, remainingUrl);

			filterChain.doFilter(localeRequest, localeResponse);
		} else {
			log.debug("No match on url {} setting {}", url, getInternalUrl(url, httpRequest.getQueryString()));
			httpRequest.setAttribute(REQUEST_INTERNAL_URL, getInternalUrl(url, httpRequest.getQueryString()));
			httpRequest.setAttribute(REQUEST_LOCALE_LANGUAGE, "");

			LocaleContextHolder.setLocale(defaultLocale);

			final LocaleWrappedServletResponse localeResponse = new LocaleWrappedServletResponse(httpResponse, localeUrlMatcher, null, defaultLocale.toLanguageTag());
			filterChain.doFilter(httpRequest, localeResponse);
		}
	}

	private String getInternalUrl(final String url, final String queryString) {
		if (StringUtils.isBlank(queryString))
			return url;
		else
			return url + "?" + queryString;
	}

	@Slf4j
	public static class LocaleWrappedServletRequest extends HttpServletRequestWrapper {

		private final String remainingUrl;

		private String originalUrl;

		public LocaleWrappedServletRequest(HttpServletRequest request, String originalUrl, String remainingUrl) {
			super(request);
			this.originalUrl = originalUrl;
			this.remainingUrl = remainingUrl;
		}

		@Override
		public String getServletPath() {
			String servletPath = super.getServletPath();
			if (this.originalUrl.equals(servletPath)) {
				log.trace("servletPath={} remaining={}", servletPath, remainingUrl);
				return remainingUrl;
			}
			return servletPath;
		}

		@Override
		public String getRequestURI() {
			String requestURI = super.getRequestURI();
			if (this.originalUrl.equals(requestURI)) {
				log.trace("requestURI={} remaining={}", requestURI, remainingUrl);
				return remainingUrl;
			}
			return requestURI;
		}
	}

	@Slf4j
	public static class LocaleWrappedServletResponse extends HttpServletResponseWrapper {

		private String prefix;
		private LocaleURLMatcher localeUrlMatcher;
		private String defaultLanguagePrefix;

		public LocaleWrappedServletResponse(HttpServletResponse response, LocaleURLMatcher localeUrlMatcher, String urlLanguage, String defaultLanguage) {
			super(response);
			this.localeUrlMatcher = localeUrlMatcher;
			this.prefix = updatePrefix(urlLanguage);
			log.debug("Response prefix={} lang={}", prefix, urlLanguage);
			this.defaultLanguagePrefix = "/" + defaultLanguage + "/";
		}

		private boolean isExcluded(String url) {
			// Exclude querystring-only urls, URLs starting with //, http:// or https:// and anything excluded by
			// the matcher
			boolean excluded = url.startsWith("?") || url.startsWith("//") || url.startsWith("http://") || url.startsWith("https://") || localeUrlMatcher.isExcluded(url);
			log.trace("isExcluded? {} --> {}", url, excluded);
			return excluded;
		}

		@Override
		public String encodeURL(String url) {
			if (isExcluded(url)) {
				if (url.startsWith(defaultLanguagePrefix)) {
					log.debug("URL starts with defaultLanguagePrefix={} trimming down", defaultLanguagePrefix);
					return super.encodeURL(url.substring(defaultLanguagePrefix.length() - 1));
				}
				return super.encodeURL(url);
			} else {
				String encodedURL = prefix + super.encodeURL(url);
				log.debug("encodeURL {} to {}", url, encodedURL);
				return encodedURL;
			}
		}

		@Override
		@Deprecated
		public String encodeUrl(String url) {
			if (isExcluded(url)) {
				return super.encodeUrl(url);
			} else {
				String encodedURL = prefix + super.encodeUrl(url);
				log.debug("encodeUrl {} to {}", url, encodedURL);
				return encodedURL;
			}
		}

		@Override
		public String encodeRedirectURL(String url) {
			if (isExcluded(url)) {
				return super.encodeRedirectURL(url);
			} else {
				String encodedURL = prefix + super.encodeRedirectURL(url);
				log.debug("encodeRedirectURL {} to {}", url, encodedURL);
				return encodedURL;
			}
		}

		@Override
		@Deprecated
		public String encodeRedirectUrl(String url) {
			if (isExcluded(url)) {
				return super.encodeRedirectUrl(url);
			} else {
				String encodedURL = prefix + super.encodeRedirectUrl(url);
				log.debug("encodeRedirectUrl {} to {}", url, encodedURL);
				return encodedURL;
			}
		}

		@Override
		public void sendRedirect(String location) throws IOException {
			if (isExcluded(location)) {
				super.sendRedirect(location);
			} else {
				String prefixedUrl = prefix + location;
				log.debug("sendRedirect {} to {}", location, prefixedUrl);
				super.sendRedirect(prefixedUrl);
			}
		}

		@Override
		public void setHeader(String name, String value) {
			log.debug("setHeader {}: {}", name, value);
			if ("Location".equalsIgnoreCase(name)) {
				if (isExcluded(value)) {
					super.setHeader(name, value);
				} else {
					String prefixedUrl = prefix + value;
					log.debug("Rewrote redirect header {}: {} -> {}", name, value, prefixedUrl);
					super.setHeader(name, prefixedUrl);
				}
			} else {
				super.setHeader(name, value);
			}
		}

		private String updatePrefix(String language) {
			if (StringUtils.isBlank(language)) {
				return "";
			} else {
				return "/" + language;
			}
		}
	}
}