from ctypes import *

import sys, math, random, time, multiprocessing, zlib, array, struct, os

TYPE_INVALID, SDF_PLANE, SDF_LINE, SDF_SPHERE, SDF_LOGIC, MESH = range(6)
OP_INTERSECT, OP_UNION, OP_DIFFERENCE = range(3)

def smoothstep(edge0, edge1, x):
    t = max(0, min(1, (x - edge0) / (edge1 - edge0)))
    return t * t * (3 - 2 * t)

class VecMixin(object):
    def length(self):
        return math.sqrt(self.length2())

    def length2(self):
        return self.dot(self)

    def normalize(self):
        return self * (1 / (self.length() or 1))

    def __rmul__(self, other):
        return self * other

    def __truediv__(self, other):
        assert isinstance(other, (float, int)), other
        return self.__class__(*(self.v[i] / other for i in range(len(self.v))))

class vec2(Structure, VecMixin):
    _fields_ = [
        ('v', c_float * 2),
    ]

    def __init__(self, x, y):
        self.v[0] = x
        self.v[1] = y

    def dot(self, other):
        return self.v[0] * other.v[0] + self.v[1] * other.v[1]

    def __add__(self, other):
        return vec2(self.v[0] + other.v[0], self.v[1] + other.v[1])

    def __sub__(self, other):
        return vec2(self.v[0] - other.v[0], self.v[1] - other.v[1])

    def __mul__(self, other):
        if isinstance(other, vec2):
            return vec2(self.v[0] * other.v[0], self.v[1] * other.v[1])
        elif isinstance(other, (float, int)):
            return vec2(self.v[0] * other, self.v[1] * other)
        else:
            raise ValueError(other)

    def rotate(self, angle):
        s = math.sin(angle)
        c = math.cos(angle)

        return vec2(self.v[0] * c - self.v[1] * s, self.v[0] * s + self.v[1] * c)

class vec3(Structure, VecMixin):
    _fields_ = [
        ('v', c_float * 3),
    ]

    def __init__(self, x, y, z):
        self.v[0] = x
        self.v[1] = y
        self.v[2] = z

    def dot(self, other):
        return self.v[0] * other.v[0] + self.v[1] * other.v[1] + self.v[2] * other.v[2]

    def __add__(self, other):
        return vec3(self.v[0] + other.v[0], self.v[1] + other.v[1], self.v[2] + other.v[2])

    def __sub__(self, other):
        return vec3(self.v[0] - other.v[0], self.v[1] - other.v[1], self.v[2] - other.v[2])

    def __mul__(self, other):
        if isinstance(other, vec3):
            return vec3(self.v[0] * other.v[0], self.v[1] * other.v[1], self.v[2] * other.v[2])
        elif isinstance(other, (float, int)):
            return vec3(self.v[0] * other, self.v[1] * other, self.v[2] * other)
        else:
            raise ValueError(other)

    def __rtruediv__(self, other):
        assert isinstance(other, float), other
        return vec3(other / (self.v[0] or 1), other / (self.v[1] or 1), other / (self.v[2] or 1))

    def cross(self, other):
        assert isinstance(other, vec3)
        return vec3(self.v[1] * other.v[2] - self.v[2] * other.v[1],
                    self.v[2] * other.v[0] - self.v[0] * other.v[2],
                    self.v[0] * other.v[1] - self.v[1] * other.v[0])

    def reflect(I, N):
        assert isinstance(N, vec3)
        return I - 2 * N.dot(I) * N

class mat4(Structure):
    _fields_ = [
        ('m', c_float * 16),
    ]

    def __mul__(self, other):
        assert isinstance(other, vec3), other
        v = tuple(other.v) + (1,)
        v = [sum(v[row] * self.m[i + row * 4] for row in range(4)) for i in range(4)]
        return vec3(v[0] / v[3], v[1] / v[3], v[2] / v[3])


class RasterizeSettings(Structure):
    _fields_ = [
        ('width', c_uint16),
        ('height', c_uint16),
        ('num_threads', c_int),
        ('num_samples', c_int),
        ('num_shadow_samples', c_int),
        ('primary_reflection_samples', c_int),
        ('secondary_reflection_samples', c_int),
        ('tile_size', c_uint),
        ('postprocess', c_uint),

        ('fovy', c_float),

        ('projection', mat4),
        ('view', mat4),
        ('ndc_to_world', mat4),

        ('ambient', vec3),

        ('dof_strength', c_float),
        ('dof_distance_divider', c_float),

        ('near_plane', c_float),
        ('far_plane', c_float),

        ('vignette_strength', c_float),
        ('vignette_distance_scale', c_float),
        ('vignette_distance_exp', c_float),

        ('film_noise_strength', c_float),

        ('chromatic_aberration_factor', c_float),
        ('chromatic_aberration_exp', c_float),

        ('plug_width', c_uint32),
        ('plug_height', c_uint32),
        ('plug_border', c_float),
        ('plug_begin', c_float),
        ('plug_blend', c_float),
    ]

class AABB(Structure):
    _fields_ = [
        ('min', vec3),
        ('max', vec3),
    ]


class Material(Structure):
    _fields_ = [
        ('diffuse', c_float),

        ('reflective', c_float),
        ('shininess', c_float),
        ('opacity', c_float),

        ('color', vec3),
    ]

class SDFPlaneData(Structure):
    _fields_ = [
        ('origin', vec3),
        ('normal', vec3),
    ]


class SDFLineData(Structure):
    _fields_ = [
        ('p0', vec3),
        ('thickness', c_float),
        ('p1', vec3),
        ('falloff', c_float),
    ]

class SDFSphereData(Structure):
    _fields_ = [
        ('center', vec3),
        ('radius', c_float),
    ]

class SDFLogicOpData(Structure):
    _fields_ = [
        ('a_id', c_int32),
        ('b_id', c_int32),
        ('op', c_uint32),
    ]

class MeshData(Structure):
    _fields_ = [
        ('first_triangle', c_int32),
        ('last_triangle', c_int32),
        ('transform_id', c_int32),
    ]

class ObjectData(Union):
    _fields_ = [
        ('sdf_plane', SDFPlaneData),
        ('sdf_line', SDFLineData),
        ('sdf_sphere', SDFSphereData),
        ('sdf_logic_op', SDFLogicOpData),
        ('mesh', MeshData),
    ]

class Object(Structure):
    _fields_ = [
        ('aabb', AABB),
        ('material', Material),

        ('type', c_uint8),
        ('ignore', c_uint8),
        ('padding0', c_uint8),
        ('padding1', c_uint8),

        ('data', ObjectData),
    ]

    def get_data(self):
        return getattr(self.data, ObjectData._fields_[self.type-1][0])

    def sdf(self, db: 'Database', pos: vec3) -> float:
        if self.type == SDF_PLANE:
            sdf_plane = self.get_data()
            origin_to_pos = pos - sdf_plane.origin
            return sdf_plane.normal.dot(origin_to_pos)
        elif self.type == SDF_LINE:
            sdf_line = self.get_data()
            h = (pos - sdf_line.p0).dot(sdf_line.p1 - sdf_line.p0) / (sdf_line.p1 - sdf_line.p0).length2()
            h_clamped = min(1, max(0, h))

            q = sdf_line.p0 + (sdf_line.p1 - sdf_line.p0) * h_clamped

            thick = sdf_line.thickness * (1 - h_clamped * sdf_line.falloff)

            return (pos - q).length() - thick
        elif self.type == SDF_SPHERE:
            sdf_sphere = self.get_data()
            return (pos - sdf_sphere.center).length() - sdf_sphere.radius
        elif self.type == SDF_LOGIC:
            sdf_logic_op = self.get_data()
            d1 = db.objects[sdf_logic_op.a_id].sdf(db, pos)
            d2 = db.objects[sdf_logic_op.b_id].sdf(db, pos)

            if sdf_logic_op.op == OP_INTERSECT:
                return max(d1, d2)
            elif sdf_logic_op.op == OP_UNION:
                return min(d1, d2)
            elif sdf_logic_op.op == OP_DIFFERENCE:
                return max(d1, -d2)

        return sys.float_info.max

    def intersect(self, db: 'Database', ray: 'Ray') -> 'IntersectionInfo':
        if not ray.intersect(self.aabb):
            return IntersectionInfo()

        if self.type in (SDF_PLANE, SDF_LINE, SDF_SPHERE, SDF_LOGIC):
            enters = self.sdf(db, ray.origin) > 0.01

            if not enters:
                return IntersectionInfo()

            depth = 0
            max_depth = ray.direction.length()
            MAX_MARCHING_STEPS = 128
            for i in range(MAX_MARCHING_STEPS):
                p = ray.origin + depth * ray.direction_normalized
                dist = abs(self.sdf(db, p))
                if dist < 0.0001:
                    EPSILON = 0.001
                    normal = vec3(
                        self.sdf(db, p + vec3(EPSILON, 0, 0)) - self.sdf(db, p - vec3(EPSILON, 0, 0)),
                        self.sdf(db, p + vec3(0, EPSILON, 0)) - self.sdf(db, p - vec3(0, EPSILON, 0)),
                        self.sdf(db, p + vec3(0, 0, EPSILON)) - self.sdf(db, p - vec3(0, 0, EPSILON))
                    ).normalize()

                    return IntersectionInfo.create(enters, p, normal, self.material)

                depth += dist
                if depth >= max_depth:
                    return IntersectionInfo()

            return IntersectionInfo()
        elif self.type == MESH:
            best = IntersectionInfo()

            mesh = self.get_data()

            for j in range(mesh.first_triangle, mesh.last_triangle+1):
                triangle = db.triangles[j]

                candidate = triangle.intersect(ray)
                if ray.is_closer(candidate, best):
                    best = candidate

            if best.intersects:
                best.material = self.material

            return best
        else:
            raise ValueError(self.type)

        return IntersectionInfo()

class PointLight(Structure):
    _fields_ = [
        ('position', vec3),
        ('radius', c_float),
        ('color', vec3),
        ('falloff', c_float),
    ]

class BVHGroup(Structure):
    _fields_ = [
        ('first', c_int16),
        ('last', c_int16),
        ('aabb', AABB),
    ]

class TriangleData(Structure):
    _fields_ = [
        ('a', vec3),
        ('b', vec3),
        ('c', vec3),

        ('a_to_b', vec3),
        ('a_to_c', vec3),
        ('normal', vec3),
    ]

    def inflate(self):
        self.a_to_b = (self.b - self.a)
        self.a_to_c = (self.c - self.a)
        self.normal = self.a_to_c.cross(self.a_to_b).normalize()

    def intersect(self, ray: 'Ray') -> 'IntersectionInfo':
        enters = ray.direction_normalized.dot(self.normal) < 0
        if not enters:
            return IntersectionInfo()

        u = self.a_to_b
        v = self.a_to_c
        n = self.normal

        direction = ray.direction

        w0 = ray.origin - self.a
        a = -n.dot(w0)
        b = n.dot(direction)

        if abs(b) < 0.0001:
            return IntersectionInfo()

        r = a / b

        if r < 0 or r > 1:
            return IntersectionInfo()

        result = ray.origin + r * direction

        uu = u.dot(u)
        uv = u.dot(v)
        vv = v.dot(v)

        w = result - self.a
        wu = w.dot(u)
        wv = w.dot(v)

        d = uv * uv - uu * vv

        s = (uv * wv - vv * wu) / d
        if s < 0 or s > 1:
            return IntersectionInfo()

        t = (uv * wu - uu * wv) / d
        if t < 0 or t > 1:
            return IntersectionInfo()

        if s + t > 1:
            return IntersectionInfo()

        return IntersectionInfo.create(enters, result, self.normal, Material())


class ByteData(object):
    def __init__(self, d):
        self.d = d

    def read(self, struct_type):
        result = struct_type()
        memmove(pointer(result), self.d, sizeof(result))
        self.d = self.d[sizeof(result):]
        return result

    def read_array(self, struct_type):
        count = self.read(c_int).value
        return self.read(struct_type * count)


class Database(object):
    def __init__(self, d):
        self.scene_bytes = d
        d = ByteData(d)
        self.settings = d.read(RasterizeSettings)
        self.objects = d.read_array(Object)
        self.point_lights = d.read_array(PointLight)
        self.bvh = d.read_array(BVHGroup)
        self.meta = d.read_array(BVHGroup)
        self.triangles = d.read_array(TriangleData)
        for triangle in self.triangles:
            triangle.inflate()
        self.random_directions = d.read_array(vec3)
        self.plug = d.read_array(c_uint8)
        assert len(d.d) == 0

    def query(self, ray: 'Ray') -> 'IntersectionInfo':
        nearest = IntersectionInfo()

        for mta in self.meta:
            if not ray.intersect(mta.aabb):
                continue

            for i in range(mta.first, mta.last+1):
                grp = self.bvh[i]

                if not ray.intersect(grp.aabb):
                    continue

                for j in range(grp.first, grp.last+1):
                    obj = self.objects[j]

                    if obj.ignore:
                        continue

                    intersection = obj.intersect(self, ray)
                    if ray.is_closer(intersection, nearest):
                        nearest = intersection

        return nearest

    def fuzz_direction_vector(self, vector: vec3, alpha: float) -> vec3:
        if alpha == 1.:
            return vector

        return (alpha * vector + (1 - alpha) * random.choice(self.random_directions)).normalize()


class ColorShader(object):
    def __init__(self, db: Database):
        self.db = db

    def primary(self, ray: 'Ray', closest: 'IntersectionInfo') -> vec3:
        result = vec3(0, 0, 0)

        if closest.material.diffuse:
            result += self.diffuse_component(ray, closest) * closest.material.diffuse

        if closest.material.reflective:
            result += self.reflective_primary(ray, closest) * closest.material.reflective

        return result

    def diffuse_component(self, ray: 'Ray', closest: 'IntersectionInfo') -> vec3:
        intensity = self.db.settings.ambient

        for light in self.db.point_lights:
            to_light_with_magnitude = light.position - closest.point
            distance = to_light_with_magnitude.length()
            if distance > light.radius:
                continue

            to_light = to_light_with_magnitude / distance

            in_shadow = 0
            for i in range(self.db.settings.num_shadow_samples):
                n = self.db.fuzz_direction_vector(to_light, 0.9)

                orig = closest.point + 0.01 * closest.normal

                if self.db.query(Ray(orig, orig + distance * n)).intersects:
                    in_shadow += 1

            if in_shadow != self.db.settings.num_shadow_samples:
                factor = smoothstep(light.radius, light.radius * (1 - light.falloff), distance)
                shadow_factor = (1 - (in_shadow / self.db.settings.num_shadow_samples))
                intensity += light.color * factor * max(0, closest.normal.dot(to_light)) * shadow_factor

        return closest.material.color * intensity

    def reflective_primary(self, ray: 'Ray', closest: 'IntersectionInfo') -> vec3:
        alpha = 0.5 + 0.5 * closest.material.shininess

        I = ray.direction_normalized
        N = closest.normal.normalize()
        R = I.reflect(N)

        color = vec3(0, 0, 0)
        for i in range(self.db.settings.primary_reflection_samples):
            n = self.db.fuzz_direction_vector(R, alpha)

            orig = closest.point + 0.001 * N

            reflected_ray = Ray(orig, orig + 100 * n)
            info = self.db.query(reflected_ray)

            sample = closest.material.color / 2
            if info.intersects:
                if info.material.diffuse:
                    sample += self.diffuse_component(reflected_ray, info) * info.material.diffuse
                if info.material.reflective:
                    sample += self.reflective_secondary(reflected_ray, info) * info.material.reflective

            color += sample

        return ((closest.material.opacity * closest.material.color) +
                (1 - closest.material.opacity) * (color / self.db.settings.primary_reflection_samples))

    def reflective_secondary(self, ray: 'Ray', closest: 'IntersectionInfo') -> vec3:
        alpha = 0.5 + 0.5 * closest.material.shininess

        I = ray.direction_normalized
        N = closest.normal.normalize()
        R = I.reflect(N)

        color = vec3(0, 0, 0)

        for i in range(self.db.settings.secondary_reflection_samples):
            n = self.db.fuzz_direction_vector(R, alpha)

            orig = closest.point + 0.001 * N

            reflected_ray = Ray(orig, orig + 100 * n)
            info = self.db.query(reflected_ray)

            sample = closest.material.color / 2
            if info.intersects:
                sample += self.diffuse_component(reflected_ray, info) * info.material.diffuse
                sample += info.material.color * info.material.reflective

            color += sample

        return ((closest.material.opacity * closest.material.color) +
                (1 - closest.material.opacity) * (color / self.db.settings.primary_reflection_samples))


class PixelGrid(object):
    def __init__(self, db: Database, w: int, h: int):
        self.w = w
        self.h = h
        self.pixels = (vec3 * (w * h))()
        self.db = db

    def pack(self, x0, y0, x1, y1):
        pixeldata = array.array('f', (v for y in range(y0, y1)
                                        for x in range(x0, x1)
                                        for v in self.pixels[y * self.w + x].v))
        return struct.pack('<IIIII', zlib.crc32(self.db.scene_bytes), x0, y0, x1, y1) + pixeldata.tobytes()

    def unpack(self, packed):
        scene_hash, x0, y0, x1, y1 = struct.unpack('<IIIII', packed[:20])
        assert scene_hash == zlib.crc32(self.db.scene_bytes), 'Chunk is for a different scene'
        pixeldata = iter(array.array('f', packed[20:]))
        for y in range(y0, y1):
            for x in range(x0, x1):
                self.pixels[y * self.w + x].v = (next(pixeldata), next(pixeldata), next(pixeldata))

    def sample_bilinear(self, pos: vec2) -> vec3:
        x0 = max(0, min(self.w - 1, int(pos.v[0])))
        y0 = max(0, min(self.h - 1, int(pos.v[1])))

        alpha_x = pos.v[0] % 1
        alpha_y = pos.v[1] % 1

        x1 = min(self.w - 1, x0 + 1)
        y1 = min(self.h - 1, y0 + 1)

        p00 = self.pixels[y0*self.w+x0]
        p01 = self.pixels[y1*self.w+x0]
        p10 = self.pixels[y0*self.w+x1]
        p11 = self.pixels[y1*self.w+x1]

        p_0 = p00 * (1 - alpha_x) + p10 * alpha_x
        p_1 = p01 * (1 - alpha_x) + p11 * alpha_x

        return p_0 * (1 - alpha_y) + p_1 * alpha_y

    def blit(self, shade):
        for y in range(self.h):
            for x in range(self.w):
                self.pixels[y*self.w+x] = shade(vec2(x / (self.w-1), y / (self.h-1)))

class Framebuffer(PixelGrid):
    def __init__(self, db: Database, w: int=None, h: int=None):
        super().__init__(db, w or db.settings.width, h or db.settings.height)
        self.shader = ColorShader(db)

    def window_to_ndc(self, x: float, y: float, z: float) -> vec3:
        x = 2 * (x / self.w) - 1
        y = 1 - 2 * (y / self.h)
        return vec3(x, y, z)

    def postprocess(self) -> 'vec3 * (self.w * self.h)':
        out = (vec3 * (self.w * self.h))()
        for i in range(self.w * self.h):
            out[i] = self.pixels[i]

        if self.db.settings.postprocess:
            center = vec2(self.w / 2, self.h / 2)
            max_dist = center.length()

            for y in range(self.h):
                for x in range(self.w):
                    delta = vec2(x, y) - center
                    delta_dir = delta.normalize()
                    distance = delta.length() / max_dist

                    aberration = (pow(distance, self.db.settings.chromatic_aberration_exp) *
                                  self.db.settings.chromatic_aberration_factor)

                    nx = x + delta_dir.v[0] * aberration * 0.1
                    ny = y + delta_dir.v[1] * aberration * 0.05
                    out[y*self.w+x].v[2] = self.sample_bilinear(vec2(nx, ny)).v[2]

                    nx = x - delta_dir.v[0] * aberration * 0.05
                    ny = y + delta_dir.v[1] * aberration * 0.1
                    out[y*self.w+x].v[0] = self.sample_bilinear(vec2(nx, ny)).v[0]

                    nx = x + delta_dir.v[0] * aberration * 0.1
                    ny = y - delta_dir.v[1] * aberration * 0.05
                    out[y*self.w+x].v[1] = self.sample_bilinear(vec2(nx, ny)).v[1]

                    out[y*self.w+x] *= (1 - self.db.settings.vignette_strength *
                                            pow(self.db.settings.vignette_distance_scale * distance,
                                                self.db.settings.vignette_distance_exp))

                    out[y*self.w+x] *= 1 - self.db.settings.film_noise_strength * random.uniform(-1, +1)

            smaller = PixelGrid(self.db, int(self.w / 2), int(self.h / 2))

            def threshold(uv: vec2) -> vec3:
                color = self.sample_bilinear(uv * vec2(self.w - 0.5, self.h - 0.5)) * 0.1
                if color.length() < 0.1:
                    color *= 0
                else:
                    color *= 10
                return color

            smaller.blit(threshold)

            blur_tmp = PixelGrid(self.db, smaller.w, smaller.h)

            weights = [1 / 4, 2 / 4, 1 / 4]
            spread = 1.8 * self.w / 320
            offsets = [-spread, 0, +spread]
            rotate = 80 / 180 * math.pi
            kernel = vec2(0, 1)
            sz = vec2(smaller.w - 1, smaller.h - 1)

            for k in range(8):
                def blur(uv: vec2) -> vec3:
                    result = vec3(0, 0, 0)
                    for i in range(3):
                        result += src.sample_bilinear(sz * uv + kernel * offsets[i]) * weights[i]
                    return result

                src = smaller
                blur_tmp.blit(blur)
                kernel = kernel.rotate(rotate)

                src = blur_tmp
                smaller.blit(blur)
                kernel = kernel.rotate(rotate)

            for y in range(self.h):
                for x in range(self.w):
                    out[y*self.w+x] += smaller.sample_bilinear(vec2(x - 0.5, y - 0.5) / 2) * 3.3

            plug_w = self.db.settings.plug_width
            plug_h = self.db.settings.plug_height

            if plug_w > 0 and plug_h > 0:
                pg = PixelGrid(self.db, plug_w, plug_h)
                for y in range(plug_h):
                    for x in range(plug_w):
                        intensity = self.db.plug[y*plug_w+x] / 255
                        pg.pixels[y*pg.w+x] = vec3(intensity, intensity, intensity)

                for y in range(self.h):
                    for x in range(self.w):
                        sx = x / (self.w - 1)
                        sx += self.db.settings.plug_border - self.db.settings.plug_begin
                        sx *= 1 / (1 - self.db.settings.plug_begin)

                        sy = 1 - y / (self.h - 1)
                        sy += self.db.settings.plug_border - self.db.settings.plug_begin
                        sy *= 1 / (1 - self.db.settings.plug_begin)

                        intensity = pg.sample_bilinear(vec2(sx * plug_w, (1 - sy) * plug_h)).v[0]
                        out[y*self.w+x] += vec3(1., 1., 1.) * intensity * self.db.settings.plug_blend

        return out

    def to_rgb(self) -> bytearray:
        buf = bytearray(3 * self.w * self.h)

        postprocessed = self.postprocess()

        for i in range(self.w * self.h):
            for j in range(3):
                buf[i*3+j] = int(0xFF * max(0, min(1, postprocessed[i].v[j])))

        return buf

    def save(self, ppm_filename: str):
        with open(ppm_filename, 'wb') as fp:
            fp.write(f'P6\n{self.w} {self.h}\n255\n'.encode())
            fp.write(self.to_rgb())

    def rasterize(self, x0: int, y0: int, x1: int, y1: int):
        topleft = self.db.settings.ndc_to_world * self.window_to_ndc(0, 0, 0)
        bottomright = self.db.settings.ndc_to_world * self.window_to_ndc(self.w-1, self.h-1, 0)
        screen_size_worldspace = bottomright - topleft
        for y in range(y0, y1):
            for x in range(x0, x1):
                color = vec3(0, 0, 0)

                for i in range(self.db.settings.num_samples):
                    aa = self.db.settings.ndc_to_world * self.window_to_ndc(x, y, 0)
                    bb = self.db.settings.ndc_to_world * self.window_to_ndc(x + 0.5 * random.uniform(-1, +1),
                                                                       y + 0.5 * random.uniform(-1, +1), 1)

                    if self.db.settings.dof_strength:
                        fuzz: vec3 = screen_size_worldspace * vec3(random.uniform(-1, +1),
                                                                   random.uniform(-1, +1),
                                                                   random.uniform(-1, +1)) * self.db.settings.dof_strength

                        aa += fuzz
                        bb -= fuzz * self.db.settings.dof_distance_divider

                    ray = Ray(aa, bb)
                    closest = self.db.query(ray)
                    if closest.intersects:
                        color += self.shader.primary(ray, closest)

                color /= self.db.settings.num_samples

                self.pixels[y*self.w + x] = color

    def get_tasks(self):
        horizontal_tiles = int((self.w  + self.db.settings.tile_size - 1) / self.db.settings.tile_size)
        vertical_tiles = int((self.h + self.db.settings.tile_size - 1) / self.db.settings.tile_size)

        for ty in range(vertical_tiles):
            for tx in range(horizontal_tiles):
                x0 = tx * self.db.settings.tile_size
                y0 = ty * self.db.settings.tile_size
                x1 = min(self.w, x0 + self.db.settings.tile_size)
                y1 = min(self.h, y0 + self.db.settings.tile_size)
                yield (self.db.scene_bytes, self.w, self.h, x0, y0, x1, y1)

    def rasterize_parallel(self):
        with multiprocessing.Pool() as pool:
            tasks = list(self.get_tasks())
            total = len(tasks)
            done = 0
            started = time.perf_counter()
            print(end=f'\r{done} / {total} done')
            for result in pool.imap_unordered(rasterize_chunk, tasks):
                self.unpack(result)
                done += 1
                elapsed = time.perf_counter() - started
                print(end=f'\r{done} / {total} done ({elapsed/done:5.1f} secs / tile)')
                sys.stdout.flush()
            print()

class IntersectionInfo(object):
    def __init__(self):
        self.intersects = False
        self.enters = False
        self.point = vec3(0., 0., 0.)
        self.normal = vec3(0., 0., 0.)
        self.material = Material()

    @classmethod
    def create(cls, enters: bool, point: vec3, normal: vec3, material: Material):
        self = cls()
        self.intersects = True
        self.enters = enters
        self.point = point
        self.normal = normal
        self.material = material
        return self


class Ray(object):
    def __init__(self, origin: vec3, end: vec3):
        self.origin = origin
        self.direction = end - origin
        self.direction_normalized = self.direction.normalize()
        self.inverse_direction = 1. / self.direction_normalized
        self.inverse_direction_sign = (int(self.inverse_direction.v[0] < 0),
                                       int(self.inverse_direction.v[1] < 0),
                                       int(self.inverse_direction.v[2] < 0))

    def intersect(self, aabb: AABB) -> bool:
        invdir = self.inverse_direction
        sign = self.inverse_direction_sign

        bounds = [aabb.min, aabb.max]

        tmin = (bounds[sign[0]].v[0] - self.origin.v[0]) * invdir.v[0]
        tmax = (bounds[1-sign[0]].v[0] - self.origin.v[0]) * invdir.v[0]

        tymin = (bounds[sign[1]].v[1] - self.origin.v[1]) * invdir.v[1]
        tymax = (bounds[1-sign[1]].v[1] - self.origin.v[1]) * invdir.v[1]

        if tmin > tymax or tymin > tmax:
            return False

        if tymin > tmin:
            tmin = tymin

        if tymax < tmax:
            tmax = tymax

        tzmin = (bounds[sign[2]].v[2] - self.origin.v[2]) * invdir.v[2]
        tzmax = (bounds[1-sign[2]].v[2] - self.origin.v[2]) * invdir.v[2]

        if tmin > tzmax or tzmin > tmax:
            return False

        return True

    def is_closer(self, candidate: 'IntersectionInfo', best: 'IntersectionInfo') -> bool:
        return (candidate.intersects and candidate.enters and
                (not best.intersects or
                 ((candidate.point - self.origin).length2() < (best.point - self.origin).length2())))


def rasterize_chunk(task):
    scene_bytes, w, h, x0, y0, x1, y1 = task

    filename = f'traced.{zlib.crc32(scene_bytes):08x}-{w}x{h}-{x0}-{x1}-{y0}-{y1}.chunk'
    cache_filename = os.path.join('cache', filename)
    if os.path.exists(cache_filename):
        with open(cache_filename, 'rb') as fp:
            return fp.read()

    db = Database(scene_bytes)
    fb = Framebuffer(db, w, h)
    fb.rasterize(x0, y0, x1, y1)

    result_bytes = fb.pack(x0, y0, x1, y1)

    os.makedirs('cache', exist_ok=True)
    with open(cache_filename, 'wb') as fp:
        fp.write(result_bytes)

    return result_bytes


def main(filename):
    with open(filename, 'rb') as fp:
        db = Database(fp.read())

    started = time.perf_counter()

    fb = Framebuffer(db)
    fb.rasterize_parallel()
    fb.save(filename + '.ppm')

    print(f'{filename} took {time.perf_counter()-started:.1f} seconds')


def example(filename):
    # Make a Framebuffer object
    with open(filename, 'rb') as fp:
        db = Database(fp.read())

    # width and height are optional parameters
    fb = Framebuffer(db)

    # Query all tasks that need to be done
    tasks = list(fb.get_tasks())

    # How many tasks are there?
    print(len(tasks), 'tasks')

    # Pick just a single task for processing (for demo/testing/profiling purposes)
    task = tasks[974]

    # For easier integration into multiprocessing, rasterize_chunk() is a global
    # function that takes a task as input and returns a bunch of bytes. As a side
    # effect, it also writes the bytes into "./cache/" and if a cached file is
    # found, it reads and returns the file from the cache.
    task_result = rasterize_chunk(task)

    # The return value of rasterize_chunk() can be easily re-integrated into the
    # Framebuffer object by just passing it to the .unpack() function.
    fb.unpack(task_result)

    # Postprocess and render the final picture (of course, this will only do
    # the right thing if we actually have all the tasks processed and unpacked
    # into the framebuffer, otherwise tiles will be missing)
    #fb.save('out.ppm')


if __name__ == '__main__':
    filename = sys.argv[1] if len(sys.argv) == 2 else 'scene.dat'

    #started = time.perf_counter()
    #example(filename)
    #duration = time.perf_counter() - started
    #print(f'Running the example(filename={filename!r}) function took {duration:.2f} seconds')

    main(filename)
