summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSébastien Labbé <slabqc@gmail.com>2018-04-22 16:16:21 +0200
committerSébastien Labbé <slabqc@gmail.com>2018-05-24 10:12:37 +0200
commit004126a706f01ceacecc01a0c4db5483a8202d7e (patch)
treeb571c1a0994362d7315365e05a936ad615fab189
parent25125: improved sentences before some doctests (diff)
rewriting number_of_solutions method like one_solution and all_solutions
-rw-r--r--src/sage/combinat/matrices/dancing_links.pyx100
1 files changed, 40 insertions, 60 deletions
diff --git a/src/sage/combinat/matrices/dancing_links.pyx b/src/sage/combinat/matrices/dancing_links.pyx
index d981994..9fdb50e 100644
--- a/src/sage/combinat/matrices/dancing_links.pyx
+++ b/src/sage/combinat/matrices/dancing_links.pyx
@@ -667,59 +667,6 @@ cdef class dancing_linksWrapper:
L.extend(val)
return L
- def _number_of_solutions_iterator(self, ncpus=None, column=None):
- r"""
- Return an iterator over the number of solutions using each row
- containing a ``1`` in the given ``column``.
-
- INPUT:
-
- - ``ncpus`` -- integer (default: ``None``), maximal number of
- subprocesses to use at the same time. If ``None``, it detects the
- number of effective CPUs in the system using
- :func:`sage.parallel.ncpus.ncpus()`.
- - ``column`` -- integer (default: ``None``), the column used to split
- the problem, if ``None`` a random column is chosen
-
- OUTPUT:
-
- iterator of tuples (row number, number of solutions)
-
- EXAMPLES::
-
- sage: from sage.combinat.matrices.dancing_links import dlx_solver
- sage: rows = [[0,1,2], [3,4,5], [0,1], [2,3,4,5], [0], [1,2,3,4,5]]
- sage: d = dlx_solver(rows)
- sage: sorted(d._number_of_solutions_iterator(ncpus=2, column=3))
- [(1, 1), (3, 1), (5, 1)]
-
- ::
-
- sage: S = Subsets(range(5))
- sage: rows = [list(x) for x in S]
- sage: d = dlx_solver(rows)
- sage: d.number_of_solutions()
- 52
- sage: sum(b for a,b in d._number_of_solutions_iterator(ncpus=2, column=3))
- 52
- """
- if column is None:
- from random import randrange
- column = randrange(self.ncols())
-
- if not 0 <= column < self.ncols():
- raise ValueError("column(={}) must be in range(ncols) "
- "where ncols={}".format(column, self.ncols()))
-
- from sage.parallel.decorate import parallel
- @parallel(ncpus=ncpus)
- def nb_sol(i):
- return self.restrict([i]).number_of_solutions()
-
- indices = [i for (i,row) in enumerate(self._rows) if column in row]
- for ((args, kwds), val) in nb_sol(indices):
- yield args[0], val
-
def number_of_solutions(self, ncpus=None, column=None):
r"""
Return the number of distinct solutions.
@@ -758,19 +705,32 @@ cdef class dancing_linksWrapper:
sage: x.number_of_solutions(ncpus=2, column=3)
3
+ ::
+
+ sage: S = Subsets(range(5))
+ sage: rows = map(list, S)
+ sage: d = dlx_solver(rows)
+ sage: d.number_of_solutions()
+ 52
+
+ TESTS:
+
The way it is coded, solutions of a dlx solver can be iterated
through only once. The second call to the function gives wrong
result::
+ sage: rows = [[0,1,2], [3,4,5], [0,1], [2,3,4,5], [0], [1,2,3,4,5]]
sage: x = dlx_solver(rows)
- sage: x.number_of_solutions()
+ sage: x.number_of_solutions(ncpus=1)
3
- sage: x.number_of_solutions()
+ sage: x.number_of_solutions(ncpus=1)
0
- TESTS::
+ ::
- sage: dlx_solver([]).number_of_solutions()
+ sage: dlx_solver([]).number_of_solutions(ncpus=None)
+ 0
+ sage: dlx_solver([]).number_of_solutions(ncpus=1)
0
"""
cdef int N = 0
@@ -778,9 +738,29 @@ cdef class dancing_linksWrapper:
while self.search():
N += 1
return N
- else:
- it = self._number_of_solutions_iterator(ncpus, column)
- return sum(val for (k,val) in it)
+
+ if self.ncols() == 0:
+ return 0
+
+ if column is None:
+ from random import randrange
+ column = randrange(self.ncols())
+
+ if not 0 <= column < self.ncols():
+ raise ValueError("column(={}) must be in range(ncols) "
+ "where ncols={}".format(column, self.ncols()))
+
+ from sage.parallel.decorate import parallel
+ @parallel(ncpus=ncpus)
+ def nb_sol(i):
+ dlx = self.restrict([i])
+ N = 0
+ while dlx.search():
+ N += 1
+ return N
+
+ indices = [i for (i,row) in enumerate(self._rows) if column in row]
+ return sum(val for ((args, kwds), val) in nb_sol(indices))
def dlx_solver(rows):
"""