Subversion Repositories general

Rev

Rev 1072 | Blame | Compare with Previous | Last modification | View Log | RSS feed

package ak.hostadmiral.core.servlet;

import java.util.List;
import java.util.ArrayList;
import java.util.StringTokenizer;
import java.io.IOException;
import java.net.URLEncoder;
import java.security.Principal;
import java.security.AccessControlException;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.log4j.Logger;

import ak.backpath.BackPath;

import ak.hostadmiral.util.ModelException;
import ak.hostadmiral.core.servlet.LoginInfo;
import ak.hostadmiral.core.model.User;
import ak.hostadmiral.core.model.UserManager;

/**
 * Ensures that user is logged in to the system to process its request.
 */
public class LoginFilter
        implements Filter
{
        private static final Logger logger = Logger.getLogger(LoginFilter.class);

        private FilterConfig filterConfig;
        private String       loginUrl;
        private String       loginServlet;
        private List         passUrls  = new ArrayList();
        private List         passMasks = new ArrayList();

        public void init(FilterConfig filterConfig)
                throws ServletException
        {
                // get config
                this.filterConfig = filterConfig;

                if(filterConfig == null)
                        throw new ServletException("No configuration for the filter");

                // get login url
                loginUrl = filterConfig.getInitParameter("loginUrl");

                if(loginUrl == null)
                        throw new ServletException("No login URL specified");

                // ensure it's absolute path
                if(!loginUrl.startsWith("/"))
                        loginUrl = "/" + loginUrl;

                // get servlet part ot the url
                int qPos = loginUrl.indexOf("?");

                if(qPos < 0)
                        loginServlet = loginUrl;
                else
                        loginServlet = loginUrl.substring(0, qPos);

                // get pass through URLs
                String passUrlsStr = filterConfig.getInitParameter("passUrls");
                if(passUrlsStr != null) {
                        String[] urls = passUrlsStr.split("\\s*;\\s*");

                        for(int i = 0; i < urls.length; i++) {
                        if(urls[i].endsWith("*")) {
                                passMasks.add(urls[i].substring(0, urls[i].length()-1));
                        }
                        else {
                                        passUrls.add(urls[i]);
                        }
                        }
                }

                // avoid loop
                if(!isPassThrough(loginServlet)) {
                        passUrls.add(loginServlet);
                }
        }

        private boolean isPassThrough(String url)
        {
        for(int i = 0; i < passUrls.size(); i++) {
                if(url.equals((String)passUrls.get(i))) return true;
        }

        for(int i = 0; i < passMasks.size(); i++) {
                if(url.startsWith((String)passMasks.get(i))) return true;
        }

        return false;
        }

        public void doFilter(ServletRequest request, ServletResponse response,
                        FilterChain chain)
                throws IOException, ServletException
        {
                boolean processNext;

                if(!(request instanceof HttpServletRequest))
                        throw new ServletException("Do not know how to handle non-HTTP requests");
                if(!(response instanceof HttpServletResponse))
                        throw new ServletException("Do not know how to handle non-HTTP responses");

        HttpServletRequest  httpRequest  = (HttpServletRequest)request;
        HttpServletResponse httpResponse = (HttpServletResponse)response;

        logger.debug("Requested " + httpRequest.getServletPath());

                if(isPassThrough(httpRequest.getServletPath())) {
                        processNext = true;
                        logger.debug("pass through");
                }
                else {
                        try {
                                HttpSession session = httpRequest.getSession(false);

                                if(session == null)
                                        throw new AccessControlException("No session");

                                Object userObj = session.getAttribute(SessionKeys.USER);
                                if(userObj == null) {
                                        // try to relogin
                                        Object loginInfoObj = session.getAttribute(SessionKeys.LOGIN_INFO);
                                        if(loginInfoObj == null)
                                                throw new AccessControlException("No user");

                                        if(!(loginInfoObj instanceof LoginInfo))
                                                throw new ServletException(
                                                        "Wrong type of login info information: "
                                                                + loginInfoObj.getClass().getName());

                    try {
                                                userObj = UserManager.getInstance().get(((LoginInfo)loginInfoObj).getId());
                    }
                    catch(ModelException ex) {
                                                throw new AccessControlException("No user");
                    }

                                        if(userObj == null)
                                                throw new AccessControlException("No user");

                                        session.setAttribute(SessionKeys.USER, userObj);
                                logger.debug("User re-logined: " + userObj);
                                }
                else {
                                        if(!(userObj instanceof User))
                                                throw new ServletException(
                                                        "Wrong type of user information: " + userObj.getClass().getName());

                                        logger.debug("User found - OK");
                                }

                                processNext = true;
                        }
                        catch(AccessControlException ex) {
                                String redirectUrl;
                                try {
                                        redirectUrl = httpRequest.getContextPath() + loginUrl
                                                + BackPath.findBackPath(httpRequest).getForwardParams();
                                }
                                catch(Exception ex2) {
                                        logger.error("Cannot get forward redirect", ex2);
                                        redirectUrl = httpRequest.getContextPath() + loginUrl;
                                }

                                logger.info("Redirect because of '" + ex.getMessage() + "' to " + redirectUrl);
                                httpResponse.sendRedirect(httpResponse.encodeRedirectURL(redirectUrl));

                                processNext = false;
                        }
                }

                if(processNext) {              // no problems found
                        chain.doFilter(request, response);
                }
        }

        public void destroy()
        {
        }
}