Skip to content

Temporal Fusion Transformer model實作,目的為熟悉特徵工程、建模流程及預測結果視覺化,並深入研究模型架構與內部邏輯,強化對新模型的理解能力。當前仍在優化中

Notifications You must be signed in to change notification settings

Lethe4564518/TemporalFusionTransformer-model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Temporal Fusion Transformer 時間序列預測模型實作

專案說明

此專案實作了 Temporal Fusion Transformer (TFT) 模型,用於台灣股市月報酬率的預測。 資料來源為 TEJ 中台灣市值前 96 的上市公司,並涵蓋了 10 年的各式月資料。

核心功能

  • 使用 Temporal Fusion Transformer 架構進行多公司的時間序列預測
  • 實現交叉驗證評估模型性能 (GroupKFold)
  • 具備數據前處理流程,包括特徵工程及缺失值處理
  • 針對股票資料進行技術指標計算
  • 盡可能地利用 GPU 加速
  • 模型評估及結果視覺化
  • 方便調整的超參數區塊

技術特點

  • 深度學習框架:PyTorch + PyTorch Lightning
  • 預測模型:Temporal Fusion Transformer (使用 pytorch-forecasting )
  • 關鍵技術
    • 注意力機制
    • GroupKFold 交叉驗證
    • 梯度累積
    • 早停機制

模型架構

TFT 模型結合了 LSTM 和 Transformer 的優點,特別適合 panel data 預測:

  • 變量選擇:識別重要特徵
  • 解釋性:注意力機制提供結果可解釋性
  • 多變量輸入:同時處理分類和連續特徵

環境要求

  • Python 3.8+
  • PyTorch 1.10+
  • pytorch-forecasting
  • PyTorch Lightning
  • pandas, numpy, matplotlib, scikit-learn
  • 與 GPU 相符之 CUDA 版本

實驗結果

模型在交叉驗證評估中取得了良好的表現:

  • 平均 RMSE: 約 0.0922
  • 預測結果散佈圖顯示全部 Fold 綜合來看模型能初步捕捉月報酬率趨勢
  • 極端值預測、特徵選擇及超參數設置上仍有許多改進空間

使用說明

  1. 安裝所需的套件
  2. 檢查資料集 (data.csv) 是否與執行檔 (Transformer.py) 在同一資料夾
  3. 運行 Transformer.py 開始訓練
  4. 查看 evaluation_results.png 了解模型表現

未來改進方向

  • 嘗試不同的損失函數組合
  • 增加更多金融領域的特徵工程
  • 深入調整模型內部架構以更好地捕捉報酬率的特性
  • 使用集成學習以提高預測穩定性
  • 利用分散式訓練提高訓練效率
  • 添加更多預測結果及模型表現可視化

About

Temporal Fusion Transformer model實作,目的為熟悉特徵工程、建模流程及預測結果視覺化,並深入研究模型架構與內部邏輯,強化對新模型的理解能力。當前仍在優化中

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages