Skip to content

Commit 46a3520

Browse files
authored
feat: add association rendering function (#141)
* feat: add association rendering function * feat: update association chart rendering and utility constants * refactor: replace map with forEach in association chart rendering
1 parent 63e402a commit 46a3520

File tree

10 files changed

+205
-35
lines changed

10 files changed

+205
-35
lines changed

__tests__/charts/utils/paths.spec.ts

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -134,44 +134,53 @@ describe('paths', () => {
134134
const arrowheadWidth = 3;
135135

136136
it('should create an arrow generator function', () => {
137-
const arrowGenerator = arrow(mockXScale, mockYScale, height, arrowheadLength, arrowheadWidth);
137+
const arrowGenerator = arrow(mockXScale, mockYScale, arrowheadLength, arrowheadWidth);
138138
expect(typeof arrowGenerator).toBe('function');
139139
});
140140

141141
it('should generate correct SVG path for horizontal arrow', () => {
142-
const arrowGenerator = arrow(mockXScale, mockYScale, height, arrowheadLength, arrowheadWidth);
142+
const arrowGenerator = arrow(mockXScale, mockYScale, arrowheadLength, arrowheadWidth);
143143
const startData = { index: 0, value: 10 };
144144
const endData = { index: 2, value: 10 };
145145
const result = arrowGenerator(startData, endData);
146146

147-
// Should contain move to start, line to base, and arrowhead
148-
expect(result).toContain('M5 50'); // Start point
149-
expect(result).toContain('L20 50'); // Base point
150-
expect(result).toContain('M25 50'); // End point
147+
// Calculate expected coordinates:
148+
// startX = mockXScale(0) = 0, startY = mockYScale(10) = 50
149+
// endX = mockXScale(2) = 20, endY = mockYScale(10) = 50
150+
// baseX = endX - arrowheadLength * Math.cos(0) = 20 - 5 = 15
151+
expect(result).toContain('M0 50'); // Start point
152+
expect(result).toContain('L15 50'); // Line to base point
153+
expect(result).toContain('M20 50'); // End point
151154
});
152155

153156
it('should generate correct SVG path for vertical arrow', () => {
154-
const arrowGenerator = arrow(mockXScale, mockYScale, height, arrowheadLength, arrowheadWidth);
157+
const arrowGenerator = arrow(mockXScale, mockYScale, arrowheadLength, arrowheadWidth);
155158
const startData = { index: 0, value: 10 };
156159
const endData = { index: 0, value: 20 };
157160
const result = arrowGenerator(startData, endData);
158161

159-
expect(result).toContain('M5 50'); // Start point
160-
expect(result).toContain('M5 0'); // End point
162+
// Calculate expected coordinates:
163+
// startX = mockXScale(0) = 0, startY = mockYScale(10) = 50
164+
// endX = mockXScale(0) = 0, endY = mockYScale(20) = 100
165+
expect(result).toContain('M0 50'); // Start point
166+
expect(result).toContain('M0 100'); // End point
161167
});
162168

163169
it('should generate correct SVG path for diagonal arrow', () => {
164-
const arrowGenerator = arrow(mockXScale, mockYScale, height, arrowheadLength, arrowheadWidth);
170+
const arrowGenerator = arrow(mockXScale, mockYScale, arrowheadLength, arrowheadWidth);
165171
const startData = { index: 0, value: 10 };
166172
const endData = { index: 1, value: 20 };
167173
const result = arrowGenerator(startData, endData);
168174

169-
expect(result).toContain('M5 50'); // Start point
170-
expect(result).toContain('M15 0'); // End point
175+
// Calculate expected coordinates:
176+
// startX = mockXScale(0) = 0, startY = mockYScale(10) = 50
177+
// endX = mockXScale(1) = 10, endY = mockYScale(20) = 100
178+
expect(result).toContain('M0 50'); // Start point
179+
expect(result).toContain('M10 100'); // End point
171180
});
172181

173182
it('should use default arrowhead dimensions when not provided', () => {
174-
const arrowGenerator = arrow(mockXScale, mockYScale, height);
183+
const arrowGenerator = arrow(mockXScale, mockYScale);
175184
const startData = { index: 0, value: 10 };
176185
const endData = { index: 1, value: 10 };
177186
const result = arrowGenerator(startData, endData);
@@ -181,22 +190,27 @@ describe('paths', () => {
181190
});
182191

183192
it('should handle same start and end points', () => {
184-
const arrowGenerator = arrow(mockXScale, mockYScale, height, arrowheadLength, arrowheadWidth);
193+
const arrowGenerator = arrow(mockXScale, mockYScale, arrowheadLength, arrowheadWidth);
185194
const startData = { index: 1, value: 10 };
186195
const endData = { index: 1, value: 10 };
187196
const result = arrowGenerator(startData, endData);
188197

189-
expect(result).toContain('M15 50'); // Same start and end
198+
// Calculate expected coordinates:
199+
// startX = endX = mockXScale(1) = 10, startY = endY = mockYScale(10) = 50
200+
expect(result).toContain('M10 50'); // Same start and end
190201
});
191202

192203
it('should handle negative values', () => {
193-
const arrowGenerator = arrow(mockXScale, mockYScale, height, arrowheadLength, arrowheadWidth);
204+
const arrowGenerator = arrow(mockXScale, mockYScale, arrowheadLength, arrowheadWidth);
194205
const startData = { index: 0, value: -5 };
195206
const endData = { index: 1, value: 5 };
196207
const result = arrowGenerator(startData, endData);
197208

198-
expect(result).toContain('M5 125'); // Start point with negative value
199-
expect(result).toContain('M15 75'); // End point
209+
// Calculate expected coordinates:
210+
// startX = mockXScale(0) = 0, startY = mockYScale(-5) = -25
211+
// endX = mockXScale(1) = 10, endY = mockYScale(5) = 25
212+
expect(result).toContain('M0 -25'); // Start point with negative value
213+
expect(result).toContain('M10 25'); // End point
200214
});
201215
});
202216
});

example/main.tsx

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
renderSeasonalityChart,
1212
renderAnomalyChart,
1313
renderDistribution,
14+
renderAssociationChart,
1415
} from '../src/charts';
1516

1617
const dimensionValueDescriptor: SpecificEntityPhraseDescriptor = {
@@ -146,3 +147,11 @@ for (let i = 0; i < SAMPLE_SIZE * 0.2; i++) {
146147
renderChart(renderDistribution)({
147148
data: distributionData,
148149
});
150+
151+
renderChart(renderAssociationChart)({
152+
data: [
153+
{ x: -1, y: -2 },
154+
{ x: 2, y: 2 },
155+
{ x: 3, y: 20 },
156+
],
157+
});

src/charts/association/index.ts

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import {
2+
createSvg,
3+
extent,
4+
getElementFontSize,
5+
scaleLinear,
6+
SCALE_ADJUST,
7+
arrow,
8+
LINE_STROKE_COLOR,
9+
linearRegression,
10+
ARROW_FILL_COLOR,
11+
HIGHLIGHT_COLOR,
12+
} from '../utils';
13+
14+
interface Point {
15+
x: number;
16+
y: number;
17+
}
18+
19+
export interface AssociationChartConfig {
20+
data: Point[];
21+
}
22+
23+
export const renderAssociationChart = (container: Element, config: AssociationChartConfig): void => {
24+
const { data = [] } = config;
25+
if (!data.length) return;
26+
27+
if (data.length < 2) throw new Error('data must contain at least 2 points');
28+
29+
const chartSize = getElementFontSize(container);
30+
31+
const height = chartSize;
32+
const width = chartSize * 2;
33+
34+
const svg = createSvg(container, width, height);
35+
36+
const xValueExtent = extent(data.map((d) => d.x));
37+
const yValueExtent = extent(data.map((d) => d.y));
38+
39+
const xValueDomain: [number, number] = [
40+
xValueExtent[0] > 0 ? 0 : xValueExtent[0],
41+
xValueExtent[1] < 0 ? 0 : xValueExtent[1],
42+
];
43+
const yValueDomain: [number, number] = [
44+
yValueExtent[0] > 0 ? 0 : yValueExtent[0],
45+
yValueExtent[1] < 0 ? 0 : yValueExtent[1],
46+
];
47+
48+
const xScale = scaleLinear(xValueDomain, [SCALE_ADJUST, width - SCALE_ADJUST]);
49+
const yScale = scaleLinear(yValueDomain, [height - SCALE_ADJUST, SCALE_ADJUST]);
50+
51+
const zeroXPosition = xScale(0);
52+
const zeroYPosition = yScale(0);
53+
54+
const linearRegressionResult = linearRegression(data);
55+
56+
const tagData: Point[] = data.map((d) => {
57+
const tag = linearRegressionResult.k * d.x + linearRegressionResult.b;
58+
59+
return {
60+
x: d.x,
61+
y: tag,
62+
};
63+
});
64+
65+
svg
66+
.append('line')
67+
.attr('x1', zeroXPosition)
68+
.attr('y1', 0)
69+
.attr('x2', zeroXPosition)
70+
.attr('y2', height)
71+
.attr('stroke', LINE_STROKE_COLOR);
72+
73+
svg
74+
.append('line')
75+
.attr('x1', 0)
76+
.attr('y1', zeroYPosition)
77+
.attr('x2', width)
78+
.attr('y2', zeroYPosition)
79+
.attr('stroke', LINE_STROKE_COLOR);
80+
81+
const arrowPath = arrow(xScale, yScale);
82+
83+
svg
84+
.append('path')
85+
.attr(
86+
'd',
87+
arrowPath(
88+
{ index: tagData[0].x, value: tagData[0].y },
89+
{ index: tagData[tagData.length - 1].x, value: tagData[tagData.length - 1].y },
90+
),
91+
)
92+
.attr('stroke', ARROW_FILL_COLOR)
93+
.attr('fill', ARROW_FILL_COLOR);
94+
95+
data.forEach((d) => {
96+
svg.append('circle').attr('cx', xScale(d.x)).attr('cy', yScale(d.y)).attr('r', 1).attr('fill', HIGHLIGHT_COLOR);
97+
});
98+
};

src/charts/difference/index.tsx

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import { renderRankChart } from '../rank';
2-
import { arrow, getElementFontSize } from '../utils';
3-
4-
const ARROW_FILL_COLOR = 'rgb(250, 84, 28)';
2+
import { ARROW_FILL_COLOR, arrow } from '../utils';
53

64
export interface DifferenceChartConfig {
75
data: number[];
@@ -11,15 +9,15 @@ export const renderDifferenceChart = (container: Element, config: DifferenceChar
119
const { data = [] } = config;
1210
if (!data.length) return;
1311

14-
const chartSize = getElementFontSize(container);
15-
1612
renderRankChart(container, { data }, (svg, xScale, yScale) => {
1713
// draw arrow on rank chart
18-
const height = chartSize;
19-
const arrowPath = arrow(xScale, yScale, height);
14+
const arrowPath = arrow(xScale, yScale);
2015
svg
2116
.append('path')
22-
.attr('d', arrowPath({ index: 0, value: data[0] }, { index: data.length - 1, value: data[data.length - 1] }))
17+
.attr(
18+
'd',
19+
arrowPath({ index: 0 + 0.5, value: data[0] }, { index: data.length - 1 + 0.5, value: data[data.length - 1] }),
20+
)
2321
.attr('stroke', ARROW_FILL_COLOR)
2422
.attr('fill', ARROW_FILL_COLOR);
2523
});

src/charts/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ export { renderDifferenceChart, type DifferenceChartConfig } from './difference'
55
export { renderSeasonalityChart, type SeasonalityChartConfig } from './seasonality';
66
export { renderAnomalyChart, type AnomalyChartConfig } from './anomaly';
77
export { renderDistribution, type DistributionConfig } from './distribution';
8+
export { renderAssociationChart, type AssociationChartConfig } from './association';

src/charts/rank/index.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ export const renderRankChart = (
6464
const xScale = scaleLinear([0, data.length - 1], [0, width - barWidth]);
6565
const maxValue = Math.max(...data);
6666
const minValue = Math.min(...data);
67-
const yScale = scaleLinear([minValue, maxValue], [SCALE_ADJUST, height - SCALE_ADJUST]);
67+
const yScale = scaleLinear([minValue, maxValue], [height - SCALE_ADJUST, SCALE_ADJUST]);
6868

6969
// Draw bars
7070
data.forEach((value, index) => {
@@ -75,9 +75,9 @@ export const renderRankChart = (
7575
.append('rect')
7676
.attr('class', 'bar')
7777
.attr('x', x)
78-
.attr('y', height - y)
78+
.attr('y', y)
7979
.attr('width', barWidth)
80-
.attr('height', y)
80+
.attr('height', height - y)
8181
.attr('fill', BAR_FILL_COLOR)
8282
.style('cursor', 'pointer');
8383
});

src/charts/utils/const.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ export const SCALE_ADJUST = 2;
22
export const WIDTH_MARGIN = 0.5;
33
export const LINE_STROKE_COLOR = '#5B8FF9';
44
export const HIGHLIGHT_COLOR = '#FF8C00';
5+
export const ARROW_FILL_COLOR = 'rgb(250, 84, 28)';
56
export const OPACITY = 0.6;

src/charts/utils/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ export { line, area, arc, arrow, createCurvePath } from './paths';
55
export { getElementFontSize, DEFAULT_FONT_SIZE } from './getElementFontSize';
66
export { createSvg } from './createSvg';
77
export { getSafeDomain } from './getSafeDomain';
8-
export { SCALE_ADJUST, WIDTH_MARGIN, LINE_STROKE_COLOR, HIGHLIGHT_COLOR, OPACITY } from './const';
98
export { max, extent, mean } from './data';
9+
export { linearRegression } from './linearRegression';
10+
export * from './const';
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
export interface LinearRegressionResult {
2+
k: number;
3+
b: number;
4+
}
5+
6+
/**
7+
* use least square method to fit a line to the points
8+
* @param points points to fit
9+
* @returns {k, b} the slope and intercept of the line
10+
*/
11+
export function linearRegression(points: { x: number; y: number }[]): LinearRegressionResult {
12+
if (!points || points.length < 2) {
13+
throw new Error('Points array must contain at least two points for linear regression.');
14+
}
15+
16+
const n = points.length;
17+
let sumX = 0;
18+
let sumY = 0;
19+
20+
for (const p of points) {
21+
sumX += p.x;
22+
sumY += p.y;
23+
}
24+
25+
const meanX = sumX / n;
26+
const meanY = sumY / n;
27+
28+
let numerator = 0;
29+
let denominator = 0;
30+
31+
for (const p of points) {
32+
const diffX = p.x - meanX;
33+
const diffY = p.y - meanY;
34+
35+
numerator += diffX * diffY;
36+
denominator += diffX * diffX;
37+
}
38+
39+
if (denominator === 0) {
40+
return { k: 0, b: meanY };
41+
}
42+
43+
const k = numerator / denominator;
44+
45+
const b = meanY - k * meanX;
46+
47+
return { k, b };
48+
}

src/charts/utils/paths.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ export const arc = (radius: number) => {
7979
* @param yScale - Scale function for Y coordinates
8080
* @returns Function that generates SVG path data from start and end points
8181
*/
82-
export const arrow = (xScale: Scale, yScale: Scale, height: number, arrowheadLength = 2, arrowheadWidth = 2) => {
82+
export const arrow = (xScale: Scale, yScale: Scale, arrowheadLength = 2, arrowheadWidth = 2) => {
8383
return (startData: { index: number; value: number }, endData: { index: number; value: number }): string => {
84-
const startX = xScale(startData.index + 0.5);
85-
const startY = height - yScale(startData.value);
86-
const endX = xScale(endData.index + 0.5);
87-
const endY = height - yScale(endData.value);
84+
const startX = xScale(startData.index);
85+
const startY = yScale(startData.value);
86+
const endX = xScale(endData.index);
87+
const endY = yScale(endData.value);
8888

8989
const deltaX = endX - startX;
9090
const deltaY = endY - startY;

0 commit comments

Comments
 (0)