🌈?個人主頁:十二月的貓-CSDN博客
🔥?系列專欄:?🏀大模型實戰訓練營_十二月的貓的博客-CSDN博客💪🏻?十二月的寒冬阻擋不了春天的腳步,十二點的黑夜遮蔽不住黎明的曙光
目錄
1. 前言
2. LayoutPrompt介紹
3. LayoutPrompt · 布局序列化模塊
3.1 固定化Prompt
3.2?父類序列化模塊
3.3?子類序列化模塊
4. 總結
1. 前言
? ? ? ? 貓貓不知道大家有沒有思考過圖片布局生成模型,這算是生成模型的一個非常小的子任務了。前面帶大家學習過生成模型,包括GAN、Diffusion等。這些都算是生成模型的研究子領域,利用這些子領域的知識,我們可以來研究具體的任務,例如生成圖片、按照語言提示生成圖片、按照布局提示生成圖片等。
????????同樣生成布局也是生成模型中的一個具體任務,其實思路也是非常簡單。生成圖片這個任務中更具體的任務是生成海報、生成照片、生成動漫圖片等。因為直接生成一個海報難度太大,我們就先去生成布局,然后在具體布局的約束下去具體生成完整圖片。可以簡單給大家看一下布局是什么:
? ? ? ? 貓貓研究這個呢,主要還是因為創新實訓和軟件創新大賽兩個需要。同時由于貓貓也研究過一點生成模型,因此這個學期就研究一下這個領域啦~~如果大家也對這個領域感興趣,可以關注我們團隊的專欄(布局生成模型是我們海報生成系統中的一個子模塊,希望更多貓友參與到我們的開發當中哦):大模型實戰訓練營_十二月的貓的博客-CSDN博客
2. LayoutPrompt介紹
總的來說,為了完成LayoutPrompt模型,我們需要完成的子模塊有:
- 數據預處理模塊:該模塊主要就是利用基礎數據預處理方法對數據集中的所有數據樣本進行預處理。
- 動態樣本選擇模塊:從訓練集中檢索最相關的樣本,然后作為最直接的上下文(約束信息)送給大語言模型。
- 布局序列化模塊:用于將上面所選的樣本布局轉化為序列表述(因為大語言模型對序列化輸入有更好的效果)。序列化數據就是類似 自然語言、代碼等。
- 大語言模型模塊:將序列化處理后的所有樣本一起送給大語言模型,讓大語言模型參考的情況下給出自己的答案。
- 大語言模型解析模塊:用于將大語言模型給出的布局結果解析為標準化的輸出。
- 布局排序模塊(布局評價模塊):評價大語言模型生成的布局的質量分數,并做一個排序。?
具體的模型圖如下:
模型運行具體流程如下:
- 用戶輸入前導信息,例如畫布大小,任務類型等,數據預處理模塊預處理數據庫(用戶根據自己的任務可以選擇數據庫,如海報布局數據庫、手機UI設計數據庫)中的所有數據。
- 動態樣本選擇模塊得到處理后的數據庫數據,然后根據用戶輸入的前導信息選擇合適的example樣本。
- 將樣本+前導信息+測試樣本送給大語言模型。
- 由大語言模型生成最終的layout,送給Rank模塊。
- Rank模塊排序后分數最高的就是最終輸出。
3. LayoutPrompt · 布局序列化模塊
布局序列化模塊:本質就是序列化+Prompt。兩者核心都是固定。
序列化:將輸入輸出按照固定序列格式調整。
? ? ? ? 例如輸出固定如下:
- "標題 0 0.1 0.2 0.3 0.4 | 正文 1 0.5 0.6 0.7 0.8"
# < html > # < body > # < div # style = "..." > 標題_0 < / div > # < div # style = "..." > 正文_1 < / div > # < / body > # < / htmlPrompt:根據不同任務,輸入需要不同Prompt,同時序列化格式。
? ? ? ? 例如輸入固定如下:
![]()
一句話來說,布局序列化模塊準備的就是模型中的這一部分:
- 前導部分(PREAMBLE):固定Prompt,用戶輸入畫布大小高度等。
- 輸入限制(INPUT CONSTRAINT) :會有兩種表示形式1.如上圖的seq形式;2.html形式。
- 輸出布局(OUTPUT LAYOUT):得到布局的坐標data后(由其他模塊負責),序列化輸出上圖結果。
從代碼角度來說分為三個部分:
- 固定化Prompt。
- 父類序列化模塊(固定輸出序列結構,輸入序列留接口給子類實現)
- 子類序列化模塊(一個子類對應一個具體的任務,不同任務輸入結構不一樣)
從任務角度存在以下七種:
- 元素類型任務(限制layout中的元素)
- 元素類型,元素尺寸限制任務
- 元素類型,元素之間位置關系
- 元素補全(根據部分已知布局元素,生成完整布局結構)
- 元素尺寸修正(給出類型和尺寸,模型自己修正尺寸)
- 防止遮擋的布局生成
- 文本描述下的布局生成
3.1 固定化Prompt
PREAMBLE = ("Please generate a layout based on the given information. ""You need to ensure that the generated layout looks realistic, with elements well aligned and avoiding unnecessary overlap.\n""Task Description: {}\n""Layout Domain: {} layout\n""Canvas Size: canvas width is {}px, canvas height is {}px"
)
# html的頭部
HTML_PREFIX = """<html>
<body>
<div class="canvas" style="left: 0px; top: 0px; width: {}px; height: {}px"></div>
"""
# html的結尾
HTML_SUFFIX = """</body>
</html>"""# html的body
HTML_TEMPLATE = """<div class="{}" style="left: {}px; top: {}px; width: {}px; height: {}px"></div>
"""HTML_TEMPLATE_WITH_INDEX = """<div class="{}" style="index: {}; left: {}px; top: {}px; width: {}px; height: {}px"></div>
"""
- 本部分主要是為了固定輸入輸出的一些序列化格式,同時固定前導部分。
- 用戶在這里通過前端輸入自己想要的畫布大小高度等信息。
3.2?父類序列化模塊
# 序列化的父類,后面有具體任務不同的序列化
# 所謂序列化本質就是固定輸入的prompt形式,同時固定輸出的一個形式。
class Serializer:def __init__(self,input_format: str,output_format: str,index2label: dict,canvas_width: int,canvas_height: int,add_index_token: bool = True,add_sep_token: bool = True,sep_token: str = "|",add_unk_token: bool = False,unk_token: str = "<unk>",):self.input_format = input_formatself.output_format = output_formatself.index2label = index2labelself.canvas_width = canvas_widthself.canvas_height = canvas_heightself.add_index_token = add_index_tokenself.add_sep_token = add_sep_tokenself.sep_token = sep_tokenself.add_unk_token = add_unk_tokenself.unk_token = unk_tokendef build_input(self, data):if self.input_format == "seq":return self._build_seq_input(data)elif self.input_format == "html":return self._build_html_input(data)else:raise ValueError(f"Unsupported input format: {self.input_format}")# check value is not nulldef _build_seq_input(self, data):raise NotImplementedErrordef _build_html_input(self, data):raise NotImplementedErrordef build_output(self, data, label_key="labels", bbox_key="discrete_gold_bboxes"):if self.output_format == "seq":return self._build_seq_output(data, label_key, bbox_key)elif self.output_format == "html":return self._build_html_output(data, label_key, bbox_key)# # 輸入數據結構示例# data = {# "labels": [0, 1], # 標簽索引# "discrete_gold_bboxes": [ # 坐標列表(假設已離散化)# [0.1, 0.2, 0.3, 0.4],# [0.5, 0.6, 0.7, 0.8]# ]# }# "標題 0 0.1 0.2 0.3 0.4 | 正文 1 0.5 0.6 0.7 0.8"def _build_seq_output(self, data, label_key, bbox_key):# 在字典中存儲的標簽信息,和邊框信息(不是存儲具體字,而是存儲信息)labels = data[label_key]bboxes = data[bbox_key]tokens = []for idx in range(len(labels)):label = self.index2label[int(labels[idx])]bbox = bboxes[idx].tolist()tokens.append(label)if self.add_index_token:tokens.append(str(idx))tokens.extend(map(str, bbox)) # extend一次性添加很多值,append一次添加一個值。map(function,list):把function作用在list上。str():將其他類型轉為str類型。if self.add_sep_token and idx < len(labels) - 1:tokens.append(self.sep_token) # 添加隔離符號return " ".join(tokens)# # 輸出結構# < html ># < body ># < div# style = "..." > 標題_0 < / div ># < div# style = "..." > 正文_1 < / div ># < / body ># < / html >def _build_html_output(self, data, label_key, bbox_key):labels = data[label_key]bboxes = data[bbox_key]htmls = [HTML_PREFIX.format(self.canvas_width, self.canvas_height)] # 使用 HTML_PREFIX 作為 HTML 頁面的開頭,并將畫布寬度和高度傳遞進去。_TEMPLATE = HTML_TEMPLATE_WITH_INDEX if self.add_index_token else HTML_TEMPLATE # 根據 add_index_token 決定使用哪種模板:帶索引的模板或普通模板。for idx in range(len(labels)):label = self.index2label[int(labels[idx])]bbox = bboxes[idx].tolist()element = [label]if self.add_index_token:element.append(str(idx))element.extend(map(str, bbox))htmls.append(_TEMPLATE.format(*element))htmls.append(HTML_SUFFIX)return "".join(htmls)
- 主要定義兩種輸出方式的序列化。第一種是以seq的形式輸出;第二種是以html的形式輸出。
- 輸入形式的序列化(prompt+序列化)僅僅定義了模板交給子類(具體任務)來實現。
- expand和append都是列表后追加,但expand一次追加很多個元素,append一次加一個。
- map(function,list):將function作用在所有的list元素上。
- 作用:1.?得到其他模塊給的Layout坐標data后,將其轉化為html或seq的固定序列化格式輸出。2.作為父類將輸入序列化交給子類實現。
3.3?子類序列化模塊
????????前面也說了,子類序列化模塊需要根據不同的任務要求給出不同的prompt以及序列化輸入。因此有多少個任務就會有多少個子類序列化模塊。在這里,貓貓僅僅展示兩個模塊的代碼,完整的代碼等專欄更新結束后,會同步放在Gitee以及CSDN賬號下。
任務一:限制元素類型的布局生成
class GenTypeSerializer(Serializer):task_type = "generation conditioned on given element types"constraint_type = ["Element Type Constraint: "]HTML_TEMPLATE_WITHOUT_ANK = '<div class="{}"></div>\n'HTML_TEMPLATE_WITHOUT_ANK_WITH_INDEX = '<div class="{}" style="index: {}"></div>\n'def _build_seq_input(self, data):labels = data["labels"]tokens = []for idx in range(len(labels)):label = self.index2label[int(labels[idx])]tokens.append(label)if self.add_index_token:tokens.append(str(idx))if self.add_unk_token:tokens += [self.unk_token] * 4if self.add_sep_token and idx < len(labels) - 1:tokens.append(self.sep_token)return " ".join(tokens)def _build_html_input(self, data):labels = data["labels"]htmls = [HTML_PREFIX.format(self.canvas_width, self.canvas_height)]if self.add_index_token and self.add_unk_token:_TEMPLATE = HTML_TEMPLATE_WITH_INDEXelif self.add_index_token and not self.add_unk_token:_TEMPLATE = self.HTML_TEMPLATE_WITHOUT_ANK_WITH_INDEXelif not self.add_index_token and self.add_unk_token:_TEMPLATE = HTML_TEMPLATEelse:_TEMPLATE = self.HTML_TEMPLATE_WITHOUT_ANKfor idx in range(len(labels)):label = self.index2label[int(labels[idx])]element = [label]if self.add_index_token:element.append(str(idx))if self.add_unk_token:element += [self.unk_token] * 4htmls.append(_TEMPLATE.format(*element))htmls.append(HTML_SUFFIX)return "".join(htmls)def build_input(self, data):return self.constraint_type[0] + super().build_input(data)
任務二:限制元素類型,以及元素關系的布局生成
class GenRelationSerializer(Serializer):task_type = ("generation conditioned on given element relationships\n""'A left B' means that the center coordinate of A is to the left of the center coordinate of B. ""'A right B' means that the center coordinate of A is to the right of the center coordinate of B. ""'A top B' means that the center coordinate of A is above the center coordinate of B. ""'A bottom B' means that the center coordinate of A is below the center coordinate of B. ""'A center B' means that the center coordinate of A and the center coordinate of B are very close. ""'A smaller B' means that the area of A is smaller than the ares of B. ""'A larger B' means that the area of A is larger than the ares of B. ""'A equal B' means that the area of A and the ares of B are very close. ""Here, center coordinate = (left + width / 2, top + height / 2), ""area = width * height")constraint_type = ["Element Type Constraint: ", "Element Relationship Constraint: "]HTML_TEMPLATE_WITHOUT_ANK = '<div class="{}"></div>\n'HTML_TEMPLATE_WITHOUT_ANK_WITH_INDEX = '<div class="{}" style="index: {}"></div>\n'def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)self.index2type = RelationTypes.index2type()def _build_seq_input(self, data):labels = data["labels"]relations = data["relations"]tokens = []for idx in range(len(labels)):label = self.index2label[int(labels[idx])]tokens.append(label)if self.add_index_token:tokens.append(str(idx))if self.add_unk_token:tokens += [self.unk_token] * 4if self.add_sep_token and idx < len(labels) - 1:tokens.append(self.sep_token)type_cons = " ".join(tokens)if len(relations) == 0:return self.constraint_type[0] + type_constokens = []for idx in range(len(relations)):label_i = relations[idx][2]index_i = relations[idx][3]if label_i != 0:tokens.append("{} {}".format(self.index2label[int(label_i)], index_i))else:tokens.append("canvas")tokens.append(self.index2type[int(relations[idx][4])])label_j = relations[idx][0]index_j = relations[idx][1]if label_j != 0:tokens.append("{} {}".format(self.index2label[int(label_j)], index_j))else:tokens.append("canvas")if self.add_sep_token and idx < len(relations) - 1:tokens.append(self.sep_token)relation_cons = " ".join(tokens)return (self.constraint_type[0]+ type_cons+ "\n"+ self.constraint_type[1]+ relation_cons)def _build_html_input(self, data):labels = data["labels"]relations = data["relations"]htmls = [HTML_PREFIX.format(self.canvas_width, self.canvas_height)]if self.add_index_token and self.add_unk_token:_TEMPLATE = HTML_TEMPLATE_WITH_INDEXelif self.add_index_token and not self.add_unk_token:_TEMPLATE = self.HTML_TEMPLATE_WITHOUT_ANK_WITH_INDEXelif not self.add_index_token and self.add_unk_token:_TEMPLATE = HTML_TEMPLATEelse:_TEMPLATE = self.HTML_TEMPLATE_WITHOUT_ANKfor idx in range(len(labels)):label = self.index2label[int(labels[idx])]element = [label]if self.add_index_token:element.append(str(idx))if self.add_unk_token:element += [self.unk_token] * 4htmls.append(_TEMPLATE.format(*element))htmls.append(HTML_SUFFIX)type_cons = "".join(htmls)if len(relations) == 0:return self.constraint_type[0] + type_constokens = []for idx in range(len(relations)):label_i = relations[idx][2]index_i = relations[idx][3]if label_i != 0:tokens.append("{} {}".format(self.index2label[int(label_i)], index_i))else:tokens.append("canvas")tokens.append(self.index2type[int(relations[idx][4])])label_j = relations[idx][0]index_j = relations[idx][1]if label_j != 0:tokens.append("{} {}".format(self.index2label[int(label_j)], index_j))else:tokens.append("canvas")if self.add_sep_token and idx < len(relations) - 1:tokens.append(self.sep_token)relation_cons = " ".join(tokens)return (self.constraint_type[0]+ type_cons+ "\n"+ self.constraint_type[1]+ relation_cons)
- 針對此任務設計了具體的Prompt形式。
- 輸入數據序列化有兩個部分:1、seq序列化;2、html序列化
- seq序列化:用戶輸入限制后,需要結合Prompt序列化到INPUT_CONSTRAINT(
- (注意:這個seq并不是完整的,僅僅是將用戶的限制要求填入):
- html序列化:根據用戶輸入限制生成對應的html格式(注意:這個html并不是完整的,僅僅是將用戶的限制要求填入)
- 作用:1.將用戶輸入的限制轉化為某種序列化(seq序列化 或 html序列化)。2.結合固定Prompt包裝成完整的INPUT CONSTRAINT
4. 總結
本篇文章帶大家深入了解了PosterGenius項目的Layout生成部分的第一篇,后續將更新Layout系列的第二篇。歡迎大家繼續支持貓貓呀!!
?【如果想學習更多深度學習文章,可以訂閱一下熱門專欄】
- 《PyTorch科研加速指南:即插即用式模塊開發》_十二月的貓的博客-CSDN博客
- 《深度學習理論直覺三十講》_十二月的貓的博客-CSDN博客
- 《AI認知筑基三十講》_十二月的貓的博客-CSDN博客
如果想要學習更多pyTorch/python編程的知識,大家可以點個關注并訂閱,持續學習、天天進步你的點贊就是我更新的動力,如果覺得對你有幫助,辛苦友友點個贊,收個藏呀~~~
本文撰寫人:十二月的貓? ?十二月的貓-CSDN博客