Skip to content

Commit ad8d3b7

Browse files
authored
Merge pull request #3 from Qamil-Mirza/release/0.1.0
Merge Release V0.1.0 To Main
2 parents 64cb15c + d7b2d94 commit ad8d3b7

File tree

6 files changed

+750270
-36
lines changed

6 files changed

+750270
-36
lines changed

app.py

Lines changed: 135 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,162 @@
11
import os
22
import dash
3-
from dash import dcc, html
3+
from dash import dcc, html, callback_context
44
import plotly.express as px
5-
from dash.dependencies import Input, Output
5+
from dash.dependencies import Input, Output, State
66
import pandas as pd
7+
import math
78

89
# App settings
9-
CSV_FILE = '1M_points.csv'
10+
CSV_FILE = 'main.csv'
1011
PLOT_DIMS = 2
12+
IMAGES_PER_PAGE = 12
13+
INITIAL_PAGE_NUM = 1
14+
15+
# Initial setup
16+
page_num = INITIAL_PAGE_NUM
1117

12-
# Read in the data
1318
df = pd.read_csv(f"./embeddings_data/{CSV_FILE}")
19+
PLOT_TITLE = f"Image Embeddings Visualization: {df.shape[0]} Images"
1420

15-
# Check if the num rows > 500k, assert to many rows for 3d
1621
if df.shape[0] > 500000:
1722
assert PLOT_DIMS == 2, "Too many rows for 3D plot. Set PLOT_DIMS to 2."
1823

19-
# Create a Plotly 3D scatter plot with color coding by class label
2024
if PLOT_DIMS == 2:
21-
fig = px.scatter(df, x='x', y='y', color='label', hover_data=['image_path'], opacity=0.5, render_mode='webgl')
25+
fig = px.scatter(df, x='x', y='y', color='label', hover_data=['image_path'], opacity=0.75, symbol='label', render_mode='webgl')
2226
elif PLOT_DIMS == 3:
2327
fig = px.scatter_3d(df, x='x', y='y', z='z', color='label', hover_data=['image_path'])
2428
else:
2529
raise ValueError("Invalid number of dimensions. Choose 2 or 3.")
2630

27-
# Set up Dash app
28-
app = dash.Dash(__name__)
31+
fig.update_layout(title_text=PLOT_TITLE, title_x=0.5, clickmode='event+select')
2932

33+
# Set up Dash app layout
34+
app = dash.Dash(__name__)
35+
server = app.server
3036
app.layout = html.Div([
31-
html.H1(f"Image Embeddings Visualization w/ {df.shape[0]} Images"),
32-
html.P("Click on a point in the scatter plot to display the image and its class label. Click on the legend to toggle classes."),
33-
html.Div(className='container', children=[
34-
dcc.Graph(id='scatter-plot', figure=fig),
35-
html.Div(id='image-display', children=[
36-
html.Img(id='selected-image', src=''),
37-
html.P(id='selected-label')
38-
])
39-
])
37+
html.H1("SUPA Embeddings Visualizer"),
38+
html.P("Click on a point to see the image and class label or use the lasso/box-select tool to select multiple points."),
39+
html.P("Click on the legend to toggle classes on/off. Hold down shift while clicking on points to cherry pick multiple points"),
40+
dcc.Graph(id='scatter-plot', figure=fig),
41+
html.H2("Selected Points"),
42+
html.Div(id='select-data'),
43+
html.Div([
44+
html.Button('<', id='decrement-button', n_clicks=0),
45+
html.P(id='page-num-display', children=f'{page_num}'),
46+
html.Button('>', id='increment-button', n_clicks=0),
47+
], id='pagination'),
48+
html.Div(id='hidden-page-num', style={'display': 'none'}, children=f'{page_num}')
4049
])
4150

4251
@app.callback(
43-
[Output('selected-image', 'src'),
44-
Output('selected-label', 'children')],
45-
Input('scatter-plot', 'clickData')
52+
Output('select-data', 'children'),
53+
[Input('scatter-plot', 'clickData'),
54+
Input('scatter-plot', 'selectedData'),
55+
Input('hidden-page-num', 'children')]
56+
)
57+
def display_images(clickData, selectedData, page_num):
58+
items = []
59+
page_num = int(page_num)
60+
61+
start_index = (page_num - 1) * IMAGES_PER_PAGE
62+
end_index = start_index + IMAGES_PER_PAGE
63+
64+
# Handle clickData for a single point
65+
if clickData:
66+
items = []
67+
image_url = clickData['points'][0]['customdata'][0]
68+
image_path = image_url.replace('./assets/', '')
69+
label = os.path.basename(os.path.dirname(image_url))
70+
items.append(
71+
html.Div([
72+
html.Img(src=app.get_asset_url(image_path), style={'height': '150px', 'margin': '5px'}),
73+
html.P(f"Class: {label}", style={'text-align': 'center'})
74+
], style={'display': 'inline-block', 'margin': '10px'})
75+
)
76+
77+
# Handle selectedData for multiple points
78+
if selectedData:
79+
items = []
80+
for point in selectedData['points'][start_index:end_index]:
81+
image_url = point['customdata'][0]
82+
image_path = image_url.replace('./assets/', '')
83+
label = os.path.basename(os.path.dirname(image_url))
84+
items.append(
85+
html.Div([
86+
html.Img(src=app.get_asset_url(image_path), style={'height': '150px', 'margin': '5px'}),
87+
html.P(f"Class: {label}", style={'text-align': 'center'})
88+
], style={'display': 'inline-block', 'margin': '10px'})
89+
)
90+
91+
# Handle the case when no points are selected or clicked
92+
if not items:
93+
items = [html.Div([
94+
html.Img(src='https://placedog.net/640/224?random'),
95+
html.P('No Points Selected. A wild Doge appears!')
96+
])]
97+
98+
return items
99+
100+
# Pagination
101+
@app.callback(
102+
[Output('page-num-display', 'children'),
103+
Output('hidden-page-num', 'children')],
104+
[Input('increment-button', 'n_clicks'),
105+
Input('decrement-button', 'n_clicks'),
106+
Input('scatter-plot', 'relayoutData')],
107+
[State('hidden-page-num', 'children'),
108+
State('scatter-plot', 'selectedData')]
46109
)
47-
def display_image_and_label(clickData):
48-
if clickData is None:
49-
return 'https://placedog.net/640/224?random', 'Wild Doge appears!'
50-
# Get the index of the clicked point
51-
image_url = clickData['points'][0]['customdata'][0]
52-
# Get the corresponding image path and label
53-
image_path = image_url.replace('./assets/', '')
54-
label = os.path.basename(os.path.dirname(image_url))
55-
return app.get_asset_url(image_path), f"Class: {label}"
110+
def update_page_num(increment_clicks, decrement_clicks, relayoutData, page_num, selectedData):
111+
page_num = int(page_num)
112+
ctx = callback_context
113+
114+
if not ctx.triggered:
115+
return f'{page_num}', f'{page_num}'
116+
else:
117+
button_id = ctx.triggered[0]['prop_id'].split('.')[0]
118+
119+
# Check if double-click occurred (reset page number to 1)
120+
if relayoutData and 'xaxis.range' in relayoutData and 'yaxis.range' in relayoutData:
121+
if 'autosize' in relayoutData:
122+
page_num = INITIAL_PAGE_NUM
123+
else:
124+
page_num = int(page_num)
125+
126+
# Calculate the total number of pages
127+
total_items = len(selectedData['points']) if selectedData else 0
128+
total_pages = math.ceil(total_items / IMAGES_PER_PAGE)
129+
130+
if button_id == 'increment-button' and increment_clicks and page_num < total_pages:
131+
page_num += 1
132+
elif button_id == 'decrement-button' and decrement_clicks and page_num > 1:
133+
page_num -= 1
134+
else:
135+
page_num = 1
136+
137+
return f'{page_num}', f'{page_num}'
138+
139+
# Button Disable
140+
@app.callback(
141+
[Output('decrement-button', 'disabled'),
142+
Output('increment-button', 'disabled')],
143+
[Input('hidden-page-num', 'children'),
144+
Input('scatter-plot', 'selectedData')],
145+
[State('scatter-plot', 'selectedData')]
146+
)
147+
def update_button_disabled(page_num, selectedData, stateSelectedData):
148+
selectedData = selectedData or stateSelectedData
149+
total_items = len(selectedData['points']) if selectedData else 0
150+
total_pages = math.ceil(total_items / IMAGES_PER_PAGE)
151+
page_num = int(page_num)
152+
153+
if total_items <= 12:
154+
return True, True
155+
156+
decrement_disabled = page_num <= 1
157+
increment_disabled = page_num >= total_pages
158+
159+
return decrement_disabled, increment_disabled
56160

57161
if __name__ == '__main__':
58-
app.run_server(debug=True)
162+
app.run_server(debug=True)

assets/header.css

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
body {
22
font-family: sans-serif;
33
height: 100vh;
4+
text-align: center;
45
}
56

6-
h1, p {
7+
h1, h4, p {
78
color: #333;
89
font-family: sans-serif;
910
}
@@ -13,15 +14,17 @@ h1, p {
1314
flex-direction: row;
1415
justify-content: center;
1516
width: 100%;
17+
margin-top: 5rem;
1618
}
1719

1820
#scatter-plot {
1921
display: flex;
2022
flex-direction: column;
2123
justify-content: center;
2224
align-items: center;
23-
width: 80%;
24-
height: 80vh;
25+
width: 100%;
26+
height: 100vh;
27+
margin-top: 2rem;
2528
}
2629

2730
#image-display {
@@ -39,4 +42,44 @@ h1, p {
3942
#selected-label {
4043
margin-top: 10px;
4144
font-size: 20px;
42-
}
45+
}
46+
47+
#pagination {
48+
display: flex;
49+
flex-direction: row;
50+
justify-content: center;
51+
align-items: center;
52+
}
53+
54+
button {
55+
padding: 10px 20px;
56+
background-color: #ffffff;
57+
color: hwb(0 20% 80%);
58+
border: 2px solid #4d4d4d;
59+
border-radius: 5px;
60+
cursor: pointer;
61+
font-size: 16px;
62+
}
63+
64+
button:disabled {
65+
background-color: #f2f2f2;
66+
color: #999;
67+
cursor: not-allowed;
68+
border: 2px solid #999;
69+
}
70+
71+
button:not(:disabled):hover {
72+
border: 2px solid #7B68EE;
73+
}
74+
75+
#decrement-button {
76+
margin-right: 1rem;
77+
}
78+
79+
#increment-button {
80+
margin-left: 1rem;
81+
}
82+
83+
#page-num-display {
84+
font-size: large;
85+
}

0 commit comments

Comments
 (0)