module DualPortMem #(T, int SIZE) {

	state T[SIZE] mem

	domain write_clk
	interface write : bool write, int addr, T data

	when write {
		mem[addr] = data
	}

	domain read_clk
	interface read : int read_addr -> T read_data

	CrossDomain #(T: type T[SIZE]) cross_mem
	cross_mem.in = mem
	
	read_data = cross_mem.out[read_addr]
}

module FIFO #(
	T,
	int DEPTH,
	// The FIFO may still receive data for several cycles after ready is de-asserted
	int READY_SLACK
) {
	state T[DEPTH] mem
	state int read_addr
	state int write_addr

	initial read_addr = 0
	initial write_addr = 0

	domain push_clk
	output bool ready'0
	interface push : bool push'READY_SLACK, T data_in'READY_SLACK

	domain pop_clk
	interface pop : bool pop -> bool data_valid, T data_out
	
	CrossDomain write_to_pop
	write_to_pop.in = write_addr

	CrossDomain read_to_push
	read_to_push.in = read_addr

	CrossDomain mem_to_pop
	mem_to_pop.in = mem

	when pop {
		data_valid = read_addr != write_to_pop.out
		when data_valid {
			// Add a pipelining register, because it can usually be fitted to the 
			reg data_out = mem_to_pop.out[read_addr]
			read_addr = (read_addr + 1) % DEPTH
		}
	}

	when push {
		mem[write_addr] = data_in
		write_addr = (write_addr + 1) % DEPTH
	}
	
	// Wrapping subtract
	int space_remaining = (read_to_push.out - write_addr) % DEPTH
	gen int ALMOST_FULL_TRESHOLD = READY_SLACK + 1 // +1 for the latency reg we introduce here
	reg bool r = space_remaining >  ALMOST_FULL_TRESHOLD
	ready = LatencyOffset #(OFFSET: -ALMOST_FULL_TRESHOLD)(r)
}

module JoinDomains #(T1, T2, int OFFSET) {
	interface identity1 : T1 i1'0 -> T1 o1'0
	interface identity2 : T2 i2'OFFSET -> T2 o2'OFFSET
}

module Iterator {
	// action
	interface start : bool start, int up_to
	// trigger
	interface iter : -> bool valid, state int value

	state int current_limit

	valid = value != current_limit

	when start {
		current_limit = up_to
		value = 0
	} else when valid {
		value = value + 1
	}
}

module FixedSizeIterator #(int UP_TO) {
	interface iter : -> bool valid, state int value

	output bool last

	initial value = 0

	interface start : bool start

	last = value == UP_TO - 1
	valid = value != UP_TO

	when start {
		value = 0
	} else {
		when valid {
			value = value + 1
		}
	}
}

module SlowClockGenerator #(int PERIOD) {
	interface SlowClockGenerator : -> state int cur_value
	
	initial cur_value = 0

	when cur_value == PERIOD-1 {
		cur_value = 0
	} else {
		cur_value = cur_value + 1
	}
}

module SplitAt #(T, int SIZE, int SPLIT_POINT) {
	interface SplitAt : T[SIZE] i -> T[SPLIT_POINT] left, T[SIZE - SPLIT_POINT] right

	for int I in 0..SPLIT_POINT {
		left[I] = i[I]
	}
	for int I in 0..SIZE - SPLIT_POINT {
		right[I] = i[I+SPLIT_POINT]
	}
}

module Abs {
	interface Abs : int a -> int o

	when a < 0 {
		o = -a
	} else {
		o = a
	}
}

// Temporary, to be replaced with slice syntax : result = vals[FROM +: OUT_SIZE]
module Slice #(T, int SIZE, int OUT_SIZE, int FROM) {
	interface Slice : T[SIZE] vals -> T[OUT_SIZE] result

	for int I in 0..OUT_SIZE {
		result[I] = vals[I + FROM]
	}
}

module BitSelect #(int SIZE) {
	interface BitSelect : int selection -> bool[SIZE] bits

	for int I in 0..SIZE {
		bits[I] = false
	}

	bits[selection] = true
}

module PopCount #(int WIDTH) {
	// Should be chosen based on what's most efficient for the target architecture
	gen int BASE_CASE_SIZE = 5

	interface PopCount : bool[WIDTH] bits -> int popcount


	if WIDTH <= BASE_CASE_SIZE {
		int[WIDTH] cvt
		for int I in 0..WIDTH {
			when bits[I] {
				cvt[I] = 1
			} else {
				cvt[I] = 0
			}
		}
		reg popcount = +cvt
	} else {
		gen int LEFT_WIDTH = WIDTH / 2
		gen int RIGHT_WIDTH = WIDTH - LEFT_WIDTH

		bool[LEFT_WIDTH] left_part, bool[RIGHT_WIDTH] right_part = SplitAt #(SIZE: WIDTH, SPLIT_POINT: LEFT_WIDTH, T: type bool)(bits)

		reg reg popcount = PopCount #(WIDTH: LEFT_WIDTH)(left_part) + PopCount #(WIDTH: RIGHT_WIDTH)(right_part)
	}
}


// Recursive Tree Add module recurses smaller copies of itself. 
module TreeAdd #(int WIDTH) {
	interface TreeAdd : int[WIDTH] values'0 -> int total

	if WIDTH == 0 {
		// Have to explicitly give zero a latency count. 
		// Otherwise total's latency can't be determined. 
		int zero'0 = 0
		total = zero
	} else if WIDTH == 1 {
		total = values[0]
	} else {
		gen int L_SZ = WIDTH / 2
		gen int R_SZ = WIDTH - L_SZ

		// Can declare modules and use them later. 
		SplitAt #(SIZE: WIDTH, SPLIT_POINT: L_SZ, T: type int) split
		int[L_SZ] left_part, int[R_SZ] right_part = split(values)

		// Or instantiate submodules inline
		int left_total = TreeAdd #(WIDTH: L_SZ)(left_part)
		int right_total = TreeAdd #(WIDTH: R_SZ)(right_part)
		
		// Can add pipelining registers here too. 
		// Latency Counting will figure it out.
		reg total = left_total + right_total
	}
}