Leetcode #2448 Minimum Cost to Make Array Equal

Read the question here: https://leetcode.com/problems/minimum-cost-to-make-array-equal/description/

Let the final array be [k, k, k, k, ...] (all the elements being equal to k)

k has to be in the range between min(nums) and max(nums)

If we can try every number (let's say k) between this range and calculate the total cost it'd take to make the entire array equal to k, we'd be good.

How to calculate total-cost, given number k

Let's just sort the nums, and arrange corresponding costs.

nums = [a, b, c, d, e, f] | cost = [p, q, r, s, t, u]
\=> a ≤ b ≤ c ≤ d ≤ e ≤ f

k can come any where between a and f. Let's say it comes between b and c.
a < b < k < c < d < e < f

Cost of taking one element (let's say a) to k: (k-a)p

So our final equation for entire array becomes:
Total Cost:

= (k-a)p + (k-b)q + (c-k)r + (d-k)s + (e-k)t + (f-k)u
= [(k-a)p + (k-b)q]  +  [(c-k)r + (d-k)s + (e-k)t + (f-k)u]
= [kp-ap + kq-bq] + [cr-kr + ds-ks + et-kt + fu-ku]
= [kp+kq -ap-bq] + [cr+ds+et+fu -kr-ks-kt-ku]
= [ k(p+q) - (ap+bq) ] + [ (cr+ds+et+fu) - k(r+s+t+u) ]

Notice that left bracket has cost numbers p,q for a,b which are less than k. And right bracket has cost numbers r,s,t,u for c,d,e,f which are greater than k.

Split point (i.e. between b and c) can be known via binary search of k on nums

Sums in parenthesises are actually sequential, just split across left and right brackets.
p+q+r+s+t+u is split as (p+q) in left and (r+s+t+u) and right
ap+bq+cr+ds+et+fu is split as (ap+bq) in left and (cr+ds+et+fu) and right

Now sums like (p+q) and (r+s+t+u) can easily be calculated from a prefix-sum array on cost.

Sums like (ap+bq) and (cr+ds+et+fu) can be calculated from a prefix-sum array of [nums[i]*cost[i]]

Final Code:

class Solution:
    def minCost(self, nums: List[int], cost: List[int]) -> int:
        n = len(nums)
        arrays = sorted(zip(nums, cost))
        cost = lambda index: arrays[index][1]
        nums = list(map(lambda k: k[0], arrays))
        mn, mx = nums[0], nums[-1]

        cs = [cost(0)] # prefix-sum array on cost
        sums = [nums[0]*cost(0)] # prefix-sum array for nums[i]*cost[i]
        for i in range(1,n):
            sums.append(sums[-1]+nums[i]*cost(i))
            cs.append(cs[-1]+cost(i))
        sums.append(0) # Util helper when someone tries to access sums[-1]
        cs.append(0) # Util helper when someone tries to access cs[-1]

        def get_cost(num):

            i = bisect_left(nums, num)
            left = num*cs[i-1] - sums[i-1]

            # Just want to highlight there can be duplicates

            i = bisect_right(nums, num)
            right = (sums[n-1]-sums[i-1]) - num*(cs[n-1]-cs[i-1])

            return left+right


        return min(get_cost(num) for num in range(mn, mx+1))

Time Complexity:
n = length(nums) = length(cost); a = max(nums)-min(nums)

minCost(): O(n.logn + a.logn)

  • Sort: O(n.logn)

  • Prefix Arrays Build: O(n)

  • Loop to calculate cost & find min: O(a.logn)

    • loop {min(nums) ... max(nums)}: O(a)

    • get_cost(): O(logn)

      • 2 Binary Searches on nums of length n