blob: eec0ee84943fa08f28b31d8a29529f1b845cacf6 [file] [log] [blame]
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package asn1
import (
"bytes"
"fmt"
"io"
"os"
"reflect"
"strings"
"time"
)
// A forkableWriter is an in-memory buffer that can be
// 'forked' to create new forkableWriters that bracket the
// original. After
// pre, post := w.fork();
// the overall sequence of bytes represented is logically w+pre+post.
type forkableWriter struct {
*bytes.Buffer
pre, post *forkableWriter
}
func newForkableWriter() *forkableWriter {
return &forkableWriter{bytes.NewBuffer(nil), nil, nil}
}
func (f *forkableWriter) fork() (pre, post *forkableWriter) {
f.pre = newForkableWriter()
f.post = newForkableWriter()
return f.pre, f.post
}
func (f *forkableWriter) Len() (l int) {
l += f.Buffer.Len()
if f.pre != nil {
l += f.pre.Len()
}
if f.post != nil {
l += f.post.Len()
}
return
}
func (f *forkableWriter) writeTo(out io.Writer) (n int, err os.Error) {
n, err = out.Write(f.Bytes())
if err != nil {
return
}
var nn int
if f.pre != nil {
nn, err = f.pre.writeTo(out)
n += nn
if err != nil {
return
}
}
if f.pre != nil {
nn, err = f.post.writeTo(out)
n += nn
}
return
}
func marshalBase128Int(out *forkableWriter, i int64) (err os.Error) {
if i == 0 {
err = out.WriteByte(0)
return
}
for i > 0 {
next := i >> 7
o := byte(i & 0x7f)
if next > 0 {
o |= 0x80
}
err = out.WriteByte(o)
if err != nil {
return
}
i = next
}
return nil
}
func base128Length(i int) (numBytes int) {
if i == 0 {
return 1
}
for i > 0 {
numBytes++
i >>= 7
}
return
}
func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err os.Error) {
b := uint8(t.class) << 6
if t.isCompound {
b |= 0x20
}
if t.tag >= 31 {
b |= 0x1f
err = out.WriteByte(b)
if err != nil {
return
}
err = marshalBase128Int(out, int64(t.tag))
if err != nil {
return
}
} else {
b |= uint8(t.tag)
err = out.WriteByte(b)
if err != nil {
return
}
}
if t.length >= 128 {
err = out.WriteByte(byte(base128Length(t.length)))
if err != nil {
return
}
err = marshalBase128Int(out, int64(t.length))
if err != nil {
return
}
} else {
err = out.WriteByte(byte(t.length))
if err != nil {
return
}
}
return nil
}
func marshalBitString(out *forkableWriter, b BitString) (err os.Error) {
paddingBits := byte((8 - b.BitLength%8) % 8)
err = out.WriteByte(paddingBits)
if err != nil {
return
}
_, err = out.Write(b.Bytes)
return
}
func marshalObjectIdentifier(out *forkableWriter, oid []int) (err os.Error) {
if len(oid) < 2 || oid[0] > 6 || oid[1] >= 40 {
return StructuralError{"invalid object identifier"}
}
err = out.WriteByte(byte(oid[0]*40 + oid[1]))
if err != nil {
return
}
for i := 2; i < len(oid); i++ {
err = marshalBase128Int(out, int64(oid[i]))
if err != nil {
return
}
}
return
}
func marshalPrintableString(out *forkableWriter, s string) (err os.Error) {
b := strings.Bytes(s)
for _, c := range b {
if !isPrintable(c) {
return StructuralError{"PrintableString contains invalid character"}
}
}
_, err = out.Write(b)
return
}
func marshalIA5String(out *forkableWriter, s string) (err os.Error) {
b := strings.Bytes(s)
for _, c := range b {
if c > 127 {
return StructuralError{"IA5String contains invalid character"}
}
}
_, err = out.Write(b)
return
}
func marshalTwoDigits(out *forkableWriter, v int) (err os.Error) {
err = out.WriteByte(byte('0' + (v/10)%10))
if err != nil {
return
}
return out.WriteByte(byte('0' + v%10))
}
func marshalUTCTime(out *forkableWriter, t *time.Time) (err os.Error) {
switch {
case 1950 <= t.Year && t.Year < 2000:
err = marshalTwoDigits(out, int(t.Year-1900))
case 2000 <= t.Year && t.Year < 2050:
err = marshalTwoDigits(out, int(t.Year-2000))
default:
return StructuralError{"Cannot represent time as UTCTime"}
}
if err != nil {
return
}
err = marshalTwoDigits(out, t.Month)
if err != nil {
return
}
err = marshalTwoDigits(out, t.Day)
if err != nil {
return
}
err = marshalTwoDigits(out, t.Hour)
if err != nil {
return
}
err = marshalTwoDigits(out, t.Minute)
if err != nil {
return
}
err = marshalTwoDigits(out, t.Second)
if err != nil {
return
}
switch {
case t.ZoneOffset/60 == 0:
err = out.WriteByte('Z')
return
case t.ZoneOffset > 0:
err = out.WriteByte('+')
case t.ZoneOffset < 0:
err = out.WriteByte('-')
}
if err != nil {
return
}
offsetMinutes := t.ZoneOffset / 60
if offsetMinutes < 0 {
offsetMinutes = -offsetMinutes
}
err = marshalTwoDigits(out, offsetMinutes/60)
if err != nil {
return
}
err = marshalTwoDigits(out, offsetMinutes%60)
return
}
func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err os.Error) {
switch value.Type() {
case timeType:
return marshalUTCTime(out, value.Interface().(*time.Time))
case bitStringType:
return marshalBitString(out, value.Interface().(BitString))
case objectIdentifierType:
return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
}
switch v := value.(type) {
case *reflect.BoolValue:
if v.Get() {
return out.WriteByte(1)
} else {
return out.WriteByte(0)
}
case *reflect.IntValue:
return marshalBase128Int(out, int64(v.Get()))
case *reflect.Int64Value:
return marshalBase128Int(out, v.Get())
case *reflect.StructValue:
t := v.Type().(*reflect.StructType)
for i := 0; i < t.NumField(); i++ {
err = marshalField(out, v.Field(i), parseFieldParameters(t.Field(i).Tag))
if err != nil {
return
}
}
return
case *reflect.SliceValue:
sliceType := v.Type().(*reflect.SliceType)
if _, ok := sliceType.Elem().(*reflect.Uint8Type); ok {
bytes := make([]byte, v.Len())
for i := 0; i < v.Len(); i++ {
bytes[i] = v.Elem(i).(*reflect.Uint8Value).Get()
}
_, err = out.Write(bytes)
return
}
var params fieldParameters
for i := 0; i < v.Len(); i++ {
err = marshalField(out, v.Elem(i), params)
if err != nil {
return
}
}
return
case *reflect.StringValue:
if params.stringType == tagIA5String {
return marshalIA5String(out, v.Get())
} else {
return marshalPrintableString(out, v.Get())
}
return
}
return StructuralError{"unknown Go type"}
}
func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err os.Error) {
tag, isCompound, ok := getUniversalType(v.Type())
if !ok {
err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
return
}
class := classUniversal
if params.stringType != 0 {
if tag != tagPrintableString {
return StructuralError{"Explicit string type given to non-string member"}
}
tag = params.stringType
}
tags, body := out.fork()
err = marshalBody(body, v, params)
if err != nil {
return
}
bodyLen := body.Len()
var explicitTag *forkableWriter
if params.explicit {
explicitTag, tags = tags.fork()
}
if !params.explicit && params.tag != nil {
// implicit tag.
tag = *params.tag
class = classContextSpecific
}
err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound})
if err != nil {
return
}
if params.explicit {
err = marshalTagAndLength(explicitTag, tagAndLength{
class: classContextSpecific,
tag: *params.tag,
length: bodyLen + tags.Len(),
isCompound: true,
})
}
return nil
}
// Marshal serialises val as an ASN.1 structure and writes the result to out.
// In the case of an error, no output is produced.
func Marshal(out io.Writer, val interface{}) os.Error {
v := reflect.NewValue(val)
f := newForkableWriter()
err := marshalField(f, v, fieldParameters{})
if err != nil {
return err
}
_, err = f.writeTo(out)
return err
}