-
Notifications
You must be signed in to change notification settings - Fork 43
Description
This is the relevant code I wrote myself, but I can't get the effect shown in figure 4 of the paper.
import numpy as np
import seaborn as sns
from cuml.manifold import TSNE
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_style("ticks")
##Learned representations from CoST
dataset_path = 'xxxx'
read_data = pd.read_csv(dataset_path+'representations.csv')
##Top 160 are trends representation
trends = np.array(read_data.iloc[:,1:161])
##Last 160 are seasons representation
seasons = np.array(read_data.iloc[:,161:])
##Perform T-SNE on the trends with a fixed season
trend_tsne = TSNE(n_components=2).fit_transform(trends)
seasonal_tsne = TSNE(n_components=2).fit_transform(seasons)
fig, axs = plt.subplots(2, 1, figsize=(8, 12))
##After fixing a certain 160-dimensional seasonal item, draw the two trend item cluster pictures
for i in range(2):
sns.scatterplot(x=trend_tsne[:, 0], y=trend_tsne[:, 1], hue=seasons[:, i+1], ax=axs[0], palette=['yellow', 'purple'])
axs[0].set_title('Fixed Seasonal Item {}'.format(i+1))
axs[0].set_xlabel('TSNE Dimension 1')
axs[0].set_ylabel('TSNE Dimension 2')
##After fixing a certain trend item in the first 160 dimensions, draw the clustering pictures of the three seasonal items
for i in range(3):
sns.scatterplot(x=seasonal_tsne[:, 0], y=seasonal_tsne[:, 1], hue=trends[:, i+1], ax=axs[1], palette=['yellow', 'blue', 'purple'])
axs[1].set_title('Fixed Trend Item {}'.format(i+1))
axs[1].set_xlabel('TSNE Dimension 1')
axs[1].set_ylabel('TSNE Dimension 2')
plt.tight_layout()
plt.show()
Any help will be appreciated.