新智元報道
編輯:英智 好睏
【新智元導讀】SANA 1.5是一種高效可擴充套件的線性擴散Transformer,針對文字生成影象任務進行了三項創新:高效的模型增長策略、深度剪枝和推理時擴充套件策略。這些創新不僅大幅降低了訓練和推理成本,還在生成質量上達到了最先進的水平。
近年來,文字生成影象的技術不斷突破,但隨著模型規模的擴大,計算成本也隨之急劇上升。
為此,聯合MIT、清華、北大等機構的研究人員提出了一種高效可擴充套件的線性擴散Transformer——SANA,在大幅降低計算需求的情況下,還能保持有競爭力的效能。
SANA1.5在此基礎上,聚焦了兩個關鍵問題:
線性擴散Transformer的可擴充套件性如何?
在擴充套件大規模線性DiT時,怎樣降低訓練成本?
論文連結:https://arxiv.org/pdf/2501.18427
SANA 1.5:高效模型擴充套件三大創新
SANA 1.5在SANA 1.0(已被ICLR 2025接收)的基礎上,有三項關鍵創新。
首先,研究者提出了一種高效的模型增長策略,使得SANA可以從1.6B(20層)擴充套件到4.8B(60層)引數,同時顯著減少計算資源消耗,並結合了一種節省記憶體的8位最佳化器。
與傳統的從頭開始訓練大模型不同,透過有策略地初始化額外模組,可以讓大模型保留小模型的先驗知識。與從頭訓練相比,這種方法能減少60%的訓練時間。
其二,引入了模型深度剪枝技術,實現了高效的模型壓縮。透過識別並保留關鍵的塊,實現高效的模型壓縮,然後透過微調快速恢復模型質量,實現靈活的模型配置。
其三,研究者提出了一種推理期間擴充套件策略,引入了重複取樣策略,使得SANA在推理時透過計算而非引數擴充套件,使小模型也能達到大模型的生成質量。
透過生成多個樣本,並利用基於視覺語言模型(VLM)的選擇機制,將GenEval分數從0.72提升至0.80。
與從頭開始訓練大模型不同,研究者首先將一個包含N個Transformer層的基礎模型擴充套件到N+M層(在實驗中,N=20,M=40),同時保留其學到的知識。
在推理階段,採用兩種互補的方法,實現高效部署:
模型深度剪枝機制:識別並保留關鍵的Transformer塊,從而在小的微調成本下,實現靈活的模型配置。
推理時擴充套件策略:藉助重複取樣和VLM引導選擇,在計算資源和模型容量之間權衡。
同時,記憶體高效CAME-8bit最佳化器讓單個消費級GPU上微調十億級別的模型成為可能。
下圖展示了這些元件如何在不同的計算資源預算下協同工作,實現高效擴充套件。
模型增長
研究者提出一種高效的模型增長策略,目的是對預訓練的DiT模型進行擴充套件,把它從層增加到+層,同時保留模型已經學到的知識。
研究過程中,探索了三種初始化策略,最終選定部分保留初始化方法。這是因為該方法既簡單又穩定。
在這個策略裡,預訓練的N層繼續發揮特徵提取的作用,而新增加的M層一開始是隨機初始化,從恆等對映起步,慢慢學習最佳化特徵表示。
實驗結果顯示,與迴圈擴充套件和塊擴充套件策略相比,這種部分保留初始化方法在訓練時的動態表現最為穩定。
模型剪枝
本文提出了一種模型深度剪枝方法,能高效地將大模型壓縮成各種較小的配置,同時保持模型質量。
受Minitron啟發,透過輸入輸出相似性模式分析塊的重要性:
這裡的表示第i個transformer的第t個token。
模型的頭部和尾部塊的重要性較高,而中間層的輸入和輸出特徵相似性較高,表明這些層主要用於逐步最佳化生成的結果。根據排序後的塊重要性,對transformer塊進行剪枝。
剪枝會逐步削弱高頻細節,因為,在剪枝後進一步微調模型,以彌補資訊損失。
使用與大模型相同的訓練損失來監督剪枝後的模型。剪枝模型的適配過程非常簡單,僅需100步微調,剪枝後的1.6B引數模型就能達到與完整的4.8B引數模型相近的質量,並且優於SANA 1.0的1.6B模型。
推理時擴充套件
SANA 1.5經過充分訓練,在高效擴充套件的基礎上,生成能力有了顯著提升。受LLM推理時擴充套件的啟發,研究者也想透過這種方式,讓SANA 1.5表現得更好。
對SANA和很多擴散模型來說,增加去噪步數是一種常見的推理時擴充套件方法。但實際上,這個方法不太理想。一方面,新增的去噪步驟沒辦法修正之前出現的錯誤;另一方面,生成質量很快就會達到瓶頸。
相較而言,增加取樣次數是更有潛力的方向。
研究者用視覺語言模型(VLM)來判斷生成影象和文字提示是否匹配。他們以NVILA-2B為基礎模型,專門製作了一個資料集對其進行微調。
微調後的VLM能自動比較並評價生成的影象,經過多輪篩選,選出排名top-N的候選影象。這不僅確保了評選結果的可靠性,還能有效過濾與文字提示不匹配的影象。
模型增長、模型深度剪枝和推理擴充套件,構成了一個高效的模型擴充套件框架。三種方法協同配合,證明了精心設計的最佳化策略,遠比單純增加引數更有效。
模型增長策略探索了更大的最佳化空間,挖掘出更優質的特徵表示。
模型深度剪枝精準識別並保留了關鍵特徵,從而實現高效部署。
推理時間擴充套件表明,當模型容量有限時,藉助額外的推理時間和計算資源,能讓模型達到與大模型相似甚至更好的效果。
為了實現大模型的高效訓練與微調,研究者對CAME進行擴充套件,引入按塊8位量化,從而實現CAME-8bit最佳化器。
CAME-8bit相比AdamW-32bit減少了約8倍的記憶體使用,同時保持訓練的穩定性。
該最佳化器不僅在預訓練階段效果顯著,在單GPU微調場景中更是意義非凡。用RTX 4090這樣的消費級GPU,就能輕鬆微調SANA 4.8B。
研究揭示了高效擴充套件不僅僅依賴於增加模型容量。透過充分利用小模型的知識,並設計模型的增長-剪枝,更高的生成質量並不一定需要更大的模型。
SANA 1.5 評估結果
實驗表明,SANA 1.5的訓練收斂速度比傳統方法(擴大規模並從頭開始訓練)快2.5倍。
訓練擴充套件策略將GenEval分數從0.66提升至0.72,並透過推理擴充套件將其進一步提高至0.80,在GenEval基準測試中達到了最先進的效能。
模型增長
將SANA-4.8B與當前最先進的文字生成影象方法進行了比較,結果如表所示。
從SANA-1.6B到4.8B的擴充套件帶來了顯著的改進:GenEval得分提升0.06(從0.66增加到0.72),FID降低0.34(從5.76降至5.42),DPG得分提升0.2(從84.8增加到85.0)。
和當前最先進的方法相比,SANA-4.8B模型的引數數量少很多,卻能達到和大模型一樣甚至更好的效果。
SANA-4.8B的GenEval得分為0.72,接近Playground v3的0.76。
在執行速度上,SANA-4.8B的延遲比FLUX-dev(23.0秒)低5.5倍;吞吐量為0.26樣本/秒,是FLUX-dev(0.04樣本/秒)的6.5倍,這使得SANA-4.8B在實際應用中更具優勢。
模型剪枝
為了和SANA 1.0(1.6B)公平比較,此次訓練的SANA 1.5(4.8B)模型,沒有用高質量資料做監督微調。
所有結果都是針對512×512尺寸的影象評估得出的。經過修剪和微調的模型,僅用較低的計算成本,得分就達到了0.672,超過了從頭訓練模型的0.664。
推理時擴充套件
將推理擴充套件應用於SANA 1.5(4.8B)模型,並在GenEval基準上與其他大型影象生成模型進行了比較。
透過從2048張生成的影象中選擇樣本,經過推理擴充套件的模型在整體準確率上比單張影象生成提高了8%,在「顏色」「位置」和「歸屬」子任務上提升明顯。
不僅如此,藉助推理時擴充套件,SANA 1.5(4.8B)模型的整體準確率比Playground v3 (24B)高4%。
結果表明,即使模型容量有限,提高推理效率,也能提升模型生成影象的質量和準確性。
SANA:超高效文生圖
在這裡介紹一下之前的SANA工作。
SANA是一個超高效的文字生成影象框架,能生成高達4096×4096解析度的影象,不僅畫質清晰,還能讓影象和輸入文字精準匹配,而且生成速度超快,在膝上型電腦的GPU上就能執行。
SANA為何如此強大?這得益於它的創新設計:
深度壓縮自動編碼器:傳統自動編碼器壓縮影象的能力有限,一般只能壓縮8倍。而SANA的自動編碼器能達到32倍壓縮,大大減少了潛在tokens數量,計算效率也就更高了。
線性DiT:SANA用線性注意力替換了DiT中的標準注意力。在處理高解析度影象時,速度更快,還不會降低影象質量。
僅解碼文字編碼器:SANA不用T5做文字編碼器了,而是採用現代化的小型僅解碼大模型。同時,透過上下文學習,設計出更貼合實際需求的指令,讓生成的影象和輸入文字對應得更好。
高效訓練與取樣:SANA提出了Flow-DPM-Solver方法,減少了取樣步驟。再配合高效的字幕標註與選取,讓模型更快收斂。
經過這些最佳化,SANA-0.6B表現十分出色。
它生成影象的質量和像Flux-12B這樣的現代大型擴散模型差不多,但模型體積縮小了20倍,資料處理能力卻提升了100倍以上。
SANA-0.6B執行要求不高,在只有16GB視訊記憶體的筆記本GPU上就能執行,生成一張1024×1024解析度的影象,用時不到1秒。
這意味著,創作者們用普通的膝上型電腦,就能輕鬆製作高質量影象,大大降低了內容創作的成本。
研究者提出新的深度壓縮自動編碼器,將壓縮比例提升到32倍,和壓縮比例為8倍的自動編碼器相比,F32自動編碼器生成的潛在tokens減少了16倍。
這一改進對於高效訓練和超高解析度影象生成,至關重要。
研究者提出一種全新的線性DiT,用線性注意力替代傳統的二次複雜度注意力,將計算複雜度從原本的O(N²) 降低至O(N)。另一方面,在MLP層引入3×3深度可分卷積,增強潛在tokens的區域性資訊。
在生成效果上,線性注意力與傳統注意力相當,在生成4K影象時,推理延遲降低了1.7倍。Mix-FFN結構讓模型無需位置編碼,也能生成高質量影象,這讓它成為首個無需位置嵌入的DiT變體。
在文字編碼器的選擇上,研究者選用了僅解碼的小型大語言模型Gemma,以此提升對提示詞的理解與推理能力。相較於CLIP和T5,Gemma在文字理解和指令執行方面表現更為出色。
為充分發揮Gemma的優勢,研究者最佳化訓練穩定性,設計複雜人類指令,藉助Gemma的上下文學習能力,進一步提高了影象與文字的匹配質量。
研究者提出一種自動標註與訓練策略,藉助多個視覺語言模型(VLM)生成多樣化的重新描述文字。然後,運用基於CLIPScore的策略,篩選出CLIPScore較高的描述,以此增強模型的收斂性和對齊效果。
在推理環節,相較於Flow-Euler-Solver,Flow-DPM-Solver將推理步驟從28-50步縮減至14-20步,不僅提升了速度,生成效果也更為出色。
參考資料:
https://huggingface.co/papers/2501.18427
https://x.com/xieenze_jr/status/1885510823767875799
https://nvlabs.github.io/SANA/