Skip to content

Commit 54f3eb7

Browse files
committed
refactor: simplify context creation and improve division logic in ndBinarySearch
1 parent 4878c3e commit 54f3eb7

File tree

1 file changed

+83
-125
lines changed

1 file changed

+83
-125
lines changed

src/nd.ts

Lines changed: 83 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -14,171 +14,126 @@ export type ShouldContinue<D extends number, T> = FixedLengthArray<
1414
D
1515
>;
1616

17-
const createContext = <
18-
D extends number,
19-
T,
20-
P extends (vector: NoInfer<Vector<D, T>>) => boolean,
21-
>(
22-
predicate: P,
23-
midpoint: FixedLengthArray<(always: T, never: T) => T, D>,
24-
shouldContinue: FixedLengthArray<(always: T, never: T) => boolean, D>,
25-
) => {
26-
/** Active components */
27-
const d = new Set(midpoint.keys());
28-
29-
const _continue = (
30-
always: ReadonlyVector<D, T>,
31-
never: ReadonlyVector<D, T>,
32-
) => {
33-
let result = 0;
34-
for (const i of d) {
35-
// biome-ignore lint/style/noNonNullAssertion: i is always valid index
36-
if (shouldContinue[i]!(always[i]!, never[i]!)) {
37-
result++;
38-
} else {
39-
// Deactivate component i if shouldContinue[i] is false (enough converged against component i)
40-
d.delete(i);
17+
export const createShouldContinue =
18+
<D extends number, T>(shouldContinue: ShouldContinue<D, T>) =>
19+
(division: Division<D, T>, components: Set<D>) => {
20+
const { always, never } = division;
21+
const result = new Set<D>();
22+
for (const i of components) {
23+
// Activate component i if shouldContinue[i] returns true (false means component i has converged sufficiently)
24+
if (shouldContinue[i](always[i], never[i])) {
25+
result.add(i);
4126
}
4227
}
43-
// Continue while active components exist
4428
return result;
4529
};
4630

47-
const _midpoint = (
48-
always: ReadonlyVector<D, T>,
49-
never: ReadonlyVector<D, T>,
50-
) =>
51-
midpoint.map((fn, i) =>
31+
const createMidpoint =
32+
<D extends number, T>(midpoint: Midpoint<D, T>) =>
33+
(division: Division<D, T>, components: Set<D>) => {
34+
const { always, never } = division;
35+
return midpoint.map((fn, i) =>
5236
// Apply the midpoint function only to active components
53-
// biome-ignore lint/style/noNonNullAssertion: i is always valid index
54-
d.has(i) ? fn(always[i]!, never[i]!) : always[i]!,
37+
components.has(i as D) ? fn(always[i], never[i]) : always[i],
5538
) as unknown as Vector<D, T>;
56-
57-
return {
58-
/** predicate */
59-
p: predicate,
60-
/** continue */
61-
c: _continue,
62-
/** midpoint */
63-
m: _midpoint,
64-
/**
65-
* array of active components.
66-
* `c: continue` mutates activation.
67-
* must be reset after recursive calls.
68-
*/
69-
d,
70-
} as const;
71-
};
72-
73-
type Context<
74-
D extends number,
75-
T,
76-
P extends (vector: ReadonlyVector<D, T>) => boolean,
77-
> = ReturnType<typeof createContext<D, T, P>>;
39+
};
7840

7941
type Division<D extends number, T> = {
8042
readonly always: ReadonlyVector<D, T>;
8143
readonly never: ReadonlyVector<D, T>;
8244
};
8345

84-
/** multiple Array.prototype.with */
85-
const vectorMergePartial = <D extends number, T>(
46+
/** Array.prototype.with */
47+
const vectorWith = <D extends number, T>(
8648
vector1: ReadonlyVector<D, T>,
87-
vector2: ReadonlyVector<D, T>,
88-
components: Iterable<number>,
49+
index: D,
50+
value: T,
8951
): Vector<D, T> => {
9052
const result = vector1.slice() as unknown as Vector<D, T>;
91-
for (const i of components) {
92-
// biome-ignore lint/style/noNonNullAssertion: index is always valid index
93-
result[i] = vector2[i]!;
94-
}
53+
result[index] = value;
9554
return result;
9655
};
9756

98-
const createDfsBinarySearch = <
99-
D extends number,
100-
T,
101-
P extends (point: ReadonlyVector<D, T>) => boolean,
102-
>(
103-
ctx: Context<D, T, P>,
104-
) => {
105-
const { p, m, c, d } = ctx;
106-
57+
const createDivide = <D extends number, T>(predicate: Predicate<D, T>) => {
10758
const divide = function* (
10859
forward: ReadonlyVector<D, T>,
10960
backward: ReadonlyVector<D, T>,
110-
base: ReadonlyVector<D, T>,
111-
baseResult: boolean,
112-
done = -1,
113-
omit = new Set<number>(),
61+
mid: ReadonlyVector<D, T>,
62+
result: boolean,
63+
components: ReadonlySet<D>,
64+
done: ReadonlySet<D> = new Set(),
11465
): Generator<Division<D, T>> {
115-
if (omit.size === 0) {
116-
const division = baseResult
117-
? { always: base, never: forward }
118-
: { always: forward, never: base };
119-
yield division;
120-
}
121-
if (omit.size === d.size) return;
122-
// NOTE: Copy to array so that we avoid mutations by continuation checks
123-
for (const i of [...d].filter((i) => i > done)) {
124-
// omit transition i
125-
omit.add(i);
126-
127-
// check if result changes between "base" and "forward without transition i"
128-
const newForward = vectorMergePartial(forward, base, omit);
129-
const newBackward = vectorMergePartial(base, backward, omit);
130-
const newResult = p(newForward);
131-
132-
// All combinations without transition i won't include boundary
133-
if (newResult === baseResult) {
134-
omit.delete(i);
66+
// Division with no omit must include boundary
67+
yield result
68+
? { always: mid, never: forward }
69+
: { always: forward, never: mid };
70+
const _done = new Set<D>(done);
71+
// Track all divisions with some omitted components
72+
for (const i of components) {
73+
if (_done.has(i)) continue;
74+
_done.add(i); // [0]
75+
76+
// Check whether the result changes between "mid" and "omitted"
77+
const omitted = vectorWith(forward, i, mid[i]);
78+
if (predicate(omitted) === result) {
79+
// If not, skip this and subsequent divisions:
80+
// this combination of omitted components makes a boundary-including division impossible.
13581
continue;
13682
}
83+
const counter = vectorWith(mid, i, backward[i]);
13784

138-
const always = baseResult ? newBackward : newForward;
139-
const never = baseResult ? newForward : newBackward;
140-
141-
// This division includes boundary
142-
yield { always, never };
143-
144-
// Generate all combinations without transition i, which can include boundary
145-
yield* divide(forward, backward, base, baseResult, i, omit);
146-
147-
// All combinations without transition i are already generated, continue with transition i
148-
omit.delete(i);
85+
// If the result changes, yield this division.
86+
// Subsequent divisions may include the boundary.
87+
yield* divide(omitted, backward, counter, result, components, _done);
14988
}
15089
};
90+
return divide;
91+
};
15192

93+
const createDfsBinarySearch = <D extends number, T>(
94+
predicate: Predicate<D, T>,
95+
divide: ReturnType<typeof createDivide<D, T>>,
96+
midpoint: (
97+
division: Division<D, T>,
98+
activeComponents: Set<D>,
99+
) => Vector<D, T>,
100+
shouldContinue: (
101+
division: Division<D, T>,
102+
activeComponents: Set<D>,
103+
) => Set<D>,
104+
) => {
152105
const dfsBinarySearch = function* (
153106
division: Division<D, T>,
107+
activeComponents: Set<D>,
154108
): Generator<Vector<D, T>> {
155-
const mid = m(division.always, division.never);
156-
const result = p(mid);
109+
const components = shouldContinue(division, activeComponents);
110+
if (components.size === 0) {
111+
yield division.always;
112+
return;
113+
}
114+
115+
const mid = midpoint(division, components);
116+
const result = predicate(mid);
117+
157118
const forward = result ? division.never : division.always;
158119
const backward = result ? division.always : division.never;
159120

160-
for (const { always, never } of divide(forward, backward, mid, result)) {
161-
// Create a copy of the active components to restore later
162-
const _d = [...d];
163-
if (c(always, never)) {
164-
// More division needed
165-
yield* dfsBinarySearch({ always, never });
166-
} else {
167-
yield always;
168-
}
169-
// Restore active components deactivated by DFS
170-
_d.forEach((i) => d.add(i));
121+
const divisions = divide(forward, backward, mid, result, components);
122+
123+
for (const { always, never } of divisions) {
124+
yield* dfsBinarySearch({ always, never }, components);
171125
}
172126
};
127+
173128
return dfsBinarySearch;
174129
};
175130

176131
export const ndBinarySearch = function* <D extends number, T>(
177132
alwaysEnd: ReadonlyVector<D, T>,
178133
neverEnd: ReadonlyVector<D, T>,
179-
predicate: (vector: ReadonlyVector<D, T>) => boolean,
180-
midpoint: FixedLengthArray<(always: T, never: T) => T, D>,
181-
shouldContinue: FixedLengthArray<(always: T, never: T) => boolean, D>,
134+
predicate: Predicate<D, T>,
135+
midpoint: Midpoint<D, T>,
136+
shouldContinue: ShouldContinue<D, T>,
182137
) {
183138
if (
184139
alwaysEnd.length !== neverEnd.length ||
@@ -187,7 +142,10 @@ export const ndBinarySearch = function* <D extends number, T>(
187142
) {
188143
throw new Error("All input vectors must have the same length");
189144
}
190-
const ctx = createContext(predicate, midpoint, shouldContinue);
191-
const dfsBinarySearch = createDfsBinarySearch(ctx);
192-
yield* dfsBinarySearch({ always: alwaysEnd, never: neverEnd });
145+
const divide = createDivide(predicate);
146+
const m = createMidpoint(midpoint);
147+
const c = createShouldContinue(shouldContinue);
148+
const dfsBinarySearch = createDfsBinarySearch(predicate, divide, m, c);
149+
const components = new Set<D>(Array.from(alwaysEnd, (_, i) => i as D));
150+
yield* dfsBinarySearch({ always: alwaysEnd, never: neverEnd }, components);
193151
};

0 commit comments

Comments
 (0)