
import MySQLdb, mx.DateTime, re

Debug = False

class jsql(object):
    rowdata = []          ## List of current row data, in orig order
    rows    = []          ## Result from get_rows()
    types   = {}          ## Dictionary of field type objects (jtype) based on field type

    __table = ""          ## Current selected table
    __recno = -1          ## Current record number in multple-row result
    
    def __init__( self, server, user, password, database, table=""):

        assert server
        assert database
        assert user

        self.server     = server
        self.user       = user
        self.password   = password
        self.database   = database
        self.__connect()             ## establish the mysql connection

        self.table      = table      ## must assign table after connecting

    def __connect(self):
        """
        """
        self.db     = MySQLdb.connect( self.server,
                                       user=self.user,
                                       passwd=self.password)
        self.cur    = self.db.cursor()
        assert self.db
        assert self.cur

    def __del__(self):
        """
        We override the del method to make sure this connection
        gets closed when the object is released
        """
        self.db.close()

    def __execute(self, templateinfo=None):
        """
        Try to execute the current SQL command. If an error,
        report the SQL statement that gave the error.
        """
        assert self.command
        if templateinfo:
            self.command = self.command % templateinfo

        try:
            self.cur.execute( self.command)
            self.result = self.cur.fetchall()
        except:
            self.result = None
            print "Class error: Error in SQL command"
            print self.command

    def get_columns(self):
        """
        Return column information from the database
        """
        if self.table:
            self.command = "SHOW COLUMNS FROM %s.%s "
            self.__execute((self.database, self.table))
            self.__prep_row(self.result)
            return self.result
        return None

    def get_fields(self):
        return self.__columnstring()

    def __columnstring(self):
        """ Returns the tables column names as a comma-delimited
            string for use in SQL commands.
        """ 
        cstr = []
        for row in self.get_columns(): cstr.append(row[0])
        return ",".join(cstr)

    def __prep_row(self, columns):
        """
        Setup a dictionary property to hold data from one row.
        Pass the result from SHOW COLUMNS
        """
        assert columns
        for each in columns:
            self.types[each[0]] = jtype( each[1])   ## each[1] = column type
            if each[3] == "PRI":                    ## each[3] = primary key indicator
                self.types[each[0]].is_pk = True    ## mark type as pk
                self.pk_fieldname = each[0]         ## set pk fieldname (note - pk field will not be updated by generated sql commands)
                self.pk_info = each[5]              ## each[5] = additional info, may need it

    def get_tables(self):
        """
        Return table information from the database
        """ 
        self.command = "SHOW TABLES FROM %s" 
        self.__execute((self.database))
        return self.result

    def get_row(self, pk=0, filter=[]):
        """
         Pass pk value or a list of field:value pairs for WHERE filter.
         Queries the current row data, loads local properties, returns the row.
        """
        if pk:
            self.command = "SELECT * FROM %s.%s WHERE %s = %s"
            self.__execute((self.database,self.table,self.pk_fieldname,pk))
        elif filter:
            cWhere       = self.__get_where( filter)
            self.command = "SELECT * FROM %s.%s WHERE %s"
            self.__execute((self.database,self.table,cWhere))

        if self.result:
            self.__recno = 0
            self.rowdata = []
            for item in self.result[0]: self.rowdata.append(item)
            columns      = self.get_columns()
        return self.rowdata

    def __get_reccount(self):
        return len( self.rows)

    def __get_recno(self):
        return self.__recno

    def __set_recno(self, recno):
        if recno <= (self.reccount-1) and recno >= 0 and self.rows:
            self.rowdata = self.rows[recno]
            self.__recno = recno

    def get_value(self,fieldname):
        cValue  = ""
        columns = self.get_columns()
        for i in range( len( columns)):
            if fieldname == columns[i][0]:
                cValue = self.rowdata[i]
        return cValue

    def __get_where(self, filter):
        """
         Builds a SQL-WHERE expression given a list of
         field:value pairs
        """
        assert type(filter) == type({})
        cSql = ""
        for field in filter.keys():
            value = getquoted( self.types[field], filter[field])
            cSql += "%s LIKE %s AND " %(field, value)
        cSql = cSql[0:-5]        ## Drop trailing 'AND'
        return cSql

    def get_rows( self, fields, filter=[], order=[]):
        assert fields
        cWhere = ""
        cOrder = ""
        if not fields or fields == "*":
            fields = self.__columnstring()
        else:
            fields = ",".join(fields)
        if filter:
            if type(filter) == type({}):
                cWhere = " WHERE %s " % (self.__get_where(filter))
            else:
                cWhere = " WHERE %s " % (filter)
        if order:
            cOrder = " ORDER BY %s " % (",".join(order))
        self.command = "SELECT %s FROM %s.%s %s %s"

        if Debug:
            print self.command % ((fields,self.database,self.table,cWhere,cOrder))
        else:
            self.__execute((fields,self.database,self.table,cWhere,cOrder))
            self.rows = self.result
            return self.result

    def get_empty(self):
        """ Get an empty row (reset rowdata)
        """
        self.rowdata = []
        columns = self.get_columns()
        for i in range(len( columns)):
            field = columns[i][0]
            self.rowdata.append( self.types[field].empty)
        
    def set_value( self, fieldname, value, save=0):
        assert fieldname
        assert value
        columns = self.get_columns()
        for i in range( len( columns)):
            if fieldname == columns[i][0]:
                self.rowdata[i] = value
        if save:
            strval = getquoted( self.types[fieldname], value)
            strpk  = getquoted( self.types[self.pk_fieldname], self.get_value(self.pk_fieldname))
            self.command = "UPDATE %s.%s SET %s WHERE %s=%s"

            if Debug:
                print self.command % ((self.database,self.table,fieldname+"="+strval,self.pk_fieldname,strpk))
            else:
                self.__execute((self.database,self.table,fieldname+"="+strval,self.pk_fieldname,strpk))

    def save_row(self):
        """ Save (update) the current rowdata
        """
        fieldsvalues = ""
        fv           = []
        pk           = getquoted( self.types[self.pk_fieldname],self.get_value( self.pk_fieldname))
        self.command = "UPDATE %s.%s SET %s WHERE %s"

        for i in range(len(self.__columns)):
            field = self.__columns[i][0]
            data  = self.rowdata[i]
            value = getquoted(self.types[field], data)

            if field == self.pk_fieldname:
                pk = "=".join((field,value))
            else:
                fv.append("=".join((field,value)))

        fieldsvalues = ",".join(fv)
        if Debug:
            print self.command % ((self.database,self.table,fieldsvalues,pk))
        else:
            self.__execute((self.database,self.table,fieldsvalues,pk))

    def delete_row(self):
        """ Delete the last selected row:
        """
        assert self.rowdata
        assert self.pk_fieldname
        pk = getquoted( self.types[self.pk_fieldname], self.get_value(self.pk_fieldname))
        self.command = "DELETE FROM %s.%s WHERE %s = %s "

        if Debug:
            print self.command % ((self.database,self.table,self.pk_fieldname,pk))
        else:
            self.__execute((self.database,self.table,self.pk_fieldname,pk))
        
    def insert_row(self):
        """ Add (insert) the rowdata
        """
        vl = []
        fields = self.__columnstring()
        fields = fields.replace(self.pk_fieldname+",","")
        for i in range(len(self.__columns)):
            field = self.__columns[i][0]
            self.types[field].setnativevalue( self.rowdata[i])
            value = self.types[field].quoted
            if self.__columns[i][0] != self.pk_fieldname:
                vl.append(value)
        values = ",".join(vl)

        self.command = "INSERT INTO %s.%s (%s) VALUES (%s)"
        if Debug:
            print self.command % ((self.database,self.table,fields,values))
        else:
            self.__execute((self.database,self.table,fields,values))

    def __get_table(self):
        return self.__table

    def __set_table(self, value):
        self.__table = value
        if value:
            self.__columns = self.get_columns()

    table    = property( __get_table, __set_table, None, "Get/set method for table property")
    recno    = property( __get_recno, __set_recno, None, "Set current row within multi-row set")
    reccount = property( __get_reccount, None, None, "Get current number of rows from multi-row set")
    
    
class jtype(object):
    sqltypes = {}

    ## { Type: (funcToStr, funcFromStr),...}    
    sqlconv = {
               "INT":(str,int),
               "SMALLINT":(str,int),
               "TINYINT":(str,int),
               "MEDIUMINT":(str,int),
               "BIGINT":(str,int),
               "DATE": (str,mx.DateTime.DateFrom),
               "DATETIME":(str,mx.DateTime.DateTimeFrom),
               "TIMESTAMP":(str,mx.DateTime.DateTimeFrom),
               "VARCHAR":(str,str),
               "CHAR":(str,str),
               "LONGTEXT":(str,str),
               "LONG":(str,eval),
               "DECIMAL":(str,eval),
               }

    sqlempty = {
               "INT" : 0,
               "SMALLINT" : 0,
               "TINYINT" : 0,
               "MEDIUMINT" : 0,
               "BIGINT" : 0,
               "DATE" : mx.DateTime.DateFrom("01-01-1901"),
               "DATETIME" : mx.DateTime.DateFrom("01-01-1901"),
               "TIMESTAMP" : mx.DateTime.DateFrom("01-01-1901"),
               "VARCHAR" : "",
               "CHAR" : "",
               "LONGTEXT" : "",
               "LONG" : 0.00,
               "DECIMAL": 0.00,
               }

    sqlcontrols = {
               "INT" : "INT",           ## wxIntCtrl( panel, size=wxSize( 50, 20 ) )
               "SMALLINT" : "INT",
               "TINYINT" : "INT",
               "MEDIUMINT" : "INT",
               "BIGINT" : "INT",
               "DATE" : "CAL",
               "DATETIME" : "CALTIME",
               "TIMESTAMP" : "CALTIME",
               "VARCHAR" : "TEXT",
               "CHAR" : "TEXT",
               "LONGTEXT" : "EDIT",
               "LONG" : "TEXT",
               "DECIMAL": "TEXT",
               }

    
    def __init__(self, type=""):
        self.type        = self.settype(type)
        self.control     = self.get_defaultcontrol()
        self.nativevalue = None
        self.strvalue    = ""
        self.precision   = 0
        self.decimal     = 0
        self.quoted      = ""
        self.is_pk       = 0
        self.is_readonly = 0

    def settype( self, ctype):
        if not ctype:
            return ""
        resrch = re.compile("[a-z]+")
        type = resrch.search(ctype).group().upper()
        self.empty = self.sqlempty[type]
        return type

    def setnativevalue(self, nvalue):
        functostr = self.sqlconv[self.type][0]
        try:
            self.strvalue = functostr(nvalue)
        except:
            print "type: %s, value: %s" % (self.type , str(nvalue))
        self._setquoted()
        self.nativevalue = nvalue
        return self.strvalue

    def setstrvalue(self, svalue):
        funcfromstr = self.sqlconv[self.type][1]
        try:
            if svalue:
                self.nativevalue = funcfromstr(svalue)
            else:
                self.nativevalue = getblank(self.type)
        except:
            print "type: %s, value: %s" % (self.type , str(svalue))

        self.strvalue = svalue
        self._setquoted()
        return self.nativevalue

    def _setquoted(self):
        if self.type in ("INT","TINYINT","SMALLINT","MEDIUMINT","BIGINT","LONG","DECIMAL"):
            self.quoted = self.strvalue
            if not self.quoted:
                self.quoted = '0'

        elif self.type in ("CHAR","VARCHAR"):
            self.quoted = "\"" + self.strvalue + "\""

        elif self.type in ("DATE","DATETIME"):
            self.quoted = "\"" + self.strvalue[0:-3] + "\""

        elif self.type in ("LONGTEXT"):
            self.quoted = hexify( self.strvalue)

        else:
            self.quoted = "\"" + self.strvalue + "\""


    def get_defaultcontrol(self):
        return self.sqlcontrols[self.type]
            
def getquoted( ctype, nativevalue):
    if isinstance(ctype, jtype):
       converter = ctype
    else:
        converter = jtype( ctype)

    converter.setnativevalue( nativevalue)
    return converter.quoted

def getnativevalue( ctype, value):
    if isinstance(ctype, jtype):
       converter = ctype
    else:
        converter = jtype( ctype)
    converter.setstrvalue( str(value))
    return converter.nativevalue

def getblank( ctype):
    if ctype in ("INT","TINYINT","SMALLINT","MEDIUMINT","BIGINT","LONG","DECIMAL"):
        return 0

    elif ctype in ("CHAR","VARCHAR","LONGTEXT"):
        return ""

    elif ctype in ("DATE","DATETIME"):
        return mx.DateTime.DateFrom("01-01-0001")

    else:
        return ""

def hexify (cString=""):
    cResult = "0x"
    for char in str(cString):
        cHex = hex(ord(char))[2:4].upper()
        if len( cHex) == 1:
            cHex = "0" + cHex
        cResult += cHex
    return cResult