11import os
22import dash
3- from dash import dcc , html
3+ from dash import dcc , html , callback_context
44import plotly .express as px
5- from dash .dependencies import Input , Output
5+ from dash .dependencies import Input , Output , State
66import pandas as pd
7+ import math
78
89# App settings
9- CSV_FILE = '1M_points .csv'
10+ CSV_FILE = 'main .csv'
1011PLOT_DIMS = 2
12+ IMAGES_PER_PAGE = 12
13+ INITIAL_PAGE_NUM = 1
14+
15+ # Initial setup
16+ page_num = INITIAL_PAGE_NUM
1117
12- # Read in the data
1318df = pd .read_csv (f"./embeddings_data/{ CSV_FILE } " )
19+ PLOT_TITLE = f"Image Embeddings Visualization: { df .shape [0 ]} Images"
1420
15- # Check if the num rows > 500k, assert to many rows for 3d
1621if df .shape [0 ] > 500000 :
1722 assert PLOT_DIMS == 2 , "Too many rows for 3D plot. Set PLOT_DIMS to 2."
1823
19- # Create a Plotly 3D scatter plot with color coding by class label
2024if PLOT_DIMS == 2 :
21- fig = px .scatter (df , x = 'x' , y = 'y' , color = 'label' , hover_data = ['image_path' ], opacity = 0.5 , render_mode = 'webgl' )
25+ fig = px .scatter (df , x = 'x' , y = 'y' , color = 'label' , hover_data = ['image_path' ], opacity = 0.75 , symbol = 'label' , render_mode = 'webgl' )
2226elif PLOT_DIMS == 3 :
2327 fig = px .scatter_3d (df , x = 'x' , y = 'y' , z = 'z' , color = 'label' , hover_data = ['image_path' ])
2428else :
2529 raise ValueError ("Invalid number of dimensions. Choose 2 or 3." )
2630
27- # Set up Dash app
28- app = dash .Dash (__name__ )
31+ fig .update_layout (title_text = PLOT_TITLE , title_x = 0.5 , clickmode = 'event+select' )
2932
33+ # Set up Dash app layout
34+ app = dash .Dash (__name__ )
35+ server = app .server
3036app .layout = html .Div ([
31- html .H1 (f"Image Embeddings Visualization w/ { df .shape [0 ]} Images" ),
32- html .P ("Click on a point in the scatter plot to display the image and its class label. Click on the legend to toggle classes." ),
33- html .Div (className = 'container' , children = [
34- dcc .Graph (id = 'scatter-plot' , figure = fig ),
35- html .Div (id = 'image-display' , children = [
36- html .Img (id = 'selected-image' , src = '' ),
37- html .P (id = 'selected-label' )
38- ])
39- ])
37+ html .H1 ("SUPA Embeddings Visualizer" ),
38+ html .P ("Click on a point to see the image and class label or use the lasso/box-select tool to select multiple points." ),
39+ html .P ("Click on the legend to toggle classes on/off. Hold down shift while clicking on points to cherry pick multiple points" ),
40+ dcc .Graph (id = 'scatter-plot' , figure = fig ),
41+ html .H2 ("Selected Points" ),
42+ html .Div (id = 'select-data' ),
43+ html .Div ([
44+ html .Button ('<' , id = 'decrement-button' , n_clicks = 0 ),
45+ html .P (id = 'page-num-display' , children = f'{ page_num } ' ),
46+ html .Button ('>' , id = 'increment-button' , n_clicks = 0 ),
47+ ], id = 'pagination' ),
48+ html .Div (id = 'hidden-page-num' , style = {'display' : 'none' }, children = f'{ page_num } ' )
4049])
4150
4251@app .callback (
43- [Output ('selected-image' , 'src' ),
44- Output ('selected-label' , 'children' )],
45- Input ('scatter-plot' , 'clickData' )
52+ Output ('select-data' , 'children' ),
53+ [Input ('scatter-plot' , 'clickData' ),
54+ Input ('scatter-plot' , 'selectedData' ),
55+ Input ('hidden-page-num' , 'children' )]
56+ )
57+ def display_images (clickData , selectedData , page_num ):
58+ items = []
59+ page_num = int (page_num )
60+
61+ start_index = (page_num - 1 ) * IMAGES_PER_PAGE
62+ end_index = start_index + IMAGES_PER_PAGE
63+
64+ # Handle clickData for a single point
65+ if clickData :
66+ items = []
67+ image_url = clickData ['points' ][0 ]['customdata' ][0 ]
68+ image_path = image_url .replace ('./assets/' , '' )
69+ label = os .path .basename (os .path .dirname (image_url ))
70+ items .append (
71+ html .Div ([
72+ html .Img (src = app .get_asset_url (image_path ), style = {'height' : '150px' , 'margin' : '5px' }),
73+ html .P (f"Class: { label } " , style = {'text-align' : 'center' })
74+ ], style = {'display' : 'inline-block' , 'margin' : '10px' })
75+ )
76+
77+ # Handle selectedData for multiple points
78+ if selectedData :
79+ items = []
80+ for point in selectedData ['points' ][start_index :end_index ]:
81+ image_url = point ['customdata' ][0 ]
82+ image_path = image_url .replace ('./assets/' , '' )
83+ label = os .path .basename (os .path .dirname (image_url ))
84+ items .append (
85+ html .Div ([
86+ html .Img (src = app .get_asset_url (image_path ), style = {'height' : '150px' , 'margin' : '5px' }),
87+ html .P (f"Class: { label } " , style = {'text-align' : 'center' })
88+ ], style = {'display' : 'inline-block' , 'margin' : '10px' })
89+ )
90+
91+ # Handle the case when no points are selected or clicked
92+ if not items :
93+ items = [html .Div ([
94+ html .Img (src = 'https://placedog.net/640/224?random' ),
95+ html .P ('No Points Selected. A wild Doge appears!' )
96+ ])]
97+
98+ return items
99+
100+ # Pagination
101+ @app .callback (
102+ [Output ('page-num-display' , 'children' ),
103+ Output ('hidden-page-num' , 'children' )],
104+ [Input ('increment-button' , 'n_clicks' ),
105+ Input ('decrement-button' , 'n_clicks' ),
106+ Input ('scatter-plot' , 'relayoutData' )],
107+ [State ('hidden-page-num' , 'children' ),
108+ State ('scatter-plot' , 'selectedData' )]
46109)
47- def display_image_and_label (clickData ):
48- if clickData is None :
49- return 'https://placedog.net/640/224?random' , 'Wild Doge appears!'
50- # Get the index of the clicked point
51- image_url = clickData ['points' ][0 ]['customdata' ][0 ]
52- # Get the corresponding image path and label
53- image_path = image_url .replace ('./assets/' , '' )
54- label = os .path .basename (os .path .dirname (image_url ))
55- return app .get_asset_url (image_path ), f"Class: { label } "
110+ def update_page_num (increment_clicks , decrement_clicks , relayoutData , page_num , selectedData ):
111+ page_num = int (page_num )
112+ ctx = callback_context
113+
114+ if not ctx .triggered :
115+ return f'{ page_num } ' , f'{ page_num } '
116+ else :
117+ button_id = ctx .triggered [0 ]['prop_id' ].split ('.' )[0 ]
118+
119+ # Check if double-click occurred (reset page number to 1)
120+ if relayoutData and 'xaxis.range' in relayoutData and 'yaxis.range' in relayoutData :
121+ if 'autosize' in relayoutData :
122+ page_num = INITIAL_PAGE_NUM
123+ else :
124+ page_num = int (page_num )
125+
126+ # Calculate the total number of pages
127+ total_items = len (selectedData ['points' ]) if selectedData else 0
128+ total_pages = math .ceil (total_items / IMAGES_PER_PAGE )
129+
130+ if button_id == 'increment-button' and increment_clicks and page_num < total_pages :
131+ page_num += 1
132+ elif button_id == 'decrement-button' and decrement_clicks and page_num > 1 :
133+ page_num -= 1
134+ else :
135+ page_num = 1
136+
137+ return f'{ page_num } ' , f'{ page_num } '
138+
139+ # Button Disable
140+ @app .callback (
141+ [Output ('decrement-button' , 'disabled' ),
142+ Output ('increment-button' , 'disabled' )],
143+ [Input ('hidden-page-num' , 'children' ),
144+ Input ('scatter-plot' , 'selectedData' )],
145+ [State ('scatter-plot' , 'selectedData' )]
146+ )
147+ def update_button_disabled (page_num , selectedData , stateSelectedData ):
148+ selectedData = selectedData or stateSelectedData
149+ total_items = len (selectedData ['points' ]) if selectedData else 0
150+ total_pages = math .ceil (total_items / IMAGES_PER_PAGE )
151+ page_num = int (page_num )
152+
153+ if total_items <= 12 :
154+ return True , True
155+
156+ decrement_disabled = page_num <= 1
157+ increment_disabled = page_num >= total_pages
158+
159+ return decrement_disabled , increment_disabled
56160
57161if __name__ == '__main__' :
58- app .run_server (debug = True )
162+ app .run_server (debug = True )
0 commit comments