[目錄]
0.前言
1.什么是 `__name__`?
2.`if __name__ == '__main__'`: 的作用
3.為何Windows更需`if __name__ =`?
前言
if __name__ == '__main__':
是 Python 中一個非常重要的慣用法,尤其在使用 multiprocessing
模塊或編寫可導入的模塊時。它的作用是區分腳本是直接運行還是被導入,從而控制代碼的執行行為。很多初學者可能對此感到困惑,不明白其真正的用途和重要性。
直到發現自己的CPU因為不良的代碼習慣而干燒了數十分鐘才知道后悔
下面詳細解釋它的作用和工作原理。
1. 什么是 __name__
?
__name__
是 Python 中的一個內置變量,它的值取決于腳本的運行方式:
-
當腳本被直接運行時(例如通過
python script.py
運行):__name__
的值是'__main__'
。- 這是 Python 解釋器自動設置的,表示當前腳本是“主腳本”。
-
當腳本被導入為模塊時(例如
import script
):__name__
的值是模塊的名稱(例如script
)。- 此時,腳本中的代碼會被執行,但
__name__
不再是'__main__'
。
示例
假設有一個腳本 example.py
:
print("The value of __name__ is:", __name__)if __name__ == '__main__':print("This script is being run directly.")
else:print("This script is being imported as a module.")
-
直接運行:
python example.py
輸出:
The value of __name__ is: __main__ This script is being run directly.
-
作為模塊導入: 創建另一個腳本
importer.py
:import example
沒錯,這個新腳本就這么短小精悍。
輸出:
The value of __name__ is: example This script is being imported as a module.
2.if __name__ == '__main__'
: 的作用
if __name__ == '__main__'
: 的作用是讓某些代碼塊只在腳本被直接運行時執行,而在腳本被導入時不執行。這有以下幾個重要用途:
避免導入時的副作用
當一個 Python 腳本被導入為模塊時,腳本中的所有頂層代碼(不在函數或類中的代碼)都會被執行。
如果這些頂層代碼包含不希望在導入時運行的操作(例如啟動服務器、執行復雜計算、創建進程等),會產生意外行為。
使用 if name == ‘main’:,可以確保這些代碼只在腳本被直接運行時執行。
示例
假設有一個腳本 math_utils.py:
# 頂層代碼
print("This will always run when the script is imported!")def add(a, b):return a + b# 不使用 if __name__ == '__main__':
result = add(2, 3)
print(f"Result of add(2, 3): {result}")
另一個腳本 main.py
導入了math_utils:
import math_utilsprint("Using math_utils to add numbers...")
print(math_utils.add(5, 6))
運行main.py:
This will always run when the script is imported!
Result of add(2, 3): 5
Using math_utils to add numbers...
11
問題在于,math_utils.py
中的頂層代碼(print
和 result = add(2, 3)
)在導入時被執行了,這可能不是我們想要的。
現在使用 if __name__ == '__main__':
修改 math_utils.py:
print("This will always run when the script is imported!")def add(a, b):return a + bif __name__ == '__main__':result = add(2, 3)print(f"Result of add(2, 3): {result}")
這時我們得到的運行后結果為:
This will always run when the script is imported!
Using math_utils to add numbers...
11
我們可以發現,if __name__ == '__main__':
塊中的代碼(result = add(2, 3)
和相關的 print
)只在 math_utils.py
被直接運行時執行,導入時不會運行。
其他用途
(1) 測試代碼
你可以在 if name == ‘main’: 中添加測試代碼,這些代碼只在腳本直接運行時執行,而不會在導入時運行。譬如:
def add(a, b):return a + bif __name__ == '__main__':# 測試代碼print(add(2, 3))print(add(5, 6))
直接運行時,測試代碼會執行;而導入時,測試代碼則不會運行。
(2) 命令行工具
許多命令行工具使用 if __name__ == '__main__':
來定義入口點,確保主邏輯只在直接運行時執行。譬如:
import sysdef main():print("Hello, world!")print("Arguments:", sys.argv)if __name__ == '__main__':main()
如果不使用 if __name__ == '__main__':
保護,主邏輯會在腳本被導入時意外執行,這可能導致不希望的行為,尤其是在命令行工具中。
3. 為何Windows更需if __name__ =
?
這主要取決于我們 coding \texttt{coding} coding的場景——是否會用到multiprocessing
或是其它類似方法來加速我們的計算。
在 Linux 上,multiprocessing
默認使用 fork
方法創建子進程。fork
會直接復制主進程的內存狀態,子進程不會重新加載腳本,因此頂層代碼不會被重復執行。
在 Windows 上,multiprocessing
使用 spawn
方法,必須重新加載腳本,導致頂層代碼被重復執行,因此需要 if __name__ == '__main__':
保護。
值得提醒的是,當我們在 Github \texttt{Github} Github上拿到其它大佬的項目代碼時,我們自己在本地運行時一定要檢查是否存在類似問題。以pix2pix與CycleGAN項目為例,其原始代碼為:
import os
import numpy as np
import cv2
import argparse
from multiprocessing import Pooldef image_write(path_A, path_B, path_AB):im_A = cv2.imread(path_A, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLORim_B = cv2.imread(path_B, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLORim_AB = np.concatenate([im_A, im_B], 1)cv2.imwrite(path_AB, im_AB)parser = argparse.ArgumentParser('create image pairs')
parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
parser.add_argument('--num_imgs', dest='num_imgs', help='number of images', type=int, default=1000000)
parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)', action='store_true')
parser.add_argument('--no_multiprocessing', dest='no_multiprocessing', help='If used, chooses single CPU execution instead of parallel execution', action='store_true',default=False)
args = parser.parse_args()for arg in vars(args):print('[%s] = ' % arg, getattr(args, arg))splits = os.listdir(args.fold_A)if not args.no_multiprocessing:pool=Pool()for sp in splits:img_fold_A = os.path.join(args.fold_A, sp)img_fold_B = os.path.join(args.fold_B, sp)img_list = os.listdir(img_fold_A)if args.use_AB:img_list = [img_path for img_path in img_list if '_A.' in img_path]num_imgs = min(args.num_imgs, len(img_list))print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))img_fold_AB = os.path.join(args.fold_AB, sp)if not os.path.isdir(img_fold_AB):os.makedirs(img_fold_AB)print('split = %s, number of images = %d' % (sp, num_imgs))for n in range(num_imgs):name_A = img_list[n]path_A = os.path.join(img_fold_A, name_A)if args.use_AB:name_B = name_A.replace('_A.', '_B.')else:name_B = name_Apath_B = os.path.join(img_fold_B, name_B)if os.path.isfile(path_A) and os.path.isfile(path_B):name_AB = name_Aif args.use_AB:name_AB = name_AB.replace('_A.', '.') # remove _Apath_AB = os.path.join(img_fold_AB, name_AB)if not args.no_multiprocessing:pool.apply_async(image_write, args=(path_A, path_B, path_AB))else:im_A = cv2.imread(path_A, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLORim_B = cv2.imread(path_B, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLORim_AB = np.concatenate([im_A, im_B], 1)cv2.imwrite(path_AB, im_AB)
if not args.no_multiprocessing:pool.close()pool.join()
即是顯然是在 Linux \texttt{Linux} Linux系統上使用的multiprocessing
方法。本人一開始尚未注意到該問題,結果CPU干燒了十幾分鐘,出現類似的 RuntimeError \texttt{RuntimeError} RuntimeError
將代碼修正后:
import os
import numpy as np
import cv2
import argparse
from multiprocessing import Pooldef image_write(path_A, path_B, path_AB):im_A = cv2.imread(path_A, 1)im_B = cv2.imread(path_B, 1)if im_A is None or im_B is None:print(f"Failed to load images: {path_A} or {path_B}")returnim_AB = np.concatenate([im_A, im_B], 1)cv2.imwrite(path_AB, im_AB)if __name__ == '__main__':parser = argparse.ArgumentParser('create image pairs')parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')parser.add_argument('--num_imgs', dest='num_imgs', help='number of images', type=int, default=1000000)parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)', action='store_true')parser.add_argument('--no_multiprocessing', dest='no_multiprocessing', help='If used, chooses single CPU execution instead of parallel execution', action='store_true', default=False)args = parser.parse_args()for arg in vars(args):print('[%s] = ' % arg, getattr(args, arg))splits = os.listdir(args.fold_A)if not args.no_multiprocessing:pool = Pool()for sp in splits:img_fold_A = os.path.join(args.fold_A, sp)img_fold_B = os.path.join(args.fold_B, sp)img_list = os.listdir(img_fold_A)if args.use_AB:img_list = [img_path for img_path in img_list if '_A.' in img_path]num_imgs = min(args.num_imgs, len(img_list))print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))img_fold_AB = os.path.join(args.fold_AB, sp)if not os.path.isdir(img_fold_AB):os.makedirs(img_fold_AB)print('split = %s, number of images = %d' % (sp, num_imgs))for n in range(num_imgs):name_A = img_list[n]path_A = os.path.join(img_fold_A, name_A)if args.use_AB:name_B = name_A.replace('_A.', '_B.')else:name_B = name_Apath_B = os.path.join(img_fold_B, name_B)if os.path.isfile(path_A) and os.path.isfile(path_B):print(f"Found pair: {path_A} and {path_B}")name_AB = name_Aif args.use_AB:name_AB = name_AB.replace('_A.', '.') # remove _Apath_AB = os.path.join(img_fold_AB, name_AB)if not args.no_multiprocessing:pool.apply_async(image_write, args=(path_A, path_B, path_AB))else:image_write(path_A, path_B, path_AB)else:print(f"Pair not found: {path_A} or {path_B}")if not args.no_multiprocessing:pool.close()pool.join()
終于能夠正常運行。