import wx

class ImageViewer(wx.Frame):

    def __init__(self, parent=None):

        wx.Frame.__init__(self, parent)

        self._scaleToFit = True
        self._panInProgress = False
        self._FirstPaint = True

        self._RawBitmap = wx.Bitmap("sample_image.jpeg", wx.BITMAP_TYPE_ANY)

        self.Bind(wx.EVT_PAINT, self.OnPaint)
        self.Bind(wx.EVT_MOUSEWHEEL, self.OnMouseWheel)
        self.Bind(wx.EVT_LEFT_DOWN, self.OnLeftDown)
        self.Bind(wx.EVT_LEFT_DCLICK, self.OnLeftDClick)
        self.Bind(wx.EVT_SIZE, self.OnSize)

        self.SetBackgroundStyle(wx.BG_STYLE_PAINT)

        self._ZoomFactor = 1.0
        self._Start_X_Pos = self._Start_Y_Pos = 0

        self._panVector = wx.Point2D(0, 0)
        self._inProgressPanStartPoint = wx.Point(0, 0)
        self._inProgressPanVector = wx.Point2D(0, 0)
        self._panInProgress = False

        self._BitmapWidth = self._RawBitmap.GetWidth()
        self._BitmapHeight = self._RawBitmap.GetHeight()


    def ScaleToFit(self):

        s = self.GetClientSize()

        win_aspect = s.x / s.y
        img_aspect = self._BitmapWidth / self._BitmapHeight

        # if the window is wider than the image, the height determines the scale factor
        if win_aspect > img_aspect:
            self._ZoomFactor = s.y / self._BitmapHeight
        else:
            self._ZoomFactor = s.x / self._BitmapWidth
        
        scaledImageWidth = self._BitmapWidth * self._ZoomFactor
        self._Start_X_Pos = (s.x - scaledImageWidth) / 2
        self._Start_X_Pos /= self._ZoomFactor
        
        scaledImageHeight = self._BitmapHeight * self._ZoomFactor
        self._Start_Y_Pos = (s.y - scaledImageHeight) / 2
        self._Start_Y_Pos /= self._ZoomFactor

        self.Refresh()


    def DoDrawCanvas(self, gc):

        gc.DrawBitmap(self._DrawBitmap, self._Start_X_Pos, self._Start_Y_Pos, self._BitmapWidth, self._BitmapHeight)
        self.SetTitle("x: %.2f, y: %.2f, zoom: %0.2f" % (self._Start_X_Pos, self._Start_Y_Pos, self._ZoomFactor))



    def OnSize(self, event):

        if self._scaleToFit:
            self.ScaleToFit()

        event.Skip()


    def OnPaint(self, event):

        dc = wx.AutoBufferedPaintDC(self)
        dc.Clear()

        # #direct2d renderer
        # d2dr = wx.GraphicsRenderer.GetDirect2DRenderer()
        # gc = d2dr.CreateContext(dc)

        gc = wx.GraphicsContext.Create(dc)

        if gc:
        
            totalPan = wx.Point2D(self._panVector.x + self._inProgressPanVector.x, self._panVector.y + self._inProgressPanVector.y)

            gc.Translate(-totalPan.x, -totalPan.y)
            gc.Scale(self._ZoomFactor, self._ZoomFactor)

            if self._FirstPaint:
            
                self._DrawBitmap = gc.CreateBitmap(self._RawBitmap)
                self._FirstPaint = False
            

            self.DoDrawCanvas(gc)
            del gc
        


    def OnMouseWheel(self, event):

        if self._panInProgress:
            self.FinishPan(False)
        
        rot = event.GetWheelRotation()
        delta = event.GetWheelDelta()

        oldZoom = self._ZoomFactor
        self._ZoomFactor += 0.25 * (rot / delta)

        if self._ZoomFactor < 0.1:
            self._ZoomFactor = 0.1
        
        # if self._ZoomFactor > 32.0:
            # self._ZoomFactor = 32
        
        a = oldZoom
        b = self._ZoomFactor

        # Set the panVector so that the point below the cursor in the new
        # scaled/panned cooresponds to the same point that is currently below it.
        uvPoint = event.GetPosition()
        newSTPoint = wx.Point2D((uvPoint.x + self._panVector.x) * b / a, (uvPoint.y + self._panVector.y) * b / a)
        self._panVector = wx.Point2D(newSTPoint.x - uvPoint.x, newSTPoint.y - uvPoint.y)

        self.Refresh()


    def ProcessPan(self, pt, refresh):

        self._inProgressPanVector = self._inProgressPanStartPoint - pt

        if refresh:
            self.Refresh()
    


    def FinishPan(self, refresh):

        if self._panInProgress:
        
            self.SetCursor(wx.NullCursor)

            if self.HasCapture():
                self.ReleaseMouse()
            
            self.Unbind(wx.EVT_LEFT_UP)
            self.Unbind(wx.EVT_MOTION)
            self.Unbind(wx.EVT_MOUSE_CAPTURE_LOST)

            self._panVector = wx.Point2D(self._panVector.x + self._inProgressPanVector.x, self._panVector.y + self._inProgressPanVector.y)
            self._inProgressPanVector = wx.Point2D(0, 0)
            self._panInProgress = False

            if refresh:
                self.Refresh()
            

    def GetUntransformedRect(self):

        a = self._ZoomFactor / 100.0

        sz = self.GetSize()
        zero = self._panVector / a

        return wx.Rect2DDouble(zero.x, zero.y, sz.GetWidth() / a, sz.GetHeight() / a)


    def OnLeftDClick(self, event):

        self._scaleToFit = not self._scaleToFit
        if self._scaleToFit:
            self.ScaleToFit()


    def OnLeftDown(self, event):

        cursor = wx.Cursor(wx.CURSOR_HAND)
        self.SetCursor(cursor)

        self._inProgressPanStartPoint = event.GetPosition()
        self._inProgressPanVector = wx.Point2D(0, 0)
        self._panInProgress = True

        self.Bind(wx.EVT_LEFT_UP, self.OnLeftUp)
        self.Bind(wx.EVT_MOTION, self.OnMotion)
        self.Bind(wx.EVT_MOUSE_CAPTURE_LOST, self.OnCaptureLost)

        self.CaptureMouse()


    def OnMotion(self, event):

        self.ProcessPan(event.GetPosition(), True)


    def OnLeftUp(self, event):

        self.ProcessPan(event.GetPosition(), False)
        self.FinishPan(True)


    def OnCaptureLost(self, event):

        self.FinishPan(True)


if __name__ == '__main__':

    app = wx.App(0)
    frame = ImageViewer(None)

    frame.SetSize(1200, 800)
    frame.Show()
    app.MainLoop()
    
    
    
