BLOG | Split an array into k subarrays

Hint: Binary search is mostly used for this kind of questions.

The simplest form of this question would be finding pivot element, which requires you to find an index in an array that makes the sum of all elements to its left equals to the sum of all elements to its right.

In this question, simply use a list prefix to hold all left sum values and a variable suffix to hold the total sum. We can iterate the list in one pass to find a pivot i where prefix[i] = (suffix – array[i]) / 2.0.

If we want to split an array into k multiple subarrays where k > 2. In most cases we still have to interact with array prefix and suffix, as well as the relationship among subarrays. For example, a common question can be Leetcode 1712: Ways to Split Array into Three Subarrays:

Q1: Count number of ways to split a list nums into three non-empty subarrays A1, A2, and A3 with the sum of elements of the subarrays as S1, S2, and S3 respectively, such that S1, S2, and S3 satisfy S1 ≤ S2 ≤ S3. Return the value modulo (10⁹ + 7).

In this case, we compute the prefix and suffix for nums, and form an inequality among subarrays such as: Given the prefix array and pivot i ≤ j, prefix[i] ≤ prefix[j] – prefix[i] ≤ prefix[-1] – prefix[j].

The inequality can further be written in 2 separate forms as:

  • prefix[j] ≥ prefix[i] * 2 ……………………………………………… form 1
  • prefix[j] ≤ (prefix[i] + prefix[-1]) // 2 ……………………….. form 2

We then use functions in bisect to find i and j, respectively.

def waysToSplit(nums: List[int]) -> int:
    prefix = [0]
    for x in nums: prefix.append(prefix[-1] + x)
        
    ans = 0
    # prefix[i] <= prefix[j] - prefix[i] <= prefix[-1] - prefix[j]
    # prefix[j] >= prefix[i] * 2
    # prefix[k] <= (prefix[-1] + prefix[i]) // 2
    for i in range(1, len(nums)): 
        j = bisect_left(prefix, 2*prefix[i])
        k = bisect_right(prefix, (prefix[i] + prefix[-1])//2)
        ans += max(0, min(len(nums), k) - max(i+1, j))
    return ans % 1000000007

n = [-1, 2, 3]
print(waysToSplit(n))

When an in-equation cannot be established

Q1: Count number of ways to split a list nums into three non-empty subarrays A1, A2, and A3 with the sum of elements of the subarrays as S1, S2, and S3 respectively, such that S1, S2, and S3 satisfy S2 ≤ S1 + S3. Return the value modulo (10⁹ + 7).

In this case, given the prefix array and pivot i <= j, we can only form one in-equation prefix[j] – prefix[i] <= prefix[i] + (prefix[-1] – prefix[j]), which is prefix[j] <= (prefix[i] * 2 + prefix[-1]) // 2. Since we cannot compute the lower bound of j, we will have to write our own binary search.

def countWays(nums):
    n = len(nums)
    pre = [0] * n
    for i in range(n):
        pre[i] = pre[i - 1] + nums[i]
    
    def check(l, r):
        s1 = pre[l]
        s2 = pre[r] - pre[l]
        s3 = pre[-1] - pre[r]
        if s2 <= (s1 + s3):
            return True
        return False
    
    ans, mod = 0, 10**9 + 7
    for i in range(n - 2):
        lo, hi = i + 1, n - 2
        while hi - lo > 1:
            mid = (lo + hi) // 2
            if check(i, mid):
                lo = mid
            else:
                hi = mid
        if check(i, hi):
            ans += hi - i
        elif check(i, lo):
            ans += lo - i
        ans %= mod
    return ans

print(countWays([1,2,2,3]))

Q2: Consider another question Leetcode 410: Largest sum for each subarray, when we can’t establish any relationship among subarrays. We will still use a binary search where left is max(nums) and right is sum(nums).

def splitArray(self, nums: List[int], k: int) -> int:
    def min_sub_required(nums: List[int], max_sum_allowed: int) -> int:
        curr_sum = 0
        splits = 0

        for i in nums:
            if curr_sum + i <= max_sum_allowed:
                curr_sum += i
            else:
                curr_sum = i
                splits += 1
        return splits + 1
        
    left = max(nums)
    right = sum(nums)
    while left <= right:
        max_sum_allowed = (left + right) // 2
        if min_sub_required(nums, max_sum_allowed) <= k:
            right = max_sum_allowed - 1
            min_largest_split_sum = max_sum_allowed
        else:
            left = max_sum_allowed + 1
    return min_largest_split_sum

Resources:

Leave a comment