5.?方言與操作
5.1.?方言的概念
在MLIR里,通過Dialect類來抽象方言。具體的每種方言都需要從這個基類派生一個類型,并實現重載自己所需的虛函數。?
MLIR文檔里這樣描述方言( MLIR Language Reference - MLIR):
方言是這樣的機制:它融入并擴展MLIR生態系統。它們允許定義新的操作,以及屬性與類型。向每個方言給出唯一的名字空間作為定義的每個屬性/操作/類型的前綴。例如,Affine方言定義了名字空間affine。
MLIR允許多個方言共存于一個模塊中,即使是在主干之外的那些。方言由特定的遍生成與消費。對不同的方言之間以及方言內部的轉換,MLIR提供了一個框架。MLIR支持的幾個方言:
- Affine dialect
- GPU dialect
- LLVM dialect
- SPIR-V dialect
- Standard dialect
- Vector dialect
?在教程中,還給出了Toy方言的例子。
5.2.?操作的概念
MLIR引入了一個稱為操作(operation)的統一概念來描述許多不同的抽象與計算層次。在MLIR系統中,從指令到函數再到模塊,一切都塑造為Op。 MLIR沒有固定的Op集合,因此允許并鼓勵用戶自定義擴展Op。編譯器遍會保守地對待未知Op,并且MLIR支持通過特征(traits)、特權操作hook和優化接口等方式向遍描述Op語義。MLIR里的操作是完全可擴展的(沒有固定的操作列表),并具有應用特定的語義。例如,MLIR支持目標無關操作、仿射(affine)操作,以及目標特定機器操作。
操作的內部表示是簡單的:操作由一個唯一字符串標識(如dim、tf.Conv2d、x86.repmovsb、ppc.eieio等),可以返回0或多個結果,接受0或多個操作數,有一個屬性字典,有0或多個后繼者,以及0或多個封閉的區域。通用打印形式包括所有這些元素,加上一個函數類型來表示結果與操作數的類型。
例子:
// An operation that produces two results.
// The results of %result can be accessed via the <name> `#` <opNo> syntax.
%result:2 = "foo_div"() : () -> (f32, i32)
// Pretty form that defines a unique name for each result.
%foo, %bar = "foo_div"() : () -> (f32, i32)
// Invoke a TensorFlow function called tf.scramble with two inputs
// and an attribute "fruit".
%2 = "tf.scramble"(%result#0, %bar) {fruit = "banana"} : (f32, i32) -> f32
5.3.?方言的管理
顯然,管理方言最恰當的地方就是MLIRContext,不過MLIRContext只是上下文的接口,真正的實現是MLIRContextImpl。在MLIRContextImpl中有這樣一些容器:
DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects
DialectRegistry dialectsRegistry
llvm::StringMap<AbstractOperation> registeredOperations
llvm::StringMap<PointerUnion<Dialect *, MLIRContext *>, llvm::BumpPtrAllocator &> identifiers
第一個容器保存已載入的方言對象,第二個容器記錄已注冊的方言,第三個容器保存已注冊的抽象操作,第四個容器記錄上下文里可見的標識符。
這里,方言的注冊是指MLIRContext知道這個方言的標識符和構造方法,但方言對象并沒有構造。方言對象的構造發生在載入時,在載入時刻,不僅構造方言對象,相關的接口也會一并準備好。抽象操作與操作相關,參考操作的管理。
5.3.1.?方言的注冊
MLIR提供了一組標準的方言,它們提供了許多有用的功能。為了讓程序能方便地使用標準方言,首先,每個程序在main()的入口都要注冊標準方言,像這樣:
int main(int argc, char **argv) {
? mlir::registerAllDialects();
? …????? // 其他初始化
?registerAllDialects()的定義如下:
67? inline void registerAllDialects() {
68? ??static bool initOnce =
69? ??????([]() { registerAllDialects(getGlobalDialectRegistry()); }(), true);
70? ??(void)initOnce;
71? }
69行的getGlobalDialectRegistry()返回一個類型為llvm::ManagedStatic<DialectRegistry>的對象dialectRegistry,它過可視為DialectRegistry的靜態對象,這個類通過一個類型為std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>的容器registry來記錄標準方言。因此,同一行上的registerAllDialects()重載函數是:
41? inline void registerAllDialects(DialectRegistry ®istry) {
42? ??// clang-format off
43? ??registry.insert<acc::OpenACCDialect,
44? ??????????????????AffineDialect,
45? ??????????????????avx512::AVX512Dialect,
46? ??????????????????gpu::GPUDialect,
47? ??????????????????LLVM::LLVMAVX512Dialect,
48? ??????????????????LLVM::LLVMDialect,
49? ??????????????????linalg::LinalgDialect,
50? ??????????????????scf::SCFDialect,
51? ??????????????????omp::OpenMPDialect,
52? ??????????????????pdl::PDLDialect,
53? ??????????????????pdl_interp::PDLInterpDialect,
54? ??????????????????quant::QuantizationDialect,
55? ??????????????????spirv::SPIRVDialect,
56? ??????????????????StandardOpsDialect,
57? ??????????????????vector::VectorDialect,
58? ??????????????????NVVM::NVVMDialect,
59? ??????????????????ROCDL::ROCDLDialect,
60? ??????????????????SDBMDialect,
61? ??????????????????shape::ShapeDialect>();
62? ??// clang-format on
63? }
這個方法列出了MLIR目前實現的標準方言,DialectRegistry通過一系列insert()方法完成注冊:
252? ??template <typename ConcreteDialect, typename OtherDialect,
253? ??????????????typename... MoreDialects>
254? ????void insert() {
255? ??????insert<ConcreteDialect>();
256? ??????insert<OtherDialect, MoreDialects...>();
257? ????}
241? ??template <typename ConcreteDialect>
242? ??void insert() {
243? ????insert(TypeID::get<ConcreteDialect>(),
244? ???????????ConcreteDialect::getDialectNamespace(),
245? ???????????static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
246? ?????????????// Just allocate the dialect, the context
247? ?????????????// takes ownership of it.
248? ?????????????return ctx->getOrLoadDialect<ConcreteDialect>();
249? ???????????})));
250? ??}
53? void DialectRegistry::insert(TypeID typeID, StringRef name,
54? ?????????????????????????????DialectAllocatorFunction ctor) {
55? ??auto inserted = registry.insert(
56? ??????std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
57? ??if (!inserted.second && inserted.first->second.first != typeID) {
58? ????llvm::report_fatal_error(
59? ????????"Trying to register different dialects for the same namespace: " +
60? ????????name);
61? ??}
62? }
?注意,這里只是注冊了這些方言的構造方法,并沒有把這些方言的Dialect對象構造出來。這是因為Dialect對象的構造需要一個MLIRContext實例,因此要把Dialect對象的構造推遲到MLIRContext對象構造出來后。另外,不是每個程序都需要所有的標準方言,在需要時構造所需的方言才比較合理。所以,DialectRegistry提供了兩個函數:loadByName(),loadAll()。前者構造指定名字的標準方言,后者構造所有的標準方言。
5.3.2.?方言的載入
從上面注冊的構造方法我們看到,實際執行構造的函數是MLIRContext的getOrLoadDialect(),這也是一系列調用:
69? ??template <typename T>
70? ????T *getOrLoadDialect() {
71? ??????return static_cast<T *>(
72? ????????getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() {
73? ??????????std::unique_ptr<T> dialect(new T(this));
74? ??????????return dialect;
75? ????????}));
76? ??}
模板參數T是具體的方言類型,它是Dialect的派生類,Dialect沒有定義getDialectNamespace(),派生類必須提供自己的定義。在MLIRContext里這個名字將作為這個方言的身份識別。另外,Dialect及其派生類亦是MLIR類型系統中的組成,它們都有TypeIDMLIRContext::getOrLoadDialect()在2021版本里的定義如下:
511 ?Dialect *
512 ?MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
513 ???????????????????????????????function_ref<std::unique_ptr<Dialect>()> ctor) {
514 ???auto &impl = getImpl();
515 ???// Get the correct insertion position sorted by namespace.
516 ???std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
517 ?
518 ???if (!dialect) {
519 ?????LLVM_DEBUG(llvm::dbgs()
520 ????????????????<< "Load new dialect in Context " << dialectNamespace << "\n");
521 ?#ifndef NDEBUG
522 ?????if (impl.multiThreadedExecutionContext != 0)
523 ???????llvm::report_fatal_error(
524 ???????????"Loading a dialect (" + dialectNamespace +
525 ???????????") while in a multi-threaded execution context (maybe "
526 ???????????"the PassManager): this can indicate a "
527 ???????????"missing `dependentDialects` in a pass for example.");
528 ?#endif
529 ?????dialect = ctor();
530 ?????assert(dialect && "dialect ctor failed");
531 ?
532 ?????// Refresh all the identifiers dialect field, this catches cases where a
533 ?????// dialect may be loaded after identifier prefixed with this dialect name
534 ?????// were already created.
535 ?????llvm::SmallString<32> dialectPrefix(dialectNamespace);
536 ?????dialectPrefix.push_back('.');
537 ?????for (auto &identifierEntry : impl.identifiers)
538 ???????if (identifierEntry.second.is<MLIRContext *>() &&
539 ???????????identifierEntry.first().startswith(dialectPrefix))
540 ?????????identifierEntry.second = dialect.get();
541 ?
542 ?????// Actually register the interfaces with delayed registration.
543 ?????impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
544 ?????return dialect.get();
545 ???}
546 ?
547 ???// Abort if dialect with namespace has already been registered.
548 ???if (dialect->getTypeID() != dialectID)
549 ?????llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
550 ??????????????????????????????"' has already been registered");
551 ?
552 ???return dialect.get();
553 ?}
從上面可以看到,構造出來的Dialect對象存放在MLIRContextImpl的loadedDialects容器中(類型DenseMap<StringRef, std::unique_ptr<Dialect>>)。同樣,MLIRContext也提供這些函數獲取指定方言的Dialect對象:getOrLoadDialect(),loadDialect()等。比如,Toy例子代碼有這樣的代碼片段來構建自己的方言對象:
int dumpMLIR() {
? mlir::MLIRContext context(/*loadAllDialects=*/false);
? // Load our Dialect in this MLIR Context.
? context.getOrLoadDialect<mlir::toy::ToyDialect>();
identifiers 是MLIRContextImpl里有類型為llvm::StringMap<PointerUnion<Dialect *, MLIRContext *>, llvm::BumpPtrAllocator &>的容器。MLIR里的標識符是帶有上下文前綴或方言前綴的(以“.”分隔),identifiers容器就是關聯標識符與其所在上下文對象或方言對象的,一個操作在創建時首先假設它在一個上下文(參考Identifier::get()),上面537行的for循環檢查是否已經創建了具有這個方言名的上下文對象,如果是把它替換為對應的方言對象。
5.3.2.1.?方言接口
接下來,通過DialectRegistry::registerDelayedInterfaces()向MLIRContextImpl注冊方言的接口。這里“延遲接口”的意思是只在方言載入(或創建)時才注冊接口。
106? void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
107??? auto it = interfaces.find(dialect->getTypeID());
108??? if (it == interfaces.end())
109????? return;
110?
111??? // Add an interface if it is not already present.
112??? for (const auto &kvp : it->getSecond().dialectInterfaces) {
113????? if (dialect->getRegisteredInterface(kvp.first))
114??????? continue;
115????? dialect->addInterface(kvp.second(dialect));
116??? }
117?
118?? ?// Add attribute, operation and type interfaces.
119??? for (const auto &kvp : it->getSecond().objectInterfaces)
120????? kvp.second(dialect->getContext());
121? }
方言可用的接口都保存在DialectRegistry類型為DenseMap<TypeID, DelayedInterfaces>的interfaces容器中,其中DelayedInterfaces是DialectRegistry里這樣的一個嵌套定義:
283 ?struct DelayedInterfaces {
284 ?????/// Dialect interfaces.
285 ?????SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
286 ?????????dialectInterfaces;
287 ?????/// Attribute/Operation/Type interfaces.
288 ?????SmallVector<std::pair<TypeID, ObjectInterfaceAllocatorFunction>, 2>
289 ?????????objectInterfaces;
290 ???};
在下面我們會看到,方言除了自己的接口,還支持操作/類型/屬性的外部模式接口,289行的objectInterfaces是存放這些操作接口的地方。這兩個容器用到這兩個定義:
30 ?using DialectInterfaceAllocatorFunction =
31 ?????std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
32 ?using ObjectInterfaceAllocatorFunction = std::function<void(MLIRContext *)>;
顧名思義,這兩個std::function封裝的方法用于創建接口對象,在上面115與120行,它們被調用來創建具體的接口對象。Dialect的addInterface()將生成的接口對象保存在registeredInterfaces容器中(類型DenseMap<TypeID, std::unique_ptr<DialectInterface>>)。而120行處,創建的接口對象實際上保存在對應操作/類型/屬性抽象對象的容器中,下面會看到。