Python: Improving long cumulative sum

Posted by Bo102010 on Stack Overflow See other posts from Stack Overflow or by Bo102010
Published on 2010-05-30T03:06:33Z Indexed on 2010/05/30 3:12 UTC
Read the original article Hit count: 486

Filed under:
|

I have a program that operates on a large set of experimental data. The data is stored as a list of objects that are instances of a class with the following attributes:

  • time_point - the time of the sample
  • cluster - the name of the cluster of nodes from which the sample was taken
  • code - the name of the node from which the sample was taken
  • qty1 = the value of the sample for the first quantity
  • qty2 = the value of the sample for the second quantity

I need to derive some values from the data set, grouped in three ways - once for the sample as a whole, once for each cluster of nodes, and once for each node. The values I need to derive depend on the (time sorted) cumulative sums of qty1 and qty2: the maximum value of the element-wise sum of the cumulative sums of qty1 and qty2, the time point at which that maximum value occurred, and the values of qty1 and qty2 at that time point.

I came up with the following solution:

dataset.sort(key=operator.attrgetter('time_point'))

# For the whole set
sys_qty1 = 0
sys_qty2 = 0
sys_combo = 0
sys_max = 0

# For the cluster grouping
cluster_qty1 = defaultdict(int)
cluster_qty2 = defaultdict(int)
cluster_combo = defaultdict(int)
cluster_max = defaultdict(int)
cluster_peak = defaultdict(int)

# For the node grouping
node_qty1 = defaultdict(int)
node_qty2 = defaultdict(int)
node_combo = defaultdict(int)
node_max = defaultdict(int)
node_peak = defaultdict(int)

for t in dataset:
  # For the whole system ######################################################
  sys_qty1 += t.qty1
  sys_qty2 += t.qty2
  sys_combo = sys_qty1 + sys_qty2
  if sys_combo > sys_max:
    sys_max = sys_combo
    # The Peak class is to record the time point and the cumulative quantities
    system_peak = Peak(time_point=t.time_point,
                       qty1=sys_qty1,
                       qty2=sys_qty2)
  # For the cluster grouping ##################################################
  cluster_qty1[t.cluster] += t.qty1
  cluster_qty2[t.cluster] += t.qty2
  cluster_combo[t.cluster] = cluster_qty1[t.cluster] + cluster_qty2[t.cluster]
  if cluster_combo[t.cluster] > cluster_max[t.cluster]:
    cluster_max[t.cluster] = cluster_combo[t.cluster]
    cluster_peak[t.cluster] = Peak(time_point=t.time_point,
                                   qty1=cluster_qty1[t.cluster],
                                   qty2=cluster_qty2[t.cluster])
  # For the node grouping #####################################################
  node_qty1[t.node] += t.qty1
  node_qty2[t.node] += t.qty2
  node_combo[t.node] = node_qty1[t.node] + node_qty2[t.node]
  if node_combo[t.node] > node_max[t.node]:
    node_max[t.node] = node_combo[t.node]
    node_peak[t.node] = Peak(time_point=t.time_point,
                             qty1=node_qty1[t.node],
                             qty2=node_qty2[t.node])

This produces the correct output, but I'm wondering if it can be made more readable/Pythonic, and/or faster/more scalable.

The above is attractive in that it only loops through the (large) dataset once, but unattractive in that I've essentially copied/pasted three copies of the same algorithm.

To avoid the copy/paste issues of the above, I tried this also:

def find_peaks(level, dataset):

  def grouping(object, attr_name):
    if attr_name == 'system':
      return attr_name
    else:
      return object.__dict__[attrname]

  cuml_qty1 = defaultdict(int)
  cuml_qty2 = defaultdict(int)
  cuml_combo = defaultdict(int)
  level_max = defaultdict(int)
  level_peak = defaultdict(int)

  for t in dataset:
    cuml_qty1[grouping(t, level)] += t.qty1
    cuml_qty2[grouping(t, level)] += t.qty2
    cuml_combo[grouping(t, level)] = (cuml_qty1[grouping(t, level)] +
                                      cuml_qty2[grouping(t, level)])
    if cuml_combo[grouping(t, level)] > level_max[grouping(t, level)]:
      level_max[grouping(t, level)] = cuml_combo[grouping(t, level)]
      level_peak[grouping(t, level)] = Peak(time_point=t.time_point,
                                            qty1=node_qty1[grouping(t, level)],
                                            qty2=node_qty2[grouping(t, level)])
  return level_peak

system_peak = find_peaks('system', dataset)
cluster_peak = find_peaks('cluster', dataset)
node_peak = find_peaks('node', dataset)

For the (non-grouped) system-level calculations, I also came up with this, which is pretty:

dataset.sort(key=operator.attrgetter('time_point'))

def cuml_sum(seq):
  rseq = []
  t = 0
  for i in seq:
    t += i
    rseq.append(t)
  return rseq

time_get = operator.attrgetter('time_point')
q1_get = operator.attrgetter('qty1')
q2_get = operator.attrgetter('qty2')

timeline = [time_get(t) for t in dataset]
cuml_qty1 = cuml_sum([q1_get(t) for t in dataset])
cuml_qty2 = cuml_sum([q2_get(t) for t in dataset])
cuml_combo = [q1 + q2 for q1, q2 in zip(cuml_qty1, cuml_qty2)]

combo_max = max(cuml_combo)
time_max = timeline.index(combo_max)
q1_at_max = cuml_qty1.index(time_max)
q2_at_max = cuml_qty2.index(time_max)

However, despite this version's cool use of list comprehensions and zip(), it loops through the dataset three times just for the system-level calculations, and I can't think of a good way to do the cluster-level and node-level calaculations without doing something slow like:

timeline = defaultdict(int)
cuml_qty1 = defaultdict(int)
#...etc.

for c in cluster_list:
  timeline[c] = [time_get(t) for t in dataset if t.cluster == c]
  cuml_qty1[c] = [q1_get(t) for t in dataset if t.cluster == c]
  #...etc.

Does anyone here at Stack Overflow have suggestions for improvements? The first snippet above runs well for my initial dataset (on the order of a million records), but later datasets will have more records and clusters/nodes, so scalability is a concern.

This is my first non-trivial use of Python, and I want to make sure I'm taking proper advantage of the language (this is replacing a very convoluted set of SQL queries, and earlier versions of the Python version were essentially very ineffecient straight transalations of what that did). I don't normally do much programming, so I may be missing something elementary.

Many thanks!

© Stack Overflow or respective owner

Related posts about python

Related posts about list-comprehension