函式修飾器


你設計了一個點餐程式,目前主餐有炸雞,價格為49元:
def friedchicken():
return 49.0

print(friedchicken()) # 49.0

之後在幾個地方都 呼叫了friedchicken()函式來計算餐點價格,現在你打算增加附餐,但又不想直接修改friedchicken()函式,也不想另外增加一個 friedchickenside1()函式,然後到處修改先前使用到friedchicken()函式的地方,則你可以這麼撰寫:
def sidedish1(meal):
return lambda: meal() + 30

def friedchicken():
return 49.0

friedchicken = sidedish1(friedchicken)
print(friedchicken()) # 79.0

sidedish1()接受函式物件,函式中使用lamdba建立一個函式物件,該函式物件執行傳入的函式取得主餐價格,再加上附餐價格,sidedish1()傳回所建立的函式物件給friedchicken參考,所以之後執行的friedchicken(),就會是主餐加附餐的價格。

以上是傳遞函式的一個應用。在Python中,你還可以使用以下的語法:
def sidedish1(meal):
return lambda: meal() + 30

@sidedish1
def friedchicken():
return 49.0

print(friedchicken()) # 79.0

@之後所接上的名稱,實際上就是個函式,@sidedish1這樣的標注方式,讓@sidedish1更像是個修飾器(Decorator),將friedchicken()函式加以修飾,增加附餐價格。

你可以堆疊修飾器,例如:
def sidedish1(meal):
return lambda: meal() + 30

def sidedish2(meal):
return lambda: meal() + 40

@sidedish1
@sidedish2
def friedchicken():
return 49.0

print(friedchicken()) # 119.0

上例實際上等同於:
def sidedish1(meal):
return lambda: meal() + 30

def sidedish2(meal):
return lambda: meal() + 40

def friedchicken():
return 49.0

friedchicken = sidedish1(sidedish2(friedchicken))

print(friedchicken()) # 119.0

如果你的修飾器語法需要帶有參數,則記得,會先以參數執行一次修飾器,傳回函式物件再修飾指定的函式。例如:
@deco('param')
def func():
    pass

實際上等於:
func = deco('param')(func)

所以若要讓點餐程式更有彈性一些,你可以這麼設計:
def sidedish(number):
return {
1 : lambda meal: (lambda: meal() + 30),
2 : lambda meal: (lambda: meal() + 40),
3 : lambda meal: (lambda: meal() + 50),
4 : lambda meal: (lambda: meal() + 60)
}.get(number, lambda meal: (lambda: meal()))

@sidedish(2)
@sidedish(3)
def friedchicken():
return 49.0

print(friedchicken()) # 139.0

以上的程式都是使用lamdba建立傳回的函式,若不易理解,以下這個是個較清楚的版本:
def sidedish(number):
def dish1(meal):
return lambda: meal() + 30

def dish2(meal):
return lambda: meal() + 40

def dish3(meal):
return lambda: meal() + 50

def dish4(meal):
return lambda: meal() + 60

def nodish(meal):
return lambda: meal()

return {
1 : dish1,
2 : dish2,
3 : dish3,
4 : dish4
}.get(number, nodish)

@sidedish(2)
@sidedish(3)
def friedchicken():
return 49.0

print(friedchicken()) # 139.0