#include #include #include #include "tree.h" static struct expr_node *alloc_node(void) { return malloc(sizeof(struct expr_node)); } struct expr_node *node_create_const(double val) { struct expr_node *node = alloc_node(); if (!node) return NULL; node->type = EXPR_CONST; node->vals.num = val; return node; } struct expr_node *node_create_neg(struct expr_node *unop) { struct expr_node *node = alloc_node(); if (!node) return NULL; node->type = EXPR_NEG; node->vals.unop = unop; return node; } static struct expr_node *create_binary_node(enum expr_type type, struct expr_node *left, struct expr_node *right) { struct expr_node *node = alloc_node(); if (!node) return NULL; node->type = type; node->vals.binop.left = left; node->vals.binop.right = right; return node; } struct expr_node *node_create_add(struct expr_node *left, struct expr_node *right) { return create_binary_node(EXPR_ADD, left, right); } struct expr_node *node_create_sub(struct expr_node *left, struct expr_node *right) { return create_binary_node(EXPR_SUB, left, right); } struct expr_node *node_create_mult(struct expr_node *left, struct expr_node *right) { return create_binary_node(EXPR_MULT, left, right); } struct expr_node *node_create_div(struct expr_node *left, struct expr_node *right) { return create_binary_node(EXPR_DIV, left, right); } struct expr_node *node_create_pow(struct expr_node *base, struct expr_node *power) { return create_binary_node(EXPR_POW, base, power); } struct expr_node *node_create_x(void) { struct expr_node *node = alloc_node(); if (!node) return NULL; node->type = EXPR_X; return node; } struct expr_node *node_create_fn(size_t fn_idx, struct expr_node **args) { size_t i, num_args; struct expr_node *node = alloc_node(); if (!node) return NULL; node->type = EXPR_FN; node->vals.fn.fn_idx = fn_idx; num_args = fns_get()[fn_idx].num_args; for (i = 0; i < num_args; ++i) { node->vals.fn.args[i] = args[i]; } return node; } void node_free(struct expr_node *node) { if (!node) return; switch (node->type) { case EXPR_ADD: case EXPR_SUB: case EXPR_MULT: case EXPR_DIV: case EXPR_POW: node_free(node->vals.binop.left); node_free(node->vals.binop.right); break; case EXPR_NEG: node_free(node->vals.unop); break; case EXPR_FN: { size_t i, num_args = fns_get()[node->vals.fn.fn_idx].num_args; for (i = 0; i < num_args; ++i) { node_free(node->vals.fn.args[i]); } } break; default: break; } free(node); } static void debug_indent(int indent) { int i; for (i = 0; i < indent; ++i) printf(" "); } static void debug_print(struct expr_node *node, int indent); static void debug_print_binop(struct expr_node *node, const char* name, int indent) { debug_indent(indent); printf("[%s]\n", name); debug_print(node->vals.binop.left, indent + 1); debug_print(node->vals.binop.right, indent + 1); } static void debug_print(struct expr_node *node, int indent) { switch (node->type) { case EXPR_ADD: debug_print_binop(node, "ADD", indent); break; case EXPR_SUB: debug_print_binop(node, "SUB", indent); break; case EXPR_MULT: debug_print_binop(node, "MULT", indent); break; case EXPR_DIV: debug_print_binop(node, "DIV", indent); break; case EXPR_POW: debug_print_binop(node, "POW", indent); break; case EXPR_NEG: debug_indent(indent); printf("[NEG]\n"); debug_print(node->vals.unop, indent + 1); break; case EXPR_CONST: debug_indent(indent); printf("[CONST] %.2f\n", node->vals.num); break; case EXPR_X: debug_indent(indent); printf("[X]\n"); break; case EXPR_FN: { size_t i; const struct math_function *fn = &fns_get()[node->vals.fn.fn_idx]; debug_indent(indent); printf("[FN] %s\n", fn->name); for (i = 0; i < fn->num_args; ++i) { debug_print(node->vals.fn.args[i], indent + 1); } break; } default: break; } } void node_debug_print(struct expr_node *node) { debug_print(node, 0); } static void debug_print_gv(const struct expr_node *node, FILE *output); static void debug_print_binop_gv(const struct expr_node *node, FILE *output, const char *name) { fprintf(output, "node%p [label=\"%s\"]\n", (void*)node, name); debug_print_gv(node->vals.binop.left, output); debug_print_gv(node->vals.binop.right, output); fprintf(output, "node%p -> node%p [label=left]\n", (void*)node, (void*)node->vals.binop.left); fprintf(output, "node%p -> node%p [label=right]\n", (void*)node, (void*)node->vals.binop.right); } static void debug_print_gv(const struct expr_node *node, FILE *output) { switch (node->type) { case EXPR_ADD: debug_print_binop_gv(node, output, "ADD"); break; case EXPR_SUB: debug_print_binop_gv(node, output, "SUB"); break; case EXPR_MULT: debug_print_binop_gv(node, output, "MULT"); break; case EXPR_DIV: debug_print_binop_gv(node, output, "DIV"); break; case EXPR_POW: debug_print_binop_gv(node, output, "POW"); break; case EXPR_NEG: fprintf(output, "node%p [label=\"NEG\"]\n", (void*)node); debug_print_gv(node->vals.unop, output); fprintf(output, "node%p -> node%p [label=unop]\n", (void*)node, (void*)node->vals.unop); break; case EXPR_CONST: fprintf(output, "node%p [label=\"CONST: %.2f\"]\n", (void*)node, node->vals.num); break; case EXPR_X: fprintf(output, "node%p [label=\"X\"]\n", (void*)node); break; case EXPR_FN: { size_t i; const struct math_function *fn = &fns_get()[node->vals.fn.fn_idx]; fprintf(output, "node%p [label=\"FN: %s\"]\n", (void*)node, fn->name); for (i = 0; i < fn->num_args; ++i) { struct expr_node *arg = node->vals.fn.args[i]; debug_print_gv(arg, output); fprintf(output, "node%p -> node%p [label=arg%d]\n", (void*)node, (void*)arg, (int)i + 1); } break; } default: break; } } void node_debug_print_gv(const struct expr_node *node, FILE *output) { fprintf(output, "digraph G {\n"); debug_print_gv(node, output); fprintf(output, "}\n"); } double node_eval(const struct expr_node *node, double x) { switch (node->type) { case EXPR_CONST: return node->vals.num; case EXPR_X: return x; case EXPR_NEG: return -node_eval(node->vals.unop, x); case EXPR_ADD: return node_eval(node->vals.binop.left, x) + node_eval(node->vals.binop.right, x); case EXPR_SUB: return node_eval(node->vals.binop.left, x) - node_eval(node->vals.binop.right, x); case EXPR_MULT: return node_eval(node->vals.binop.left, x) * node_eval(node->vals.binop.right, x); case EXPR_DIV: return node_eval(node->vals.binop.left, x) / node_eval(node->vals.binop.right, x); case EXPR_POW: return pow(node_eval(node->vals.binop.left, x), node_eval(node->vals.binop.right, x)); case EXPR_FN: { double inner_results[MAX_MATH_FUNCTION_ARGS]; size_t i; const struct math_function *fn = &fns_get()[node->vals.fn.fn_idx]; for (i = 0; i < fn->num_args; ++i) { inner_results[i] = node_eval(node->vals.fn.args[i], x); } return fn->ptr(inner_results); } } return 0.0; }