#include #include #include #include "tree.h" static struct expr_node *alloc_node(void) { return malloc(sizeof(struct expr_node)); } struct expr_node *node_create_const(double val) { struct expr_node *node = alloc_node(); if (!node) return NULL; node->type = EXPR_CONST; node->vals.num = val; return node; } struct expr_node *node_create_neg(struct expr_node *unop) { struct expr_node *node = alloc_node(); if (!node) return NULL; node->type = EXPR_NEG; node->vals.unop = unop; return node; } static struct expr_node *create_binary_node(enum expr_type type, struct expr_node *left, struct expr_node *right) { struct expr_node *node = alloc_node(); if (!node) return NULL; node->type = type; node->vals.binop.left = left; node->vals.binop.right = right; return node; } struct expr_node *node_create_add(struct expr_node *left, struct expr_node *right) { return create_binary_node(EXPR_ADD, left, right); } struct expr_node *node_create_sub(struct expr_node *left, struct expr_node *right) { return create_binary_node(EXPR_SUB, left, right); } struct expr_node *node_create_mult(struct expr_node *left, struct expr_node *right) { return create_binary_node(EXPR_MULT, left, right); } struct expr_node *node_create_div(struct expr_node *left, struct expr_node *right) { return create_binary_node(EXPR_DIV, left, right); } struct expr_node *node_create_pow(struct expr_node *base, struct expr_node *power) { return create_binary_node(EXPR_POW, base, power); } struct expr_node *node_create_x(void) { struct expr_node *node = alloc_node(); if (!node) return NULL; node->type = EXPR_X; return node; } struct expr_node *node_create_fn(enum math_fn fn, struct expr_node *arg) { struct expr_node *node = alloc_node(); if (!node) return NULL; node->type = EXPR_FN; node->vals.fn.fn = fn; node->vals.fn.arg = arg; return node; } void node_free(struct expr_node *node) { if (!node) return; switch (node->type) { case EXPR_ADD: case EXPR_SUB: case EXPR_MULT: case EXPR_DIV: case EXPR_POW: node_free(node->vals.binop.left); node_free(node->vals.binop.right); break; case EXPR_NEG: node_free(node->vals.unop); break; case EXPR_FN: node_free(node->vals.fn.arg); break; default: break; } free(node); } static void debug_indent(int indent) { int i; for (i = 0; i < indent; ++i) printf(" "); } static void debug_print(struct expr_node *node, int indent); static void debug_print_binop(struct expr_node *node, const char* name, int indent) { debug_indent(indent); printf("[%s]\n", name); /*debug_indent(indent); printf("left:\n");*/ debug_print(node->vals.binop.left, indent + 1); /*debug_indent(indent); printf("right:\n");*/ debug_print(node->vals.binop.right, indent + 1); } static void debug_print(struct expr_node *node, int indent) { static const char* fn_str[] = { "FN_ABS", "FN_EXP", "FN_LN", "FN_LOG", "FN_SIN", "FN_COS", "FN_TAN", "FN_ASIN", "FN_ACOS", "FN_ATAN", "FN_SINH", "FN_COSH", "FN_TANH" }; switch (node->type) { case EXPR_ADD: debug_print_binop(node, "ADD", indent); break; case EXPR_SUB: debug_print_binop(node, "SUB", indent); break; case EXPR_MULT: debug_print_binop(node, "MULT", indent); break; case EXPR_DIV: debug_print_binop(node, "DIV", indent); break; case EXPR_POW: debug_print_binop(node, "POW", indent); break; case EXPR_NEG: debug_indent(indent); printf("[NEG]\n"); /*debug_indent(indent); printf("unop:\n");*/ debug_print(node->vals.unop, indent + 1); break; case EXPR_CONST: debug_indent(indent); printf("[CONST] %.2f\n", node->vals.num); break; case EXPR_X: debug_indent(indent); printf("[X]\n"); break; case EXPR_FN: debug_indent(indent); printf("[FN] %s\n", fn_str[node->vals.fn.fn]); /*debug_indent(indent); printf("arg:\n");*/ debug_print(node->vals.fn.arg, indent + 1); break; default: break; } } void node_debug_print(struct expr_node *node) { debug_print(node, 0); } double node_eval(struct expr_node *node, double x) { switch (node->type) { case EXPR_CONST: return node->vals.num; case EXPR_X: return x; case EXPR_NEG: return -node_eval(node->vals.unop, x); case EXPR_ADD: return node_eval(node->vals.binop.left, x) + node_eval(node->vals.binop.right, x); case EXPR_SUB: return node_eval(node->vals.binop.left, x) - node_eval(node->vals.binop.right, x); case EXPR_MULT: return node_eval(node->vals.binop.left, x) * node_eval(node->vals.binop.right, x); case EXPR_DIV: return node_eval(node->vals.binop.left, x) / node_eval(node->vals.binop.right, x); case EXPR_POW: return pow(node_eval(node->vals.binop.left, x), node_eval(node->vals.binop.right, x)); case EXPR_FN: { double inner = node_eval(node->vals.fn.arg, x); switch (node->vals.fn.fn) { case FN_ABS: return fabs(inner); case FN_EXP: return exp(inner); case FN_LN: return log(inner); case FN_LOG: return log10(inner); case FN_SIN: return sin(inner); case FN_COS: return cos(inner); case FN_TAN: return tan(inner); case FN_ASIN: return asin(inner); case FN_ACOS: return acos(inner); case FN_ATAN: return atan(inner); case FN_SINH: return sinh(inner); case FN_COSH: return cosh(inner); case FN_TANH: return tanh(inner); } } } return 0.0; }