import  wx
import  wx.grid as  gridlib

class EventGrid(object):
    def __init__(self):
        self.__col = None
        self.__row = None

    def GetCol(self):
        return self.__col

    def SetCol(self, col):
        self.__col = col

    def GetRow(self):
        return self.__row

    def SetRow(self, row):
        self.__row = row

GRID_BEFORE_FIELD = wx.NewEventType()
EVT_GRID_BEFORE_FIELD = wx.PyEventBinder(GRID_BEFORE_FIELD, 1)
class evt_before_field(wx.PyCommandEvent, EventGrid):
    pass

GRID_AFTER_FIELD = wx.NewEventType()
EVT_GRID_AFTER_FIELD = wx.PyEventBinder(GRID_AFTER_FIELD, 1)
class evt_after_field(wx.PyCommandEvent, EventGrid):
    pass

GRID_BEFORE_ROW = wx.NewEventType()
EVT_GRID_BEFORE_ROW = wx.PyEventBinder(GRID_BEFORE_ROW, 1)
class evt_before_row(wx.PyCommandEvent, EventGrid):
    def __init__(self, evtType, id):
        wx.PyCommandEvent.__init__(self, evtType, id)
        EventGrid.__init__(self)

GRID_AFTER_ROW = wx.NewEventType()
EVT_GRID_AFTER_ROW = wx.PyEventBinder(GRID_AFTER_ROW, 1)
class evt_after_row(wx.PyCommandEvent, EventGrid):
    def __init__(self, evtType, id):
        wx.PyCommandEvent.__init__(self, evtType, id)
        EventGrid.__init__(self)

class TGrid(gridlib.PyGridTableBase):
    def __init__(self, headers, dtypes, data):
        gridlib.PyGridTableBase.__init__(self)
        self.headers = headers

        self.dtypes = dtypes
        self.data = data

    def GetNumberRows(self):
        return len(self.data)# + 1

    def GetNumberCols(self):
        return len(self.data[0])

    def IsEmptyCell(self, row, col):
        try:
            return not self.data[row][col]
        except IndexError:
            return True

    def GetValue(self, row, col):
        try:
            return self.data[row][col]
        except IndexError:
            return ''

    def SetValue(self, row, col, value):
        try:
            self.data[row][col] = value
        except IndexError:
            # add a new row
            self.data.append([''] * self.GetNumberCols())
            self.SetValue(row, col, value)

            # tell the grid we've added a row
            msg = gridlib.GridTableMessage(self,            # The table
                    gridlib.GRIDTABLE_NOTIFY_ROWS_APPENDED, # what we did to it
                    1                                       # how many
                    )

            self.GetView().ProcessTableMessage(msg)

    def GetColLabelValue(self, col):
        return self.headers[col]

    def GetTypeName(self, row, col):
        return self.dtypes[col]

    # Called to determine how the data can be fetched and stored by the
    # editor and renderer.  This allows you to enforce some type-safety
    # in the grid.
    def CanGetValueAs(self, row, col, typeName):
        colType = self.dtypes[col].split(':')[0]
        if typeName == colType:
            return True
        else:
            return False 

    def CanSetValueAs(self, row, col, typeName):
        return self.CanGetValueAs(row, col, typeName)

class Grid(gridlib.Grid):
    def __init__(self, parent):
        gridlib.Grid.__init__(self, parent, -1)
        #self.CreateGrid(20, 6)
        #self.SetCellValue(0, 0, "Enter moves to the right")
        #self.SetCellValue(0, 5, "Enter wraps to next row")
        #self.SetColSize(0, 150)
        #self.SetColSize(5, 150)
        headers = ['ID', 'Description', 'Severity', 'Priority', 'Platform',
                   'Opened?', 'Fixed?', 'Tested?', 'TestFloat']
 
        data = [['']*len(headers)]

        types = [gridlib.GRID_VALUE_NUMBER,
                 gridlib.GRID_VALUE_STRING,
                 gridlib.GRID_VALUE_CHOICE + ':only in a million years!,wish list,minor,normal,major,critical',
                 gridlib.GRID_VALUE_NUMBER + ':1,5',
                 gridlib.GRID_VALUE_CHOICE + ':all,MSW,GTK,other',
                 gridlib.GRID_VALUE_BOOL,
                 gridlib.GRID_VALUE_BOOL,
                 gridlib.GRID_VALUE_BOOL,
                 gridlib.GRID_VALUE_FLOAT + ':6,2',
                ]

        self.table = TGrid(headers, types, data)
        self.SetTable(self.table, True)
        parent.Bind(wx.EVT_ACTIVATE, self.OnActivate)
        self.Bind(wx.EVT_KEY_DOWN, self.OnKeyDown)
        self.Bind(gridlib.EVT_GRID_SELECT_CELL, self.OnGridSelectCell)

        self.EnableGridLines(True)
        self.SetGridLineColour('#DDDDDD')
        self.SetRowLabelSize(1)
        #self.Enable(False)

    def OnActivate(self, evt):
        row = self.GetGridCursorRow()
        #print 'OnActivate', row
        self.EnableCellEditControl()
        evt.Skip()

    def OnGridSelectCell(self, evt):
        row = self.GetGridCursorRow()
        newRow = evt.GetRow()
        newCol = evt.GetCol()
        col = self.GetGridCursorCol()

        if col != newCol:
            if col != -1:
                e = evt_after_field(GRID_AFTER_FIELD, self.GetId())
                e.SetCol(col)
                e.SetRow(newRow)
                self.GetEventHandler().ProcessEvent(e)

        if row != newRow:
            if row != -1:
                e = evt_after_row(GRID_AFTER_ROW, self.GetId())
                e.SetCol(newCol)
                e.SetRow(row)
                self.GetEventHandler().ProcessEvent(e)

            e = evt_before_row(GRID_BEFORE_ROW, self.GetId())
            e.SetCol(newCol)
            e.SetRow(newRow)
            self.GetEventHandler().ProcessEvent(e)

        if col != newCol:
            e = evt_before_field(GRID_BEFORE_FIELD, self.GetId())
            e.SetCol(newCol)
            e.SetRow(row)
            self.GetEventHandler().ProcessEvent(e)

        self.SelectRow(newRow)
        wx.CallAfter(self.EnableCellEditControl)
        evt.Skip()

    def OnKeyDown(self, evt):
        wx.CallAfter(self.EnableCellEditControl)

        if evt.KeyCode() != wx.WXK_RETURN:
            evt.Skip()
            return

        if evt.ControlDown():   # the edit control needs this key
            evt.Skip()
            return

        success = self.MoveCursorRight(evt.ShiftDown())
        if not success:
            newRow = self.GetGridCursorRow() + 1
            if newRow < self.GetTable().GetNumberRows():
                self.SetGridCursor(newRow, 0)
                self.MakeCellVisible(newRow, 0)
            else:
                # this would be a good place to add a new row if your app
                # needs to do that
                col = 0
                self.table.SetValue(newRow, col, '')
                self.SetGridCursor(newRow, 0)
                self.MakeCellVisible(newRow, 0)
                pass

class TestFrame(wx.Frame):
    def __init__(self, parent):
        wx.Frame.__init__(self, parent, -1)
        test_grid = Grid(self)

        test_grid.Bind(EVT_GRID_BEFORE_FIELD, self.OnGridBeforeField)
        test_grid.Bind(EVT_GRID_AFTER_FIELD, self.OnGridAfterField)
        test_grid.Bind(EVT_GRID_BEFORE_ROW, self.OnGridBeforeRow)
        test_grid.Bind(EVT_GRID_AFTER_ROW, self.OnGridAfterRow)

    def OnGridBeforeField(self, evt):
        print 'Before Field', evt.GetCol()
        #print '       Col  ', evt.GetCol()
        evt.Skip()

    def OnGridAfterField(self, evt):
        print 'After Field', evt.GetCol()
        #print '      Col  ', evt.GetCol()
        evt.Skip()

    def OnGridBeforeRow(self, evt):
        print 'Before Row', evt.GetRow()
        print
        #print '       Col', evt.GetCol()
        evt.Skip()

    def OnGridAfterRow(self, evt):
        print 'After Row', evt.GetRow()
        print
        #print '      Col', evt.GetCol()
        evt.Skip()

if __name__ == '__main__':
    app = wx.App()
    frame = TestFrame(None)
    frame.Show()
    app.MainLoop()