#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Copyright(C) 2007-2008 INL
Written by Romain Bignon <romain AT inl.fr>

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.

$Id: tools.py 12192 2008-01-14 16:48:27Z romain $
"""

import SOAPpy
from IPy import IP
import sys, traceback
import socket
import time
from nevow import tags
from nevow.i18n import _, PlaceHolder, flattenL10n
from datetime import datetime
import re
from time import strptime

NORMALIZE_REGEX = re.compile("[-/.:]+")

# Date regex: DD-MM
DATE_REGEX_SHORT = re.compile("^([0-3]?[0-9])~([01][0-9])$")

# Date regex: DD-MM-YY
DATE_REGEX_LONG = re.compile("^([0-3]?[0-9])~([01]?[0-9])~([0-9]{1,4})$")

# Datetime regex: "HH:MM:SS" (FR format)
DATETIME_REGEX = re.compile("([0-9]{1,2})~([0-9]{2})~([0-9]{2})$")

DATETIME_FORMAT = '%m-%y-%d %H:%M:%S'

def parseDatetime(gvalue, endofday=False):

    value = NORMALIZE_REGEX.sub("~", gvalue.strip())

    now = datetime.today()

    split = value.split(' ')
    if len(split) == 2:
        _date, _time = split
    else:
        _date = split[0]
        _time = ''

    year = now.year
    day, month, year = (0, 0, now.year)
    if endofday:
        hour, minute, sec = (23,59,59)
    else:
        hour, minute, sec = (0,0,0)

    regs = DATE_REGEX_SHORT.match(_date)
    if regs:
        try:
            day = int(regs.group(1))
            month = int(regs.group(2))
        except ValueError:
            pass
    regs = DATE_REGEX_LONG.match(_date)
    if regs:
        try:
            day = int(regs.group(1))
            month = int(regs.group(2))
            year = int(regs.group(3))
        except ValueError:
            pass

    regs = DATETIME_REGEX.match(_time)
    if regs:
        try:
            hour = int(regs.group(4))
            minute = int(regs.group(5))
            sec = int(regs.group(6))
        except ValueError:
            pass

    try:
        return int(time.mktime(datetime(year, month, day, hour, minute, sec).timetuple()))
    except:
        raise Exception("'%s' isn't any valid date or time." % gvalue)

def trans(ctx, trad):

    if not isinstance(trad, PlaceHolder):
        return trad

    return flattenL10n(trad, ctx)

def getBacktrace(empty="Empty backtrace."):
    """
    Try to get backtrace as string.
    Returns "Error while trying to get backtrace" on failure.
    """
    try:
        info = sys.exc_info()
        trace = traceback.format_exception(*info)
        sys.exc_clear()
        if trace[0] != "None\n":
            return "".join(trace)
    except:
        # No i18n here (imagine if i18n function calls error...)
        return "Error while trying to get backtrace"
    return empty

class Args:
    """ This class can be used to analyze arguments... """
    class ArgType:
        """ Subclass to give informations of an argument type """
        def __init__(self, label='', check_func=None, data_func=None, label_func=None, filter=True, links=[]):
            """
                @param label [string] this is the label of this arg
                @param check_func [func] function to check if arg value is correct
                @param data_func [func] function to create a render of this argument value
                @param label_func [func] function to create a render of the label
                @param filter [bool] is it a filter?
                @param links [list] list of arguments linked to this (must be strings)
            """
            self.label = label
            self.check_func = check_func
            self.data_func = data_func
            self.label_func = label_func
            self.filter = filter
            self.links = links

    def __init__(self, args={}):

        self.args = args

        self.functions = {'TCPTable':        _('TCP ports'),
                          'UDPTable':        _('UDP ports'),
                          'IPsrcTable':      _('source IP'),
                          'IPdstTable':      _('destination IP'),
                          'UserTable':       _('Users'),
                          'PacketTable':     _('Packet list'),
                          'ConUserTable':    _('Connected users'),
                          'PacketInfo':      _('Packet info'),
                          'AppTable':        _('Applications'),
                          'ConnTrackTable':  _('Connection tracking'),
                         }

        # Build list. TOOD: use a global variable?
        self.arg_types = {'id':         self.ArgType('',                   self._check_int,     self._data_id),
                          'tcp_dport':  self.ArgType(_('TCP dest port'),   self._check_port,    self._data_port, self._label_port),
                          'tcp_sport':  self.ArgType(_('TCP source port'), self._check_port,    self._data_port, self._label_port),
                          'udp_dport':  self.ArgType(_('UDP dest port'),   self._check_port,    self._data_port, self._label_port),
                          'udp_sport':  self.ArgType(_('UDP source port'), self._check_port,    self._data_port, self._label_port),
                          'sport':      self.ArgType(_('Source port'),     self._check_port,    self._data_port, self._label_port),
                          'dport':      self.ArgType(_('Dest port'),       self._check_port,    self._data_port, self._label_port),
                          'ip_saddr':   self.ArgType(_('Source'),          self._check_ip,      self._data_ip,   self._label_ip),
                          'ip_daddr':   self.ArgType(_('Destination'),     self._check_ip,      self._data_ip,   self._label_ip),
                          'ip_addr':    self.ArgType(_('Address'),         self._check_ip,      self._data_ip,   self._label_ip),
                          'ip_from':    self.ArgType(),
                          'username':   self.ArgType(_('User'),                                 label_func=self._label_nothing),
                          'user_id':    self.ArgType(_('User'),            self._check_int,     label_func=self._label_user_id, links=['username']),
                          'userlike':   self.ArgType(_('Username contains')),
                          'state':      self.ArgType(_('State'),           self._check_int,     label_func=self._label_state),
                          'proto':      self.ArgType(_('Proto'),           self._check_proto),
                          'limit':      self.ArgType('',                   self._check_int,     filter=False),
                          'start':      self.ArgType('',                   self._check_int,     filter=True),
                          'sort':       self.ArgType(                                           filter=False),
                          'sortby':     self.ArgType(                                           filter=False),
                          '~render':    self.ArgType('',                 self._check_forbidden, filter=False),
                          'packets':    self.ArgType(_('Packets'),         self._check_forbidden),
                          'begin':      self.ArgType(_('First packet'),    self._check_datetime,  self._data_timestamp, self._label_timestamp),
                          'end':        self.ArgType(_('Last packet'),     self._check_datetime,  self._data_timestamp, self._label_timestamp),
                          'timestamp':  self.ArgType(_('Time'),            None,                  self._data_timestamp, self._label_timestamp),
                          'currents':   self.ArgType(),
                          'client_app': self.ArgType(_('Application'),     None,                  self._data_app),
                          'tiny':       self.ArgType(filter=False),
                          'oob_prefix': self.ArgType(_('Prefix')),
                          'start_time': self.ArgType(_('Begin')),
                          'end_time':   self.ArgType(_('End')),
                          'os_sysname': self.ArgType(_('System')),
                          }

    def no_filters(self):

        new_args = dict()
        for key, value in self.args.items():
            if self.arg_types.has_key(key) and not self.arg_types[key].filter:
                new_args[key] = value

        return new_args

    def filters(self):

        new_args = dict()
        for key, value in self.args.items():
            if self.arg_types.has_key(key) and self.arg_types[key].filter:
                new_args[key] = value

        return new_args

    def remove(self, key):

        d = {key: None}

        try:
            for k in self.arg_types[key].links:
                d[k] = None
        except:
            pass

        return d

    ###################
    #       DATA      #
    ###################

    def get_data(self, label, value):
        """ It is a post-analyze of data. If type is in my internal list, I will
            transform data output.
        """

        if value is None or value == (None,None):
            return ''

        if isinstance(value, tuple):
            value = value[0]

        try:
            return self.arg_types[label].data_func(value)
        except:
            return str(value)

    def _data_id(self, value):
        """
            On backend, we put 'user' on label's id when this is
            a user packet, and 'host' when this is a non authentificated
            packet.

            We show an image and link on this points to the packet ID
        """

        if value == 'user':
            return tags.img(src='img/user.gif')
        elif value == 'host':
            return tags.img(src='img/host.gif')
        else:
            return value

    def _data_ip(self, value):
        """ We truncate IP if this is too long (an IPv6 for example) """

        if len(str(value)) > 15:
            return '...%s' % value[-15:]
        else:
            return value

    def _data_app(self, value):

        """ Show only the filename """

        r = value.split('/')[-1].split('\\')[-1]
        if not r:
            return value
        else:
            return r

    def _data_port(self, value, proto=None):
        """
            Show service associed of this port number.
            @arg proto [string] Only used if you want to call this function yourself
        """

        try:
            if proto:
                return socket.getservbyport(int(value), proto)
            elif self.args.has_key('proto'):
                return socket.getservbyport(int(value), self.args['proto'])
            else:
                return socket.getservbyport(int(value))
        except:
            return value

    def _data_timestamp(self, value):
        """ If this timestamp is a number, we create a date. """

        try:
            timestamp = int(value)
            return time.strftime(DATETIME_FORMAT, time.localtime(timestamp))
        except:
            if hasattr(value, 'strftime'):
                return value.strftime(DATETIME_FORMAT)
            return value

    ###################
    #      LABELS     #
    ###################

    def sort_label(self):
        """ With my args, get a label to say on what field my list is ordered. """

        if not self.args.has_key('sortby') or not self.arg_types.has_key(self.args['sortby']):
            return _('No sort')

        return _('Sorted by %s') % (self.arg_types[self.args['sortby']].label)

    def _label_ip(self, key, value):

        try:
            return '%s %s' % (key, socket.gethostbyaddr(value)[0])
        except:
            return '%s %s' % (key, value)

    def _label_state(self, key, value):

        return '%s %s' % (key, i2state(value))

    def _label_timestamp(self, key, value):
	try:
	    timestamp = int(value)
	    return '%s %s' % (key, time.strftime(DATETIME_FORMAT, time.localtime(timestamp)))
	except:
            if hasattr(value, 'strftime'):
                return value.strftime(DATETIME_FORMAT)

	    return '%s %s' % (key, value)

    def _label_port(self, key, value):
        """ Return port service if any. """

        regs = re.match('(.*)_(.*)', key) # It looks like tits

        proto = None
        if regs:
            proto = regs.group(1)

        port = self._data_port(value, proto)
        if isinstance(port, (int, long)) or isinstance(port, str) and port.isdigit():
            return '%s %s' % (key, value)
        else:
            return '%s %s (%s)' % (key, value, port)

    def _label_user_id(self, key, value):

        if self.args.has_key('username'):
            return '%s %s' % (key, self.args['username'])
        else:
            return '%s %s' % (key, value)

    def _label_nothing(self, key, value):
        return ''

    def labels_dict(self, ctx):

        labels = dict()

        for key, value in self.args.items():
            if not value or not self.arg_types.has_key(key) or (
               (not self.arg_types[key].label or not self.arg_types[key].filter) and not self.arg_types[key].label_func):
                continue

            if self.arg_types[key].label_func:
                s = self.arg_types[key].label_func(trans(ctx, self.arg_types[key].label), value)
                if s:
                    labels[key] = s
            else:
                labels[key] = '%s %s' % (trans(ctx, self.arg_types[key].label), value)

        return labels

    def labels(self, ctx, function=''):
        """ Get a label to show on what criteria this list is built. """

        label = ''
        labels = self.labels_dict(ctx)

        for value in labels.values():

            if label:
                label += ', '
            label += value

        # If a function is given, we use the label associated.
        if function and self.functions.has_key(function):
            if label:
                label = '%s %s %s' % (trans(ctx, self.functions[function]), trans(ctx, _('for')), label)
            else:
                label = trans(ctx, self.functions[function])

        return label

    ###################
    #       CHECK     #
    ###################

    def check(self):
        """ Check if types of arguments are good. """

        for arg in self.args:
            if arg.startswith('~') or not self.args[arg]:
                continue

            if not self.arg_types.has_key(arg):
                raise Exception(_('%s isn\'t a valid argument name') % arg)

            if self.arg_types[arg].check_func is not None:

                ret = self.arg_types[arg].check_func(arg, self.args[arg])

                if not ret:
                    raise Exception(_("'%(key)s' isn't a valid value for %(value)s argument") %
                                      {'key': self.args[arg],
                                       'value': arg})
                if ret != True:
                    # If function returns something else than "True", it is because
                    # we want to replace the value.
                    self.args[arg] = ret

        return True

    def _check_port(self, arg, value):

        if not self._check_int(arg, value) or int(value) < 1 or int(value) > 65535:
            raise Exception(_('%s value may be in range 1-65535') % arg)

        return True

    def _check_int(self, arg, value):

        try:
            int(value)
            return True
        except:
            raise Exception(_('%s value may be an integer') % arg)

    def _check_proto(self, arg, value):

        if not value in ('tcp', 'udp', 'icmp'):
            raise Exception(_('%(key)s value may be tcp, udp or icmp (and not %(value)s)') % {'key': arg, 'value': value})

        return True

    def _check_ip(self, arg, value):
        """ This function is useless because now we can give an hostname """
        try:
            IP(value)
        except:
            if not re.match('[a-zA-Z0-9-\._]{2,128}', value):
                raise Exception(_('Please give a correct IP address or hostname'))

        return True

    def _check_datetime(self, arg, value):
        """ Return the timestamp from the date string """

        try:
	    # We check if this is an integer (a timestamp).
            # If function doesn't raise an exception, it's ok!
	    self._check_int(arg, value)
	    return value
	except:
	    # it isn't an integer, so we continue checking for
            # datetime format.
	    pass

        endofday = False
        if arg == 'end':
            endofday = True

        return parseDatetime(value, endofday)

    def _check_forbidden(self, arg, value):
        """ This function is used if argument linked is forbidden """
        raise Exception(_('%s isn\'t a valid argument name') % arg)

def i2state(i):
    """ Get state label from integer. """

    try:
        i = int(i)
    except:
        return None

    # Do not translate because this is used for CSS classes.
    lst = {-1: 'all',
            0: 'drop',
            1: 'open',
            2: 'established',
            3: 'close',
            4: 'open'
           }

    try:
        return lst[i]
    except KeyError:
        return None

def args2url(args, **margs):
    """
        This function cas be used to add one or several arguments on a list,
        and to create a link.

        For example:
        >>> args = {'truc': 'machin', 'bidule': 'blah', 'misc': 'litledick'}
        >>> args2url(args, bz='flag', bidule='truc', misc=None)
        '?truc=machin&bidule=truc&bz=flag'

        If you put a None argument and it is in args, it will be removed (like 'misc').

        @param args [dict] original arguments list
        @param **margs [dict] all new arguments
        @return [string] formated url arguments
    """
    s = ''

    if isinstance(args, SOAPpy.Types.structType):
        args = args._asdict()

    for arg in args:

        value = ''

        if margs.has_key(arg):
            # Do not only check with "not margs[arg]" because
            # we want to use 0 value.
            if margs[arg] == '' or margs[arg] is None:
                continue
            else:
                value = margs[arg]
        else:
            value = args[arg]

        if not s:
            s += '?'
        else:
            s += '&'

        s += '%s=%s' % (arg, value)

    for arg in margs:
        if not arg in args and margs[arg] is not None and str(margs[arg]):
            if not s:
                s += '?'
            else:
                s += '&'

            s += '%s=%s' % (arg, margs[arg])

    if not s:
        s = '?'

    return s
