#!/usr/bin/python

# reconf-inetd - reconfigure and restart inetd
# Copyright (C) 2010, 2011, 2012 Serafeim Zanikolas <sez@debian.org>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

import unittest
import logging
from StringIO import StringIO

from reconf_inetd import XFragmentContainer, XFragmentParser, InetdService, \
        InetdServiceContainer, InetdConfParser, InvalidEntryException, \
        MissingFieldException

logger = logging.getLogger()
logger.addHandler(logging.FileHandler('tests.log'))
logger.setLevel(logging.DEBUG)

class TestFragmentData(unittest.TestCase):

    fragment1 = \
    """service ftpd-ssl
    {
        disable		= yes
        type		= INTERNAL
        id		= ftpd-ssl-stream
        socket_type	= stream
        protocol	= tcp
        user		= root
        wait		= no
        server		= /usr/sbin/ftpd-ssl
        server_args = -d
    }"""

    finger_fragment = \
    """service finger
    {
        disable = no
        socket_type = stream
        protocol = tcp
        flags = IPv6
        wait = no
        user = root
        server = /path/to/non-existent-server
        log_type = SYSLOG daemon info
        log_on_success = HOST
        log_on_failure = HOST
    }"""

    fragment_unlisted_srv_without_protocol = \
    """service nonstd_service
    {
        disable		= no
        flags		= IPv6
        socket_type	= stream
        wait		= no
        user		= root
        server		= /usr/sbin/in.ftpd
        server_args	= -l
        log_type	= SYSLOG daemon info
        log_on_failure	= HOST
    }"""

    fragment_listed_srv_without_protocol = \
    """service ftp
    {
        disable		= no
        flags		= IPv6
        socket_type	= stream
        wait		= no
        user		= root
        server		= /usr/sbin/in.ftpd
        server_args	= -l
        log_type	= SYSLOG daemon info
        log_on_failure	= HOST
    }"""

    tcpd_fragment = \
    """service nntp
    {
        socket_type	= stream
        protocol	= tcp
        wait		= no
        user		= news
        flags = NAMEINARGS
        server		= /usr/sbin/tcpd
        server_args	= /usr/sbin/leafnode
    }"""

    result1 = 'ftpd-ssl stream tcp nowait root /usr/sbin/ftpd-ssl -d\n'
    result2 = 'proftpd		stream	tcp	nowait	root	/usr/sbin/proftpd\n'
    listed_srv_without_protocol = 'ftp stream tcp nowait root /usr/sbin/in.ftpd -l\n'
    tcpd_result = "nntp stream tcp nowait news /usr/sbin/tcpd /usr/sbin/leafnode\n"

    # wait: no --> nowait; yes --> wait
    # del id, disable

    def _get_first_service(self, container, service):
        self.assertTrue(len(container.all_services), 1)
        srv = [s for s in container.all_services if s.get_name() == service]
        self.assertTrue(srv)
        return srv[0]

    def _get_fragment_parser(self):
        fc = XFragmentContainer(logger)
        x = XFragmentParser(fc, logger)
        x.current_filename = 'some-file'
        return x

    def _get_conf_parser(self):
        inetd_serv_container = InetdServiceContainer(logger)
        return InetdConfParser(logger, inetd_serv_container, StringIO())

class TestXFragmentParser(TestFragmentData):

    # test successful parsing of valid input
    # XXX add a test case for rpc from xinetd readme
    def test_success1(self):
        x = self._get_fragment_parser()
        x.parse(StringIO(TestXFragmentParser.fragment1))
        fragment = self._get_first_service(x, service)
        self.assertEqual(fragment['disable'], 'yes')
        self.assertEqual(fragment['type'], 'INTERNAL')
        self.assertEqual(fragment['id'], 'ftpd-ssl-stream')
        self.assertEqual(fragment['socket_type'], 'stream')
        self.assertEqual(fragment['protocol'], 'tcp')
        self.assertEqual(fragment['user'], 'root')
        self.assertEqual(fragment['wait'], 'no')
        self.assertEqual(fragment['server_args'], '-d')

    # test exception throwing for invalid input
    def test_missing_opening_brace(self):
        x = self._get_fragment_parser()
        f = StringIO(TestXFragmentParser.fragment1.replace('{', ''))
        self.assertRaises(InvalidEntryException, x.parse, f)

    def test_missing_closing_brace(self):
        x = self._get_fragment_parser()
        f = StringIO(TestXFragmentParser.fragment1.replace('}', ''))
        self.assertRaises(InvalidEntryException, x.parse, f)

    def test_typo1(self):
        x = self._get_fragment_parser()
        f = StringIO(TestXFragmentParser.fragment1.replace('}', '}a'))
        self.assertRaises(InvalidEntryException, x.parse, f)

    def test_typo2(self):
        x = self._get_fragment_parser()
        f = StringIO(TestXFragmentParser.fragment1.replace('}', '} a'))
        self.assertRaises(InvalidEntryException, x.parse, f)

    def test_missing_mandatory_field(self):
        x = self._get_fragment_parser()
        f = StringIO(TestXFragmentParser.fragment_unlisted_srv_without_protocol)
        self.assertRaises(MissingFieldException, x.parse, f)

    # convertion of xinetd fragments to /etc/inetd.conf entries
    def test_guessed_protocol_for_listed_service(self):
        x = self._get_fragment_parser()
        x.parse(StringIO(TestXFragmentParser.fragment_listed_srv_without_protocol))
        fragment = self._get_first_service(x.container, 'ftp')
        self.assertEqual(fragment.to_inetd(), TestXFragmentParser.listed_srv_without_protocol)

    def test_success1(self):
        x = self._get_fragment_parser()
        x.parse(StringIO(TestXFragmentParser.fragment1))
        fragment = self._get_first_service(x.container, 'ftpd-ssl')
        self.assertEqual(fragment.to_inetd(), TestXFragmentParser.result1)

    def test_success1(self):
        x = self._get_fragment_parser()
        x.parse(StringIO(TestXFragmentParser.tcpd_fragment))
        fragment = self._get_first_service(x.container, 'nntp')
        self.assertEqual(fragment.to_inetd(), TestXFragmentParser.tcpd_result)

class TestXFragmentContainer(TestFragmentData):

    def test_disable_is_no_but_non_existing_server_path(self):
        x = self._get_fragment_parser()
        x.parse(StringIO(self.finger_fragment))
        self.assertEqual(x.container.get_valid_services(), [])

    def test_disable_is_yes_and_non_existing_server_path(self):
        x = self._get_fragment_parser()
        f = self.finger_fragment.replace('disable = no',
                'disable = yes')
        x.parse(StringIO(self.finger_fragment))
        self.assertEqual(x.container.get_valid_services(), [])

# obsolete test case; we ignore the 'disable' field in fragment files
#    def test_existing_server_path_but_disable_is_yes(self):
#        x = self._get_fragment_parser()
#        f = self.finger_fragment.replace('disable = no',
#                'disable = yes')
#        # set server path to something that's guaranteed to exist
#        f = f.replace('server = /path/to/non-existent-server',
#                'server = /bin/ls')
#        x.parse(StringIO(f))
#        self.assertEqual(x.container.get_valid_services(), [])

    def test_existing_server_path(self):
        x = self._get_fragment_parser()
        # set server path to something that's guaranteed to exist
        f = self.finger_fragment.replace(
                'server = /path/to/non-existent-server', 'server = /bin/ls')
        x.parse(StringIO(f))
        enabled_services = x.container.get_valid_services()
        self.assertEqual(len(enabled_services), 1)
        self.assertEqual(enabled_services[0].get_name(), 'finger')

class TestInetdConfParser(TestFragmentData):

    entry1 = 'proftpd		stream	tcp	nowait	root	/usr/sbin/proftpd -d'
    result1 = dict([ \
            ['service', 'proftpd'],
            ['socket_type', 'stream'],
            ['protocol', 'tcp'],
            ['user', 'root'],
            ['wait', 'nowait'],
            ['server', '/usr/sbin/proftpd'],
            ['server_args', '-d'],
            ['status', InetdService.ENABLED]])

    entry2 = "nntp stream tcp nowait news /usr/sbin/tcpd /usr/sbin/leafnode\n"
    result2 = dict([
            ['service', 'nntp'],
            ['socket_type', 'stream'],
            ['protocol', 'tcp'],
            ['user', 'news'],
            ['wait', 'nowait'],
            ['server', '/usr/sbin/tcpd'],
            ['server_args', '/usr/sbin/leafnode'],
            ['flags', 'NAMEINARGS'],
            ['status', InetdService.ENABLED]])

    def _get_first_service(self, container, service):
        self.assertTrue(len(container.all_services), 1)
        srv = [s for s in container.all_services if s.get_name() == service]
        self.assertTrue(srv)
        return srv[0]

    def test_enabled(self):
        p = self._get_conf_parser()
        p.parse_line(TestInetdConfParser.entry1, 1)
        result = self._get_first_service(p.container, 'proftpd')
        for key, expected_value in TestInetdConfParser.result1.iteritems():
            self.assertEqual(result.attrs[key], expected_value)

    def test_user_disabled(self):
        p = self._get_conf_parser()
        p.parse_line("# %s" % TestInetdConfParser.entry1, 1)
        result = self._get_first_service(p.container, 'proftpd')
        expected_result = TestInetdConfParser.result1.copy()
        expected_result['status'] = InetdService.USER_DISABLED
        for key, expected_value in expected_result.iteritems():
            self.assertEqual(result.attrs[key], expected_value)

    def test_maintainer_disabled(self):
        p = self._get_conf_parser()
        p.parse_line("#<off># %s" % TestInetdConfParser.entry1, 1)
        result = self._get_first_service(p.container, 'proftpd')
        expected_result = TestInetdConfParser.result1.copy()
        expected_result['status'] = InetdService.MAINT_DISABLED
        for key, expected_value in expected_result.iteritems():
            self.assertEqual(result.attrs[key], expected_value)

    def test_has_matching_entry(self):
        inetd_conf_parser = self._get_conf_parser()
        inetd_conf_parser.parse_line(TestInetdConfParser.entry1, 1)
        self.assertEqual(len(inetd_conf_parser.container.all_services), 1)

        fragment_parser = self._get_fragment_parser()
        raw_fragment = "service proftpd\n{\n %s }" % "".join(["%s = %s\n" %
            (k,v) for k,v in TestInetdConfParser.result1.iteritems()])
        fragment_parser.parse(StringIO(raw_fragment))
        self.assertEqual(len(fragment_parser.container.all_services), 1)
        fragment = fragment_parser.container.all_services.values()[0]

        self.assertTrue(inetd_conf_parser.container.has_matching_entry(fragment))

    def test_has_matching_entry_with_different_srv_args(self):
        # same as above but with a different server argument
        inetd_conf_parser = self._get_conf_parser()
        inetd_conf_parser.parse_line(TestInetdConfParser.entry1, 1)
        self.assertEqual(len(inetd_conf_parser.container.all_services), 1)

        fragment_parser = self._get_fragment_parser()
        fragment_elements = TestInetdConfParser.result1.copy()
        fragment_elements['server_args'] = '--debug'
        raw_fragment = "service proftpd\n{\n %s }" % "".join(["%s = %s\n" %
            (k,v) for k,v in fragment_elements.iteritems()])
        fragment_parser.parse(StringIO(raw_fragment))
        self.assertEqual(len(fragment_parser.container.all_services), 1)
        fragment = fragment_parser.container.all_services.values()[0]

        self.assertTrue(inetd_conf_parser.container.has_matching_entry(fragment))

    def test_has_matching_entry_with_tcpd(self):
        inetd_conf_parser = self._get_conf_parser()
        inetd_conf_parser.parse_line(TestInetdConfParser.entry2, 1)
        self.assertEqual(len(inetd_conf_parser.container.all_services), 1)

        fragment_parser = self._get_fragment_parser()
        raw_fragment = "service nntp\n{\n %s }" % "".join(["%s = %s\n" %
            (k,v) for k,v in TestInetdConfParser.result2.iteritems()])
        fragment_parser.parse(StringIO(raw_fragment))
        self.assertEqual(len(fragment_parser.container.all_services), 1)
        fragment = fragment_parser.container.all_services.values()[0]

        self.assertTrue(inetd_conf_parser.container.has_matching_entry(fragment))



if __name__ == "__main__":
    unittest.main()
