#!/usr/bin/env python
import csv, sys, os, re
try:
    from pysqlite2 import dbapi2 as sqlite3
except:
    import sqlite3
from argparse import ArgumentParser, SUPPRESS
import gzip
try:
    # not all platforms/installations of python support bz2
    import bz2
    bz2_support = True
except:
    bz2_support = False
try:
    try:
        import lzma
    except:
        import backports.lzma as lzma
    lzma_support = True
except:
    lzma_support = False

def openFile(filename):
    if filename.lower().endswith('.gz'):
        return gzip.open(filename, 'rb')
    elif filename.lower().endswith('.bz2'):
        if not bz2_support: raise ValueError("Cannot process bz2 files with your operating system")
        return bz2.BZ2File(filename, 'rb')
    elif filename.lower().endswith('.xz'):
        if not lzma_support: raise ValueError("Cannot process xz files with your operating system")
        return lzma.open(filename, 'rb')
    else:
        # text file
        # because readline() from gzip.open will be byte, not string, we should return
        # binary here in order to process them equally in order for things to work
        # correctly under python 3 
        return open(filename, 'rb')

SQL_KEYWORDS = set([
    'ADD', 'ALL', 'ALTER', 'ANALYZE', 'AND', 'AS', 'ASC', 'ASENSITIVE', 'BEFORE',
    'BETWEEN', 'BIGINT', 'BINARY', 'BLOB', 'BOTH', 'BY', 'CALL', 'CASCADE', 'CASE',
    'CHANGE', 'CHAR', 'CHARACTER', 'CHECK', 'COLLATE', 'COLUMN', 'CONDITION',
    'CONSTRAINT', 'CONTINUE', 'CONVERT', 'CREATE', 'CROSS', 'CURRENT_DATE',
    'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'CURRENT_USER', 'CURSOR', 'DATABASE',
    'DATABASES', 'DAY_HOUR', 'DAY_MICROSECOND', 'DAY_MINUTE', 'DAY_SECOND', 'DEC',
    'DECIMAL', 'DECLARE', 'DEFAULT', 'DELAYED', 'DELETE', 'DESC',
    'DESCRIBE', 'DETERMINISTIC', 'DISTINCT', 'DISTINCTROW', 'DIV', 'DOUBLE',
    'DROP', 'DUAL', 'EACH', 'ELSE', 'ELSEIF', 'ENCLOSED', 'ESCAPED', 'EXISTS',
    'EXIT', 'EXPLAIN', 'FALSE', 'FETCH', 'FLOAT', 'FLOAT4', 'FLOAT8', 'FOR',
    'FORCE', 'FOREIGN', 'FROM', 'FULLTEXT', 'GRANT', 'GROUP', 'HAVING', 'HIGH_PRIORITY',
    'HOUR_MICROSECOND', 'HOUR_MINUTE', 'HOUR_SECOND', 'IF', 'IGNORE', 'IN',
    'INDEX', 'INFILE', 'INNER', 'INOUT', 'INSENSITIVE', 'INSERT',
    'INT', 'INT1', 'INT2', 'INT3', 'INT4', 'INT8', 'INTEGER', 'INTERVAL', 'INTO',
    'IS', 'ITERATE', 'JOIN', 'KEY', 'KEYS', 'KILL', 'LEADING', 'LEAVE', 'LEFT',
    'LIKE', 'LIMIT', 'LINES', 'LOAD', 'LOCALTIME', 'LOCALTIMESTAMP',
    'LOCK', 'LONG', 'LONGBLOB', 'LONGTEXT', 'LOOP', 'LOW_PRIORITY', 'MATCH',
    'MEDIUMBLOB', 'MEDIUMINT', 'MEDIUMTEXT', 'MIDDLEINT', 'MINUTE_MICROSECOND',
    'MINUTE_SECOND', 'MOD', 'MODIFIES', 'NATURAL', 'NOT', 'NO_WRITE_TO_BINLOG',
    'NULL', 'NUMERIC', 'ON', 'OPTIMIZE', 'OPTION', 'OPTIONALLY', 'OR',
    'ORDER', 'OUT', 'OUTER', 'OUTFILE', 'PRECISION', 'PRIMARY', 'PROCEDURE',
    'PURGE', 'READ', 'READS', 'REAL', 'REFERENCES', 'REGEXP', 'RELEASE',
    'RENAME', 'REPEAT', 'REPLACE', 'REQUIRE', 'RESTRICT', 'RETURN',
    'REVOKE', 'RIGHT', 'RLIKE', 'SCHEMA', 'SCHEMAS', 'SECOND_MICROSECOND',
    'SELECT', 'SENSITIVE', 'SEPARATOR', 'SET', 'SHOW', 'SMALLINT',
    'SONAME', 'SPATIAL', 'SPECIFIC', 'SQL', 'SQLEXCEPTION', 'SQLSTATE',
    'SQLWARNING', 'SQL_BIG_RESULT', 'SQL_CALC_FOUND_ROWS', 'SQL_SMALL_RESULT',
    'SSL', 'STARTING', 'STRAIGHT_JOIN', 'TABLE', 'TERMINATED',
    'THEN', 'TINYBLOB', 'TINYINT', 'TINYTEXT', 'TO', 'TRAILING',
    'TRIGGER', 'TRUE', 'UNDO', 'UNION', 'UNIQUE', 'UNLOCK', 'UNSIGNED',
    'UPDATE', 'USAGE', 'USE', 'USING', 'UTC_DATE', 'UTC_TIME', 'UTC_TIMESTAMP', 'VALUES',
    'VARBINARY', 'VARCHAR', 'VARCHARACTER', 'VARYING', 'WHEN', 'WHERE', 'WHILE',
    'WITH', 'WRITE', 'XOR', 'YEAR_MONTH', 'ZEROFILL', 'ASENSITIVE', 'CALL', 'CONDITION',
    'CONNECTION', 'CONTINUE', 'CURSOR', 'DECLARE', 'DETERMINISTIC', 'EACH',
    'ELSEIF', 'EXIT', 'FETCH', 'GOTO', 'INOUT', 'INSENSITIVE', 'ITERATE', 'LABEL', 'LEAVE',
    'LOOP', 'MODIFIES', 'OUT', 'READS', 'RELEASE', 'REPEAT', 'RETURN', 'SCHEMA', 'SCHEMAS',
    'SENSITIVE', 'SPECIFIC', 'SQL', 'SQLEXCEPTION', 'SQLSTATE', 'SQLWARNING', 'TRIGGER',
    'UNDO', 'UPGRADE', 'WHILE', 'ABS', 'ACOS', 'ADDDATE', 'ADDTIME', 'ASCII', 'ASIN',
    'ATAN', 'AVG', 'BETWEEN', 'AND', 'BINARY', 'BIN', 'BIT_AND',
    'BIT_OR', 'CASE', 'CAST', 'CEIL', 'CHAR', 'CHARSET', 'CONCAT', 'CONV', 'COS', 'COT',
    'COUNT', 'DATE', 'DAY', 'DIV', 'EXP', 'IS', 'LIKE', 'MAX', 'MIN', 'MOD', 'MONTH',
    'LOG', 'POW', 'SIN', 'SLEEP', 'SORT', 'STD', 'VALUES', 'SUM'
])

class SQLiteMan:
    def __init__(self, dbpath):
        self.conn = sqlite3.connect(dbpath)
        try:
            # gcc -fPIC -lm -shared extension-functions.c -o libsqlitefunctions.so
            self.conn.enable_load_extension(True)
        except:
            pass
        self.c = self.conn.cursor()
        self.c.execute('pragma synchronous=off')
        self.c.execute('pragma count_changes=off')
        self.c.execute('pragma journal_mode=memory')
        self.c.execute('pragma temp_store=memory')
    
    def convert(self, filepath_or_fileobj, table = 'data', delim = None, header_option = None, force = False):
        # @author: Rufus Pollock
        # Placed in the Public Domain
        table = self._legalize_name(table)
        if table in self.getTables():
            if force:
                self.c.execute('DROP TABLE IF EXISTS {}'.format(table))
            else:
                sys.exit("Table '{}' already exists!".format(table))
        if isinstance(filepath_or_fileobj, basestring):
            fo = openFile(filepath_or_fileobj)
        else:
            fo = filepath_or_fileobj
        # guess delimiter
        if delim is None:
            sniffer = csv.Sniffer()
            dialect = sniffer.sniff(fo.readline())
            delim = dialect.delimiter
            fo.seek(0)
        #
        if delim == '\\t': delim = '\t'
        reader = csv.reader(fo, delimiter = delim)

        if header_option is None:
            # first line is header
            headers = [self._legalize_name(x) for x in reader.next()]
            fo.seek(0)
            start = len(fo.readline())
        elif header_option is False:
            # no header
            headers = ["V{}".format(i+1) for i in range(len(reader.next()))]
            start = 0
        else:
            # use this input as header
            if header_option == ['-']:
                headers = sys.stdin.readlines()[0].strip().split(delim if delim != '\t' else '\\t')
            else:
                headers = header_option
            start = 0
        #
        for idx, header in enumerate(headers):
            if header.upper() in SQL_KEYWORDS:
                headers[idx] = "_" + header
        #
        types = self._guess_types(reader, headers)
        fo.seek(start)

        _columns = ','.join(
            ['"%s" %s' % (header, _type) for (header,_type) in zip(headers, types)]
            )

        self.c.execute('CREATE table %s (%s)' % (table, _columns))

        _insert_tmpl = 'insert into %s values (%s)' % (table,
            ','.join(['?']*len(headers)))
        for row in reader:
            # we need to take out commas from int and floats for sqlite to
            # recognize them properly ...
            row = [ x.replace(',', '') if y in ['real', 'integer'] else x
                    for (x,y) in zip(row, types) ]
            self.c.execute(_insert_tmpl, row)

        self.conn.commit()
        self.c.close()

    def _guess_types(self, reader, headers, max_sample_size=100):
        '''Guess column types (as for SQLite) of CSV.
        '''
        # @author: Rufus Pollock
        # Placed in the Public Domain
        # we default to text for each field
        types = ['text'] * len(headers)
        # order matters
        # (order in form of type you want used in case of tie to be last)
        options = [
            ('text', unicode),
            ('real', float),
            ('integer', int)
            # 'date',
            ]
        # for each column a set of bins for each type counting successful casts
        perresult = {
            'integer': 0,
            'real': 0,
            'text': 0
            }
        results = [ dict(perresult) for x in range(len(headers)) ]
        for count,row in enumerate(reader):
            for idx,cell in enumerate(row):
                cell = cell.strip()
                # replace ',' with '' to improve cast accuracy for ints and floats
                cell = cell.replace(',', '')
                for key,cast in options:
                    try:
                        # for null cells we can assume success
                        if cell:
                            cast(cell)
                        results[idx][key] = (results[idx][key]*count + 1) / float(count+1)
                    except ValueError as inst:
                        pass
            if count >= max_sample_size:
                break
        for idx,colresult in enumerate(results):
            for _type, dontcare in options:
                if colresult[_type] == 1.0:
                    types[idx] = _type
        return types


    def _legalize_name(self, name):
        output = ''
        for x in name:
            if re.match(r'^[a-zA-Z0-9_]+$', x):
                output += x
            else:
                output += '_'
        if re.match(r'^[0-9][a-zA-Z0-9_]+$', output) or output.upper() in SQL_KEYWORDS:
            output = '_' + output
        return output


    def getTables(self):
        tables = []
        for item in self.c.execute("SELECT tbl_name FROM sqlite_master"):
            tables.extend(item)
        return sorted(tables)
    
    def getFields(self, table):
        fields = []
        for item in self.c.execute("PRAGMA table_info('{0}')".format(table)):
            fields.append((item[1].lower(), item[2]))
        return sorted(fields)

    def execute(self, query):
        for item in self.c.execute(query).fetchall():
            print (','.join(map(str, item)))
        self.conn.commit()
        self.c.close()

    def load_extension(self, ext):
        try:
            self.conn.load_extension(ext)
        except Exception as e:
            sys.stderr.write('Cannot load extension "{}" ({})! Perhaps python-sqlite3 version too old?\n'.\
                             format(ext,e))


if __name__ == '__main__':
    parser = ArgumentParser(description = "Simple database utility", prog = os.path.split(sys.argv[0])[-1],
        fromfile_prefix_chars = '@', epilog = '''Gao Wang, 2014''')
    parser.add_argument('--version', action='version', version='%(prog)s version {0}'.format('0.0.1'))
    parser.add_argument('database', help = 'DB filename')
    parser_import = parser.add_argument_group('Import options')
    parser_import.add_argument('-i', '--import', metavar = 'string', dest = 'convert', help = 'text filename')
    parser_import.add_argument('--as', metavar = 'string', dest = 'table',
                               help = 'text filename alias in database')
    parser_import.add_argument('-d', '--delimiter', metavar = 'string',
                               help = 'text file delimiter')
    parser_import.add_argument('--no-header', action = 'store_true', dest = 'no_header',
                               help = 'input text file does not have a header')
    parser_import.add_argument('--header', metavar = 'vector', nargs='+',
                               help = 'provide headers from command (use "-" if header comes from stdin)')
    parser_import.add_argument('-f', '--force', action = 'store_true', 
                               help = 'delete existing table in database if exists and import new table')
    parser_query = parser.add_argument_group('Utilities')
    parser_import.add_argument('--extension', metavar = 'string',
                               help = 'path to sqlite extension to load for use')
    parser_query.add_argument('-s', '--show', nargs = '*', help = 'list tables in database, or fields in table')
    args, argv = parser.parse_known_args()
    #
    if args.convert:
        database = args.database if args.database.endswith('sqlite3') else \
          os.path.splitext(args.database)[0] + '.sqlite3'
    else:
        database = args.database
    s = SQLiteMan(database)
    if args.convert:
        table = args.table if args.table else os.path.splitext(os.path.split(args.convert)[-1])[0]
        if table.upper in SQL_KEYWORDS:
            sys.stderr.write('Table name {0} renamed to _{0} due to conflict with SQL keywords!\n'.format(table))
        s.convert(args.convert, table, args.delimiter,
                  args.header if not args.no_header else False,
                  args.force)
    if args.show is not None:
        if len(args.show) == 0:
            print ('\n'.join(s.getTables()))
        else:
            for item in args.show:
                print ("\033[1m{}\033[0m".format(item))
                print ('\n'.join(['[{}] {}'.format(x[1], x[0]) for x in s.getFields(item)]))
    if args.extension:
        s.load_extension(os.path.abspath(os.path.expanduser(args.extension)))
    if argv:
        try:
            s.execute(' '.join(argv))
        except sqlite3.OperationalError as e:
            sys.exit(e)