sudoku.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. import sys
  2. import copy
  3. from contextlib import suppress
  4. from heapq import heappush, heappop
  5. board = sys.argv[1]
  6. algotype = sys.argv[2]
  7. # import board
  8. with open(board) as file:
  9. board = file.read().splitlines()
  10. board = board[:-1]
  11. # convert to list of list of ints
  12. for l in board:
  13. board[board.index(l)] = list(map(lambda x: int(x), l.split()))
  14. # return a board that is like the board b, but has domains for each element of b (always 1-9)
  15. def genDomains(b):
  16. for row in range(0, 9):
  17. for cell in range(0, 9):
  18. if (b[row][cell] == 0):
  19. b[row][cell] = list(range(1, 10))
  20. return b
  21. # returns True if value is valid
  22. def valid(brd, row, col, val):
  23. # check row
  24. if (val in brd[row]):
  25. return False
  26. # check column
  27. for i in range(0, 9):
  28. if (brd[i][col] == val):
  29. return False
  30. # check "box"
  31. rownum = int(row / 3)
  32. colnum = int(col / 3)
  33. for i in range(rownum * 3, rownum * 3 + 3):
  34. for j in range(colnum * 3, colnum * 3 + 3):
  35. if (brd[i][j] == val):
  36. return False
  37. return True
  38. # naive backtracking solver
  39. def naive(start):
  40. working = copy.deepcopy(start) # this is only "filled in values" and 0s
  41. solution = genDomains(start)
  42. # unassigned will be a list of positions we have to fill
  43. unassigned = []
  44. for i in range(0, 9):
  45. for j in range(0, 9):
  46. if (isinstance(solution[i][j], list)):
  47. unassigned.append((i, j))
  48. assumptions = []
  49. if(len(unassigned) == 0):
  50. return True
  51. # count assignments
  52. count = 0
  53. # while there are unassigned vars, keep going
  54. while(len(unassigned)):
  55. index = unassigned[-1]
  56. success = False
  57. # iterate over all values in the domain list
  58. while solution[index[0]][index[1]]:
  59. i = solution[index[0]][index[1]].pop()
  60. count += 1
  61. # took too long
  62. if (count >= 10,000):
  63. print("took too long")
  64. return False
  65. # check if this part of the domain(solution) is valid
  66. if (valid(working, index[0], index[1], i)):
  67. solution[index[0]][index[1]].append(i) # keep in the domain
  68. working[index[0]][index[1]] = i
  69. assumptions.append(index)
  70. unassigned.pop()
  71. success = True
  72. break
  73. if (success):
  74. continue
  75. else:
  76. # restore domain to full since we failed
  77. solution[index[0]][index[1]] = list(range(1, 10))
  78. working[index[0]][index[1]] = 0
  79. lastdex = assumptions.pop()
  80. solution[lastdex[0]][lastdex[1]].remove(working[lastdex[0]][lastdex[1]])
  81. working[lastdex[0]][lastdex[1]] = 0
  82. unassigned.append(lastdex)
  83. # if we exit without assigning everything, we should have failed
  84. if (unassigned): return False
  85. return working
  86. # returns a board (domains) where inferences are made for the cell at row, col
  87. def infer(solutions, brd, row, col, val):
  88. domains = copy.deepcopy(solutions)
  89. # remove from same row & col
  90. for i in range(0, 9):
  91. if (val in domains[row][i] and i != col):
  92. domains[row][i].remove(val)
  93. if (val in domains[i][col] and i != row):
  94. domains[i][col].remove(val)
  95. # remove for "box"
  96. rownum = int(row / 3)
  97. colnum = int(col / 3)
  98. for i in range(rownum * 3, rownum * 3 + 3):
  99. for j in range(colnum * 3, colnum * 3 + 3):
  100. if (val in domains[i][j] and (i != row and j != col)):
  101. domains[i][j].remove(val)
  102. return domains
  103. # generates domains in a format supporting forward checking
  104. def gen2Domains(b):
  105. for row in range(0, 9):
  106. for cell in range(0, 9):
  107. if (b[row][cell] == 0):
  108. b[row][cell] = list(range(1, 10))
  109. else:
  110. b[row][cell] = [b[row][cell]]
  111. return b
  112. # recursive solver for forward-checking
  113. def solve(working, domains, unassigned, count):
  114. if (not unassigned):
  115. return working
  116. index = unassigned.pop()
  117. # for every value in the domain, check if using it works. if all fail, backtrack.
  118. for i in domains[index[0]][index[1]]:
  119. working[index[0]][index[1]] = i
  120. count += 1
  121. # took too long
  122. if (count >= 10,000):
  123. print("took too long")
  124. return False
  125. newdomains = infer(domains, working, index[0], index[1], i)
  126. result = solve(working, newdomains, copy.deepcopy(unassigned), count)
  127. if (result):
  128. return result
  129. else:
  130. continue
  131. return False
  132. # forward checking solver
  133. def forward(start):
  134. working = copy.deepcopy(start) # this is only "filled in values" and 0s
  135. domains = gen2Domains(start)
  136. # unassigned will be a list of positions we have to fill
  137. unassigned = []
  138. for i in range(0, 9):
  139. for j in range(0, 9):
  140. if (len(domains[i][j]) == 9):
  141. unassigned.append((i, j))
  142. # forward-checking on pre-assigned values
  143. for i in range(0, 9):
  144. for j in range(0, 9):
  145. if (working[i][j] != 0):
  146. domains = infer(domains, working, i, j, working[i][j])
  147. return solve(working, domains, unassigned, 0)
  148. # returns size of domain for a given index
  149. def domsize(domains, index):
  150. return (len(domains[index[0]][index[1]]))
  151. # returns the # of 0s that are in the same row, col, or box as index
  152. def related(brd, index):
  153. count = 0
  154. # count 0s in row + col
  155. for i in range(0, 9):
  156. if (brd[index[0]][i] == 0 and i != index[1]):
  157. ++count
  158. if (brd[i][index[1]] == 0 and i != index[0]):
  159. ++count
  160. # count for "box" as well
  161. rownum = int(index[0] / 3)
  162. colnum = int(index[1] / 3)
  163. for i in range(rownum * 3, rownum * 3 + 3):
  164. for j in range(colnum * 3, colnum * 3 + 3):
  165. if (brd[i][j] == 0 and (i != index[0] and j != index[1])):
  166. ++count
  167. return count
  168. # returns the # of constraints that will follow from assigning index with val
  169. def lcv(solutions, index, val):
  170. count = 0
  171. # count 0s in row + col
  172. for i in range(0, 9):
  173. if (val in solutions[index[0]][i] and i != index[1]):
  174. ++count
  175. if (val in solutions[i][index[1]] and i != index[0]):
  176. ++count
  177. # count for "box" as well
  178. rownum = int(index[0] / 3)
  179. colnum = int(index[1] / 3)
  180. for i in range(rownum * 3, rownum * 3 + 3):
  181. for j in range(colnum * 3, colnum * 3 + 3):
  182. if (val in solutions[i][j] and (i != index[0] and j != index[1])):
  183. ++count
  184. return count
  185. # return the correct node + val to try
  186. def genVal(domains, working, unassigned):
  187. # used to track intermediary values
  188. heap = []
  189. superheap = []
  190. bestrating = 1.0
  191. # get the best indices according to domain size
  192. for i in unassigned:
  193. rating = domsize(domains, i) / 9.0
  194. if (rating < bestrating):
  195. bestrating = rating
  196. heap = [i]
  197. elif (rating == bestrating):
  198. heap.append(i)
  199. # get the best indices according to degree(related cells)
  200. bestrating = 1
  201. for i in heap:
  202. rating = related(working, i) / 27.0
  203. if (rating < bestrating):
  204. bestrating = rating
  205. superheap = [i]
  206. elif (rating == bestrating):
  207. superheap.append(i)
  208. index = superheap[0]
  209. bestrating = 27
  210. val = working[index[0]][index[1]]
  211. # get best values according to LCV
  212. for i in domains[index[0]][index[1]]:
  213. rating = lcv(domains, index, i)
  214. if (rating <= bestrating):
  215. bestrating = rating
  216. val = i
  217. return (index, val)
  218. # recursive solver that uses heuristics to decide what node to explore
  219. def solveh(working, domains, unassigned, count):
  220. if (not unassigned):
  221. return working
  222. # while there are unassigned values keep trying
  223. while(unassigned):
  224. # get next value using heuristics, remove this node from assigned
  225. nextThing = genVal(domains, working, unassigned)
  226. index = nextThing[0]
  227. val = nextThing[1]
  228. working[index[0]][index[1]] = val
  229. unassigned.remove(index)
  230. count += 1
  231. # took too long
  232. if (count >= 10,000):
  233. print("took too long")
  234. return False
  235. # check for invalidated nodes (empty domain)
  236. flag = True
  237. result = False
  238. newdomains = infer(domains, working, index[0], index[1], val)
  239. for i in range(0, 9):
  240. for j in range(0, 9):
  241. if (not domains[i][j]):
  242. flag = False
  243. # success! recurse
  244. if (flag): result = solveh(working, newdomains, copy.deepcopy(unassigned), count)
  245. if (result):
  246. return result
  247. elif (len(domains[index[0]][index[1]]) > 1): # remove from domain, keep going
  248. working[index[0]][index[1]] = 0
  249. domains[index[0]][index[1]].remove(val)
  250. unassigned.append(index)
  251. else: # no values worked :( return false
  252. return False
  253. # forward checking solver with heuristics
  254. def heuristic(start):
  255. working = copy.deepcopy(start) # this is only "filled in values" and 0s
  256. domains = gen2Domains(start)
  257. # unassigned will be a list of positions we have to fill
  258. unassigned = []
  259. for i in range(0, 9):
  260. for j in range(0, 9):
  261. if (len(domains[i][j]) == 9):
  262. unassigned.append((i, j))
  263. # initial inferences
  264. for i in range(0, 9):
  265. for j in range(0, 9):
  266. if (working[i][j] != 0):
  267. domains = infer(domains, working, i, j, working[i][j])
  268. return solveh(working, domains, unassigned, 0)
  269. def main():
  270. print("###########")
  271. print(*board, sep='\n')
  272. print("##########")
  273. if (algotype == str(0)):
  274. result = naive(board)
  275. elif (algotype == str(1)):
  276. result = forward(board)
  277. elif (algotype == str(2)):
  278. result = heuristic(board)
  279. else:
  280. print("No valid algorithm selected. RIP.")
  281. print("###########")
  282. print(*result, sep='\n')
  283. print("##########")
  284. main()