
/*
 * blktool
 *
 * Copyright 2004 Jeff Garzik
 *
 * This software may be used and distributed according to the terms
 * of the GNU General Public License, incorporated herein by reference.
 *
 */

#include "blktool-config.h"

#include <sys/types.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/stat.h>
#include <sys/ioctl.h>
#include <fcntl.h>
#include <errno.h>
#include <getopt.h>

#include <linux/hdreg.h>
#include <linux/fs.h>
#include <linux/cdrom.h>
#include <linux/major.h>
#include <scsi/scsi.h>

#include "blktool.h"

#define         BLKI2OGRSTRAT   _IOR('2', 1, int)
#define         BLKI2OGWSTRAT   _IOR('2', 2, int)
#define         BLKI2OSRSTRAT   _IOW('2', 3, int)
#define         BLKI2OSWSTRAT   _IOW('2', 4, int)

static DEF_HANDLER_PROTO(bool);
static DEF_HANDLER_PROTO(class);
static DEF_HANDLER_PROTO(dev_read_ahead);
static DEF_HANDLER_PROTO(geometry);
static DEF_HANDLER_PROTO(id);
static DEF_HANDLER_PROTO(int);
static DEF_HANDLER_PROTO(media);
static DEF_HANDLER_PROTO(reread_part);
static DEF_HANDLER_PROTO(sector_sz);
static DEF_HANDLER_PROTO(standby);
static DEF_HANDLER_PROTO(version);
static DEF_HANDLER_PROTO(wcache);

static struct class_operations dummy_ops;
static dev_class_t dev_class = dc_unknown;
static int opt_force;

int blkdev = -1;
GPtrArray *flag_arr;
struct class_operations *ops = &dummy_ops;

static const char *dev_class_names[] = {
	"any",
	"ATA",
	"SCSI",
	"I2O",
};

#define IOCNAME(token) token, #token
#define DEF_HDIO(token) IOCNAME(HDIO_GET_##token), IOCNAME(HDIO_SET_##token)
#define DEF_BOOLSTR "off", "on"
#define DEF_BOOL(str) str, ct_bool, handle_bool

static struct bool_command bool_cmd_tbl[] = {
	{ { "defect-mgmt", ct_bool, handle_defect_mgmt, dc_ata,
	    0, NULL, IOCNAME(HDIO_DRIVE_CMD) }, DEF_BOOLSTR },
	{ { "dev-keep-settings", ct_bool, handle_dev_keep_settings, dc_ata,
	    0, NULL, 0, NULL }, DEF_BOOLSTR },
	{ { "dev-read-ahead", ct_bool, handle_dev_read_ahead, dc_any,
	    0, NULL, IOCNAME(HDIO_DRIVE_CMD) }, DEF_BOOLSTR },
	{ { DEF_BOOL("dma"), dc_ata, DEF_HDIO(DMA) }, DEF_BOOLSTR },
	{ { DEF_BOOL("keep-settings"), dc_ata, DEF_HDIO(KEEPSETTINGS) },
	  DEF_BOOLSTR },
	{ { "media", ct_bool, handle_media, dc_any,
	    0, NULL, 0, NULL }, "unlock", "lock" },
	{ { DEF_BOOL("no-write-err"), dc_ata, DEF_HDIO(NOWERR) },
	  DEF_BOOLSTR },
	{ { DEF_BOOL("pio-data"), dc_ata, DEF_HDIO(32BIT) },
	  "16-bit", "32-bit" },
	{ { DEF_BOOL("readonly"), dc_any, IOCNAME(BLKROGET), IOCNAME(BLKROSET) },
	  DEF_BOOLSTR, bc_arg_int_ptr },
	{ { DEF_BOOL("unmask-irq"), dc_ata, DEF_HDIO(UNMASKINTR) },
	  DEF_BOOLSTR },
	{ { "wcache", ct_bool, handle_wcache, dc_any,
	    0, NULL, 0, NULL }, DEF_BOOLSTR },
};

#define DEF_INT(str) str, ct_int, handle_int

static struct command int_cmd_tbl[] = {
	{ DEF_INT("acoustic-mgmt"), dc_ata, DEF_HDIO(ACOUSTIC) },
	{ DEF_INT("block-sz"), dc_any, IOCNAME(BLKBSZGET), IOCNAME(BLKBSZSET) },
	{ DEF_INT("bus-state"), dc_ata, DEF_HDIO(BUSSTATE) },
	{ DEF_INT("cd-speed"), dc_any, 0, NULL, IOCNAME(CDROM_SELECT_SPEED) },
	{ DEF_INT("i2o-rcache"), dc_i2o, IOCNAME(BLKI2OGRSTRAT), IOCNAME(BLKI2OSRSTRAT) },
	{ DEF_INT("i2o-wcache"), dc_i2o, IOCNAME(BLKI2OGWSTRAT), IOCNAME(BLKI2OSWSTRAT) },
	{ DEF_INT("multiple-count"), dc_ata, DEF_HDIO(MULTCOUNT) },
	{ DEF_INT("pio-mode"), dc_ata, 0, NULL, IOCNAME(HDIO_SET_PIO_MODE) },
	{ "pm-mode", ct_int, handle_pm_mode, dc_ata, 0, NULL, IOCNAME(HDIO_DRIVE_CMD) },
	{ DEF_INT("queue-depth"), dc_ata, DEF_HDIO(QDMA) },
	{ DEF_INT("read-ahead"), dc_any, IOCNAME(BLKRAGET), IOCNAME(BLKRASET) },
};

#define DEF_VOID(str, handler) \
	str, ct_void, handle_##handler

static struct command void_cmd_tbl[] = {
	{ DEF_VOID("bus-id", bus_id), dc_scsi },
	{ DEF_VOID("class", class), dc_any },
	{ DEF_VOID("geometry", geometry), dc_any },
	{ DEF_VOID("id", id), dc_any },
	{ DEF_VOID("reread-part", reread_part), dc_any },
	{ DEF_VOID("reset", reset), dc_ata },
	{ DEF_VOID("sector-sz", sector_sz), dc_any },
	{ DEF_VOID("sleep", sleep), dc_ata },
	{ DEF_VOID("standby", standby), dc_any },
	{ DEF_VOID("version", version), dc_any },
};


static void handle_int(int argc, char **argv, struct command *cmd)
{
	int do_32;

	if (argc == 3) {
		if (cmd->read_ioctl_name == NULL) {
			fprintf(stderr, MSG_INVALID_CMD_NOREAD, cmd->cmd);
			exit(1);
		}

		do_32 = 0;
		if (ioctl(blkdev, cmd->read_ioctl, &do_32))
			pdie(cmd->read_ioctl_name, 1);

		printf("%d\n", do_32);

	} else if ((argc == 4) && (cmd->write_ioctl_name != NULL)) {
		do_32 = atoi(argv[optind]);
		if (ioctl(blkdev, cmd->write_ioctl, do_32))
			pdie(cmd->write_ioctl_name, 1);
	}
	else {
		fprintf(stderr, MSG_INVALID_CMD_ARGS, cmd->cmd);
		exit(1);
	}
}

static void handle_bool(int argc, char **argv, struct command *cmd)
{
	struct bool_command *bcm = (struct bool_command *) cmd;
	int do_32;

	if (argc == 3) {
		if (cmd->read_ioctl_name == NULL) {
			fprintf(stderr, MSG_INVALID_CMD_NOREAD, cmd->cmd);
			exit(1);
		}

		do_32 = 0;
		if (ioctl(blkdev, cmd->read_ioctl, &do_32))
			pdie(cmd->read_ioctl_name, 1);

		printf("%s\n", do_32 ? bcm->str_true : bcm->str_false);

	} else if ((argc == 4) && (cmd->write_ioctl_name != NULL)) {
		do_32 = parse_bool(argc, argv, bcm);
		
		int ret;
		if (bcm->arg_type == bc_arg_int_ptr) {
			ret = ioctl(blkdev, cmd->write_ioctl, &do_32);
		} else {
			ret = ioctl(blkdev, cmd->write_ioctl, do_32);
		}
		if (ret)
			pdie(cmd->write_ioctl_name, 1);
	}
	else {
		fprintf(stderr, MSG_INVALID_CMD_ARGS, cmd->cmd);
		exit(1);
	}
}

static void handle_version(int argc, char **argv, struct command *cmd)
{
	printf(MSG_VERSION, VERSION);
}

static void handle_geometry(int argc, char **argv, struct command *cmd)
{
	struct hd_geometry g;

	IOCTL(HDIO_GETGEO, &g);

	printf("head %u sect %u cyl %u start %lu\n",
	       g.heads,
	       g.sectors,
	       g.cylinders,
	       g.start);
}

static void generic_ro_int(int argc, char **argv, int cmd, const char *name)
{
	int val = 0;
	if (ioctl(blkdev, cmd, &val))
		pdie(name, 1);

	printf("%d\n", val);
}

static void handle_sector_sz(int argc, char **argv, struct command *cmd)
{
	generic_ro_int(argc, argv, BLKSSZGET, "BLKSSZGET");
}

static void handle_reread_part(int argc, char **argv, struct command *cmd)
{
	IOCTL(BLKRRPART, NULL);
}

static void handle_id(int argc, char **argv, struct command *cmd)
{
	if (ops->id)
		ops->id();
	else {
		fprintf(stderr, MSG_UNKNOWN_ID);
		exit(1);
	}
}

static void handle_class(int argc, char **argv, struct command *cmd)
{
	printf("%s\n", dev_class_names[dev_class]);
}

static void handle_standby(int argc, char **argv, struct command *cmd)
{
	if (ops->standby)
		ops->standby();
	else {
		fprintf(stderr, MSG_UNKNOWN_ID);
		exit(1);
	}
}

static void handle_media(int argc, char **argv, struct command *cmd)
{
	int do_32 = get_bool(argc, argv, cmd);

	if (ops->media)
		ops->media(do_32);
	else {
		fprintf(stderr, MSG_UNKNOWN_ID);
		exit(1);
	}
}

static void handle_dev_read_ahead(int argc, char **argv, struct command *cmd)
{
	int do_32 = get_bool(argc, argv, cmd);

	if (ops->read_ahead)
		ops->read_ahead(do_32);
	else {
		fprintf(stderr, MSG_UNKNOWN_ID);
		exit(1);
	}
}

static void handle_wcache(int argc, char **argv, struct command *cmd)
{
	int do_32 = get_bool(argc, argv, cmd);

	if (ops->wcache) {
		if (!opt_force) {
			/* flush blkdev buffers */
			IOCTL(BLKFLSBUF, NULL);
		}

		ops->wcache(do_32);
	} else {
		fprintf(stderr, MSG_UNKNOWN_ID);
		exit(1);
	}
}

static void detect_dev_class(dev_t st_rdev)
{
	switch (major(st_rdev)) {
	case IDE0_MAJOR:
	case IDE1_MAJOR:
	case IDE2_MAJOR:
	case IDE3_MAJOR:
	case IDE4_MAJOR:
	case IDE5_MAJOR:
	case IDETAPE_MAJOR:
		dev_class = dc_ata;
		break;

	case SCSI_DISK0_MAJOR:
	case SCSI_TAPE_MAJOR:
	case SCSI_CDROM_MAJOR:
	case SCSI_GENERIC_MAJOR:
	case SCSI_DISK1_MAJOR ... SCSI_DISK7_MAJOR:
	case SCSI_DISK8_MAJOR ... SCSI_DISK15_MAJOR:
		dev_class = dc_scsi;
		break;

	default:
		/* leave default dev_class value as-is */
		break;
	}
}

static int open_blkdev(const char *fn)
{
	struct stat st;
	int fd;

	fd = open(fn, O_RDONLY | O_NONBLOCK);
	if (fd < 0)
		pdie(fn, 1);

	if (fstat(fd, &st) < 0)
		pdie(fn, 1);

	if (!S_ISBLK(st.st_mode)) {
		fprintf(stderr, MSG_NOT_BLKDEV, fn);
		close(fd);
		exit(1);
	}

	if (dev_class == dc_unknown)
		detect_dev_class(st.st_rdev);

	return fd;
}

static void usage_bool(struct bool_command *bcmd)
{
	fprintf(stderr, "%-8s%s { %s | %s }\n",
		dev_class_names[bcmd->cmd.class_required],
		bcmd->cmd.cmd,
		bcmd->str_false,
		bcmd->str_true);
}

static void usage_void(struct command *cmd)
{
	fprintf(stderr, "%-8s%s\n",
		dev_class_names[cmd->class_required],
		cmd->cmd);
}

static void usage_int(struct command *cmd)
{
	fprintf(stderr, "%-8s%s nnn\n",
		dev_class_names[cmd->class_required],
		cmd->cmd);
}

static void usage(const char *progname)
{
	int i;

	fprintf(stderr, MSG_VERSION, VERSION);
	fprintf(stderr, "usage: %s [options] DEVICE COMMAND [args...]\n", progname);
	fprintf(stderr, "command list:\n");

	for (i = 0; i < ARRAY_SIZE(bool_cmd_tbl); i++)
		usage_bool(&bool_cmd_tbl[i]);
	for (i = 0; i < ARRAY_SIZE(int_cmd_tbl); i++)
		usage_int(&int_cmd_tbl[i]);
	for (i = 0; i < ARRAY_SIZE(void_cmd_tbl); i++)
		usage_void(&void_cmd_tbl[i]);

	exit(1);
}

int main (int argc, char *argv[])
{
	int i, ch;
	size_t cmdlen;
	struct command *cmd;
	const char *cmd_str, *blkdev_name, *prog_name = argv[0];

	while ((ch = getopt(argc, argv, "ft:")) != -1) {
		switch (ch) {
		case 'f':
			opt_force = 1;
			break;

		case 't':
			if (!strcmp(optarg, "ide"))
				dev_class = dc_ata;
			if (!strcmp(optarg, "ata"))
				dev_class = dc_ata;
			else if (!strcmp(optarg, "scsi"))
				dev_class = dc_scsi;
			else if (!strcmp(optarg, "i2o"))
				dev_class = dc_i2o;
			else if (!strcmp(optarg, "auto"))
				dev_class = dc_unknown;
			else
				usage(prog_name);
			break;
		default:
			usage(prog_name);
			break;
		}
	}

	if ((optind + 1) >= argc)
		usage(prog_name);

	if (!strcmp(argv[optind], "version")) {
		handle_version(argc, argv, NULL);
		exit(0);
	}
	if (((optind + 1) < argc) && !strcmp(argv[optind + 1], "version")) {
		handle_version(argc, argv, NULL);
		exit(0);
	}

	blkdev_name = argv[optind];
	blkdev = open_blkdev(blkdev_name);
	optind++;

	cmd_str = argv[optind];
	cmdlen = strlen(cmd_str);
	optind++;

	switch(dev_class) {
	case dc_ata:
		ata_init();
		break;
	case dc_scsi:
		scsi_init();
		break;
	default:
		break;
	}

	for (i = 0; i < ARRAY_SIZE(bool_cmd_tbl); i++) {
		cmd = &bool_cmd_tbl[i].cmd;
		if (!strcmp(cmd->cmd, cmd_str)) {
			cmd->handler(argc, argv, cmd);
			goto out;
		}
	}
	for (i = 0; i < ARRAY_SIZE(int_cmd_tbl); i++) {
		cmd = &int_cmd_tbl[i];
		if (!strcmp(cmd->cmd, cmd_str)) {
			cmd->handler(argc, argv, cmd);
			goto out;
		}
	}
	for (i = 0; i < ARRAY_SIZE(void_cmd_tbl); i++) {
		cmd = &void_cmd_tbl[i];
		if (!strcmp(cmd->cmd, cmd_str)) {
			cmd->handler(argc, argv, cmd);
			goto out;
		}
	}

	usage(prog_name);

out:
	close(blkdev);

	return 0;
}
