A problem with canvas, sizer and scrolledwindows

Is anyone able to help me with this problem?
I attach a source as example.
I would like the scrolled window and the sizer to adapt themselves to the
increasing/decreasing size of the canvas.
But as you can see if I start with a big canvas and draw a small one I don't
get the DC and the sizer refreshed; OTH if I start with a small canvas and
then draw a big figure the sizer will contain the big canvas and I don't get
scrollbars.

···

--------------------------------------------------------------------
import wx
import matplotlib
import wxmpl

class PanelWScrolledWindow(wx.ScrolledWindow):
  def __init__(self,parent):
    wx.ScrolledWindow.__init__(self, parent)
    self.SetScrollRate(20,20)

    self.mainsizer = wx.BoxSizer(wx.VERTICAL)
    self.BTNsizer = wx.BoxSizer(wx.HORIZONTAL)

    self.canvas = wxmpl.PlotPanel(self, -1, size=(20,10)) #START
WITH A BIG CANVAS
    #self.canvas = wxmpl.PlotPanel(self, -1, size=(2,1)) #START
WITH A SMALL CANVAS

    BTNsmall = wx.Button(self, -1, "Small FIG")
    BTNbig = wx.Button(self, -1, "Big FIG")

    self.Bind(wx.EVT_BUTTON, self.onBTNbig, BTNbig)
    self.Bind(wx.EVT_BUTTON, self.onBTNsmall, BTNsmall)

    self.BTNsizer.Add (BTNsmall, 1, wx.GROW|wx.ALL, 1)
    self.BTNsizer.Add (BTNbig, 1, wx.GROW|wx.ALL, 1)
  
    self.mainsizer.Add (self.canvas, 1, wx.GROW|wx.ALL, 1)
    self.mainsizer.Add (self.BTNsizer, 0, wx.GROW|wx.ALL, 1)
    
    self.SetSizer(self.mainsizer)

  def onBTNbig(self,event):
    plot_simple(self.canvas.get_figure(), (20,10))
    self.canvas.draw()

  def onBTNsmall(self,event):
    plot_simple(self.canvas.get_figure(), (2,2))
    self.canvas.draw()

class MyFrame(wx.Frame):
  def __init__(self, parent, title):
    wx.Frame.__init__(self, parent)

    self.nb = wx.Notebook(self)

    self.PageOne = PanelWScrolledWindow(self.nb)
    self.PageTwo = wx.Panel(self.nb)

    self.nb.AddPage(self.PageOne, "PageOne")
    self.nb.AddPage(self.PageTwo, "PageTwo")

def plot_simple(fig, size):

  fig.set_size_inches(size)

  t = arange(0.0, 2.0, 0.01)
  s = sin(2*pi*t)
  c = cos(2*pi*t)
  
  axes = fig.gca()
  axes.plot(t, s, linewidth=1.0)
  axes.plot(t, c, linewidth=1.0)
  
  axes.set_xlabel('time (s)')
  axes.set_ylabel('voltage (mV)')
  axes.set_title('About as simple as it gets, folks')
  axes.grid(True)

#here we start
a = 10
app = wx.PySimpleApp()
frm = MyFrame(None, "Test")
frm.SetSize((800,600))
frm.Show()
app.SetTopWindow(frm)
app.MainLoop()