/*
 * keyconv.c - Convert OpenSSL DER encoded RSA keys to alternative format
 * 
 * Copyright (C) 2004  Jochen Eisinger <jochen@penguin-breeder.org>
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 */
#include <stdlib.h>
#include <stdio.h>

unsigned char 	OID[] =	"\x30\x0d\x06\x09\x2a\x86\x48\x86\xf7\x0d\x01\x01\x01\x05\x00";
#define OID_LEN	15

/* */

/* primitive ASN1 DER parser.. enough for what we need */
int asn1_parse(unsigned char **in, unsigned char **data, int *type, int *len)
{
	int tmp = **in, ctr=2;

	/* we only can parse universal class */
	if (tmp & (0x80 + 0x40))
		return -1;

	/* we don't care whether it's primitive or constructed, as long as it's 
	 * definitive length
	 */
	*type = (tmp & 0x1f);

	/* we can't handle hight-tag numbers */
	if (*type == 0x1f)
		return -1;

	tmp = (*in)[1];
	*len = 0;
	if (tmp & 0x80) {
		tmp &= 0x7f;

		/* we only handle definitive length */
		if (tmp == 0)
			return -1;

		while (tmp-- > 0) {
			*len <<= 8;
			*len += (*in)[ctr++];
		}
	} else {
		*len = tmp;
	}
	*in += ctr;
	*data = *in;
	*in += *len;
	return 0;

}

unsigned char *asn1_enc(unsigned char *data, int type,int *len)
{
	unsigned char *tmp;
	int nlen = *len + 2,ctr=0;

	if (*len > 0x7f) {
		nlen = *len;
		while (nlen > 0) {
			ctr++;
			nlen >>= 8;
		}
		nlen = 2 + *len + ctr;
	}
	if ((tmp = malloc(nlen)) == NULL)
		return NULL;

	memcpy(tmp+2+ctr,data,*len);
	tmp[0] = type;
	if (*len < 0x7f)
		tmp[1] = *len;
	else {
		tmp[1] = 0x80 + ctr;
		
		while (ctr > 0) {
			tmp[1+ctr] = (*len & 0xff);
			*len >>= 8;
			ctr--;
		}
	}
	*len = nlen;
	return tmp;
}


/* invalid data may result in segmentation faults... */
int main(int argc, char **argv)
{
	FILE *fp;
	unsigned char *indata, *outdata, *keydata[9];
	long fsize, cur=0;
	int len, type, keylen[9], tmp;

	if ((fp = fopen(argv[1], "r")) == NULL) {
		fprintf(stderr, "could not open input file \"%s\"\n", argv[1]);
		return 1;
	}

	fseek(fp,0,SEEK_END);
	fsize = ftell(fp);
	fseek(fp,0,SEEK_SET);

	if ((indata = malloc(fsize)) == NULL) {
		fprintf(stderr, "not enough free memory\n");
		return 1;
	}

	while (cur < fsize)
		cur += fread(&indata[cur],1,fsize-cur,fp);

	fclose(fp);
		
	if (asn1_parse(&indata,&outdata,&type,&len)) {
		fprintf(stderr,"error parsing ASN.1 structure\n");
		return 1;
	}

	if (type != 16) {
		fprintf(stderr,"expected SEQUENCE, read %d\n", type);
		return 1;
	}

	indata = outdata;
	/* indata should now be 9 consecutive integers */
	for (cur=0 ; cur<9 ; cur++) {

		if (asn1_parse(&indata,&outdata,&type,&len)) {
			fprintf(stderr,"error parsing ASN.1 structure\n");
			return 1;
		}


		if (type != 2) {
			fprintf(stderr,"expect INTEGER, read %d\n", type);
			return 1;
		}

		tmp = len;
		keydata[cur] = asn1_enc(outdata,0x02,&tmp);

		if (keydata[cur]  == NULL) {
			fprintf(stderr,"not enough free memory\n");
			return 1;
		}
		keylen[cur] = tmp;
	}

	/* now construct the new key format */
	len = keylen[1] + keylen[2] + keylen[3] + keylen[4] + keylen[5] + keylen[8];
	
	if ((outdata = malloc(len)) == NULL) {
		fprintf(stderr,"not enough free memory\n");
		return 1;
	}

	memcpy(outdata,keydata[1],keylen[1]);
	memcpy(outdata+keylen[1],keydata[2],keylen[2]);
	memcpy(outdata+keylen[1]+keylen[2],keydata[3],keylen[3]);
	memcpy(outdata+keylen[1]+keylen[2]+keylen[3],keydata[4],keylen[4]);
	memcpy(outdata+keylen[1]+keylen[2]+keylen[3]+keylen[4],keydata[5],keylen[5]);
	memcpy(outdata+len-keylen[8],keydata[8],keylen[8]);

	indata = asn1_enc(outdata, 0x30, &len);

	if (indata == NULL) {
		fprintf(stderr,"not enough free memory\n");
		return 1;
	}

	if ((outdata = malloc(len + OID_LEN)) == NULL) {
		fprintf(stderr,"not enough free memory\n");
		return 1;
	}

	memcpy(outdata,OID,OID_LEN);
	memcpy(outdata+OID_LEN,indata,len);

	len += OID_LEN;

	indata = asn1_enc(outdata, 0x30, &len);

	if (indata == NULL) {
		fprintf(stderr,"not enough free memory\n");
		return 1;
	}

	if ((fp = fopen(argv[2],"w")) == NULL) {
		fprintf(stderr,"could not open output file \"%s\"\n", argv[2]);
		return 1;
	}

	fwrite(indata,1,len,fp);
	fclose(fp);

	printf("done. wrote %d octets\n", len);
	
	return 0;
}

