An array is squareful if the sum of every pair of adjacent elements is a perfect square.
Given an integer array nums, return the number of permutations of nums
that are squareful.
Two permutations perm1
and perm2
are different if there is some index i
such that perm1[i] != perm2[i]
.
Example 1:
Input: nums = [1,17,8] Output: 2 Explanation: [1,8,17] and [17,8,1] are the valid permutations.
Example 2:
Input: nums = [2,2,2] Output: 1
Constraints:
1 <= nums.length <= 12
0 <= nums[i] <= 109
This problem asks to find the number of permutations of an input array nums
such that the sum of every pair of adjacent elements is a perfect square. The solution uses dynamic programming with bit manipulation to efficiently explore the permutation space.
Approach:
The core idea is to build a dynamic programming table f[i][j]
where:
i
represents a bitmask indicating which elements of nums
have been included in the current permutation. A bit set to 1 means the corresponding element is included.j
represents the index of the last element added to the permutation.f[i][j]
stores the number of permutations that use the elements represented by the bitmask i
and end with element j
.The algorithm iterates through all possible bitmasks and adds up the counts of valid permutations. It handles duplicate numbers by dividing the final answer by the factorial of the count of each duplicate number. This is because permutations of identical numbers are considered the same.
Time Complexity Analysis:
nums
, which is O(2n), where n is the length of nums
.Therefore, the overall time complexity is dominated by the outer loop and is O(n * 2n).
Space Complexity Analysis:
The space complexity is dominated by the DP table f
, which has dimensions O(2n * n). Therefore, the space complexity is O(n * 2n).
Code Explanation (Python):
class Solution:
def numSquarefulPerms(self, nums: List[int]) -> int:
n = len(nums)
f = [[0] * n for _ in range(1 << n)] # DP table initialization
for j in range(n):
f[1 << j][j] = 1 # Base case: single element permutations
for i in range(1 << n):
for j in range(n):
if i >> j & 1: #Check if element j is in current subset
for k in range(n):
if (i >> k & 1) and k != j: #Check for another element in the subset
s = nums[j] + nums[k]
t = int(sqrt(s)) # Check for perfect square
if t * t == s:
f[i][j] += f[i ^ (1 << j)][k] #Add to count if square
ans = sum(f[(1 << n) - 1][j] for j in range(n)) #Total Permutations
for v in Counter(nums).values(): #Handle duplicates
ans //= factorial(v)
return ans
The other languages (Java, C++, Go) follow a very similar structure, implementing the same dynamic programming approach with bit manipulation. The only difference is syntax and standard library functions used. The core logic remains consistent.