Skip to content

Commit 826a16c

Browse files
Atom strategy (#385)
* extra params for atom strategy * CHANGELOG.md * use var * tidying code
1 parent 2ce61fb commit 826a16c

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
77
## [Unreleased]
88

99
### Added
10-
- the `AtomStrategy` handles extra parameters when counting
10+
- the `AtomStrategy` handles extra parameters when counting, finding equations and sampling
1111

1212
### Changed
1313
- `from_dict` methods use `copy` to avoid deleting dicts

comb_spec_searcher/strategies/strategy.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -810,33 +810,38 @@ def get_terms(self, comb_class: CombinatorialClass, n: int) -> Terms:
810810
return Counter()
811811

812812
def get_objects(self, comb_class: CombinatorialClass, n: int) -> Objects:
813-
if comb_class.extra_parameters:
814-
raise NotImplementedError
815813
res: Objects = defaultdict(list)
816814
if n == comb_class.minimum_size_of_object():
817-
res[tuple()].append(next(comb_class.objects_of_size(n)))
815+
obj = next(comb_class.objects_of_size(n))
816+
param = comb_class.get_parameters(obj)
817+
res[param].append(obj)
818818
return res
819819

820820
def get_genf(
821821
self,
822822
comb_class: CombinatorialClass,
823823
funcs: Optional[Dict[CombinatorialClass, Function]] = None,
824824
) -> Any:
825-
if comb_class.extra_parameters:
826-
raise NotImplementedError
827825
if not self.verified(comb_class):
828826
raise StrategyDoesNotApply("Can't find generating functon for non-atom.")
829-
x = var("x")
830-
return x ** comb_class.minimum_size_of_object()
827+
obj = next(comb_class.objects_of_size(comb_class.minimum_size_of_object()))
828+
param = comb_class.get_parameters(obj)
829+
variables = comb_class.extra_parameters
830+
res = var("x") ** comb_class.minimum_size_of_object()
831+
for k, v in zip(variables, param):
832+
res *= var(k) ** v
833+
return res
831834

832835
def random_sample_object_of_size(
833836
self, comb_class: CombinatorialClass, n: int, **parameters: int
834837
) -> CombinatorialObject:
835-
if comb_class.extra_parameters:
836-
raise NotImplementedError
837838
if n != comb_class.minimum_size_of_object():
838839
raise ValueError("Invalid size")
839840
obj: CombinatorialObject = next(comb_class.objects_of_size(n))
841+
param = comb_class.get_parameters(obj)
842+
variables = comb_class.extra_parameters
843+
if parameters != dict(zip(variables, param)):
844+
raise ValueError("Invalid params")
840845
return obj
841846

842847
def verified(self, comb_class: CombinatorialClass) -> bool:

0 commit comments

Comments
 (0)