Skip to content

Add custom prefix to autogenerated C functions #268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 82 additions & 60 deletions cmd/mkcgo/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ func generateGo(src *mkcgo.Source, w io.Writer) {
// This block outputs C header includes and forward declarations for loader functions.
fmt.Fprintf(w, "/*\n")
fmt.Fprintf(w, "#cgo CFLAGS: -Wno-attributes\n\n")
for _, file := range src.Files {
fmt.Fprintf(w, "#include %q\n", file)
}
if *includeHeader != "" {
fmt.Fprintf(w, "#include \"%s\"\n", *includeHeader)
}
Expand All @@ -39,10 +36,10 @@ func generateGo(src *mkcgo.Source, w io.Writer) {
// Generate Go wrapper functions that load and unload the C symbols.
for _, tag := range src.Tags() {
fmt.Fprintf(w, "func mkcgoLoad_%s(handle unsafe.Pointer) {\n", tag)
fmt.Fprintf(w, "\tC.__mkcgoLoad_%s(handle)\n", tag)
fmt.Fprintf(w, "\tC.__mkcgo_load_%s(handle)\n", tag)
fmt.Fprintf(w, "}\n\n")
fmt.Fprintf(w, "func mkcgoUnload_%s() {\n", tag)
fmt.Fprintf(w, "\tC.__mkcgoUnload_%s()\n", tag)
fmt.Fprintf(w, "\tC.__mkcgo_unload_%s()\n", tag)
fmt.Fprintf(w, "}\n\n")
}

Expand All @@ -57,14 +54,13 @@ func generateGo(src *mkcgo.Source, w io.Writer) {

// Generate function wrappers.
for _, fn := range src.Funcs {
if fn.Variadic() {
// cgo doesn't support variadic functions
if !fnCalledFromGo(fn) {
continue
}
if fn.Optional {
// Generate a function that returns true if the function is available.
fmt.Fprintf(w, "func %s_Available() bool {\n", fn.GoName)
fmt.Fprintf(w, "\treturn C.%s_Available() != 0\n", fn.ImportName)
fmt.Fprintf(w, "func %s() bool {\n", fnGoNameAvailable(fn))
fmt.Fprintf(w, "\treturn C.%s() != 0\n", fnCNameAvailable(fn))
fmt.Fprintf(w, "}\n\n")
}
generateGoFn(fn, w)
Expand All @@ -81,10 +77,7 @@ func generateGo124(src *mkcgo.Source, w io.Writer) {
// This block outputs C header includes and forward declarations for loader functions.
fmt.Fprintf(w, "/*\n")
for _, fn := range src.Funcs {
name := fn.CName
if fnNeedErrWrapper(fn) {
name = fnCErrWrapperName(fn)
}
name := fnCName(fn)
if fn.NoEscape {
fmt.Fprintf(w, "#cgo noescape %s\n", name)
}
Expand Down Expand Up @@ -155,13 +148,26 @@ func generateCHeader(src *mkcgo.Source, w io.Writer) {
fmt.Fprintf(w, "#ifndef MKCGO_H // only include this header once\n")
fmt.Fprintf(w, "#define MKCGO_H\n\n")

for _, file := range src.Files {
fmt.Fprintf(w, "#include %q\n", file)
}
if *includeHeader != "" {
fmt.Fprintf(w, "#include \"%s\"\n", *includeHeader)
fmt.Fprintf(w, "#include %q\n\n", *includeHeader)
}

// Source includes.
for _, v := range src.Includes {
fmt.Fprintf(w, "#include %s\n", v)
}
fmt.Fprintf(w, "\n")

// Source types and enums.
for _, def := range src.TypeDefs {
fmt.Fprintf(w, "typedef %s %s;\n", def.Type, def.Name)
}
fmt.Fprintf(w, "\n")
fmt.Fprintf(w, "enum {\n")
for _, enum := range src.Enums {
fmt.Fprintf(w, "\t%s = %s,\n", enum.Name, enum.Value)
}
fmt.Fprintf(w, "};\n\n")

// Custom types
fmt.Fprintf(w, "typedef void* %s;\n", mkcgoErrState)
Expand All @@ -171,25 +177,30 @@ func generateCHeader(src *mkcgo.Source, w io.Writer) {

// Add forward declarations for loader functions.
for _, tag := range src.Tags() {
fmt.Fprintf(w, "void __mkcgoLoad_%s(void* handle);\n", tag)
fmt.Fprintf(w, "void __mkcgoUnload_%s();\n", tag)
fmt.Fprintf(w, "void __mkcgo_load_%s(void* handle);\n", tag)
fmt.Fprintf(w, "void __mkcgo_unload_%s();\n", tag)
}
fmt.Fprintf(w, "\n")

// Add forward declarations for optional functions.
for _, fn := range src.Funcs {
if fn.Optional {
fmt.Fprintf(w, "int %s_Available();\n", fn.ImportName)
fmt.Fprintf(w, "int %s();\n", fnCNameAvailable(fn))
}
}
fmt.Fprintf(w, "\n")

// Add forward declarations for function wrappers returning errors.
for _, fn := range src.Funcs {
if !fnNeedErrWrapper(fn) {
if !fnCalledFromGo(fn) {
// cgo doesn't support variadic functions, no need to include them.
continue
}
fmt.Fprintf(w, "%s %s(%s);\n", fn.Ret.Type, fnCErrWrapperName(fn), fnCErrWrapperParams(fn, false))
if fnNeedErrWrapper(fn) {
fmt.Fprintf(w, "%s %s(%s);\n", fn.Ret.Type, fnCName(fn), fnCErrWrapperParams(fn, false))
} else {
fmt.Fprintf(w, "%s %s(%s);\n", fn.Ret.Type, fnCName(fn), fnToCArgs(fn, true, false))
}
}
fmt.Fprintf(w, "\n")
fmt.Fprintf(w, "#endif // MKCGO_H\n")
Expand All @@ -204,11 +215,8 @@ func generateC(src *mkcgo.Source, w io.Writer) {
fmt.Fprintf(w, "#include <stdlib.h>\n")
fmt.Fprintf(w, "#include <stdint.h>\n")
fmt.Fprintf(w, "#include <stdio.h>\n")
for _, file := range src.Files {
fmt.Fprintf(w, "#include %q\n", file)
}
if *includeHeader != "" {
fmt.Fprintf(w, "#include \"%s\"\n", *includeHeader)
fmt.Fprintf(w, "#include %q\n", *includeHeader)
}
fmt.Fprintf(w, "#include \"%s\"\n", autogeneratedFileName(".h"))
fmt.Fprintf(w, "\n")
Expand All @@ -222,12 +230,11 @@ func generateC(src *mkcgo.Source, w io.Writer) {
fmt.Fprintf(w, "#endif\n\n")

// Function pointer declarations.
fmt.Fprintf(w, "#define __mkcgo__funcptr(name) typeof(name) *_g_##name;\n\n")
for _, fn := range src.Funcs {
if fn.VariadicInst {
continue
}
fmt.Fprintf(w, "__mkcgo__funcptr(%s);\n", fn.ImportName)
fmt.Fprintf(w, "%s (*_g_%s)(%s);\n", fn.Ret.Type, fn.ImportName, fnToCArgs(fn, true, false))
}
fmt.Fprintf(w, "\n")

Expand All @@ -242,7 +249,7 @@ func generateC(src *mkcgo.Source, w io.Writer) {

// Loader and unloader functions for each tag.
for _, tag := range src.Tags() {
fmt.Fprintf(w, "void __mkcgoLoad_%s(void* handle) {\n", tag)
fmt.Fprintf(w, "void __mkcgo_load_%s(void* handle) {\n", tag)
for _, fn := range src.Funcs {
if fn.VariadicInst {
continue
Expand All @@ -267,7 +274,7 @@ func generateC(src *mkcgo.Source, w io.Writer) {
}
fmt.Fprintf(w, "}\n\n")

fmt.Fprintf(w, "void __mkcgoUnload_%s() {\n", tag)
fmt.Fprintf(w, "void __mkcgo_unload_%s() {\n", tag)
for _, fn := range src.Funcs {
if fn.VariadicInst {
continue
Expand All @@ -292,26 +299,25 @@ func generateC(src *mkcgo.Source, w io.Writer) {
typedefs[def.Name] = def.Type
}
for _, fn := range src.Funcs {
if fn.Variadic() {
// cgo doesn't support variadic functions
if !fnCalledFromGo(fn) {
// cgo doesn't support variadic functions, no need to include them.
continue
}
if fn.Optional {
// Generate a function that returns true if the function is available.
fmt.Fprintf(w, "int %s_Available() {\n", fn.CName)
fmt.Fprintf(w, "int %s() {\n", fnCNameAvailable(fn))
fmt.Fprintf(w, "\treturn _g_%s != NULL;\n", fn.ImportName)
fmt.Fprintf(w, "}\n\n")
}
generateCFn(fn, w)
generateCFnErrorWrapper(typedefs, fn, w)
generateCFn(typedefs, fn, w)
}
}

// generateGoFn generates Go function f.
func generateGoFn(fn *mkcgo.Func, w io.Writer) {
fnCall := fmt.Sprintf("C.%s(%s)", fn.CName, fnToGoArgs(fn))
fnCall := fmt.Sprintf("C.%s(%s)", fnCName(fn), fnToGoArgs(fn))
// Function definition
fmt.Fprintf(w, "func %s(%s)", fn.GoName, fnToGoParams(fn))
fmt.Fprintf(w, "func %s(%s)", fnGoName(fn), fnToGoParams(fn))
if retIsVoid(fn.Ret) {
// Easy path, just call the C function. No need to write the return types,
// nor do error handling, nor cast the return value.
Expand All @@ -320,17 +326,16 @@ func generateGoFn(fn *mkcgo.Func, w io.Writer) {
fmt.Fprintf(w, "}\n\n")
return
}
typ, _ := cTypeToGo(fn.Ret.Type, false)
goType, needCast := cTypeToGo(fn.Ret.Type, false)
if fn.NoError {
fmt.Fprintf(w, " %s ", typ)
fmt.Fprintf(w, " %s ", goType)
} else {
fmt.Fprintf(w, " (%s, error) ", typ)
fmt.Fprintf(w, " (%s, error) ", goType)
}
fmt.Fprintf(w, "{\n")

// Function call
var needUnsafeCast bool
goType, needCast := cTypeToGo(fn.Ret.Type, false)
if needCast && goType[0] == '*' {
goType = fmt.Sprintf("(%s)(unsafe.Pointer", goType)
needUnsafeCast = true
Expand All @@ -351,7 +356,7 @@ func generateGoFn(fn *mkcgo.Func, w io.Writer) {
return
}
fmt.Fprintf(w, "\tvar _err C.%s\n", mkcgoErrState)
fmt.Fprintf(w, "\t_ret := C.%s(", fnCErrWrapperName(fn))
fmt.Fprintf(w, "\t_ret := C.%s(", fnCName(fn))
args := fnToGoArgs(fn)
if len(args) > 0 {
args += ", "
Expand All @@ -368,26 +373,22 @@ func generateGoFn(fn *mkcgo.Func, w io.Writer) {
} else {
fmt.Fprintf(w, "_ret")
}
fmt.Fprintf(w, ", newMkcgoErr(%q, _err)\n", fn.CName)
fmt.Fprintf(w, "}\n\n")
}

func generateCFn(fn *mkcgo.Func, w io.Writer) {
fmt.Fprintf(w, "%s %s(%s) {\n\t", fn.Ret.Type, fn.CName, fnToCArgs(fn, true, true))
if !retIsVoid(fn.Ret) {
fmt.Fprintf(w, "return ")
}
fmt.Fprintf(w, "_g_%s(%s);\n", fn.ImportName, fnToCArgs(fn, false, true))
fmt.Fprintf(w, ", newMkcgoErr(%q, _err)\n", fn.Name)
fmt.Fprintf(w, "}\n\n")
}

// generateCFnErrorWrapper generates C function wrapper for function f
// that returns an error state.
func generateCFnErrorWrapper(typedefs map[string]string, fn *mkcgo.Func, w io.Writer) {
func generateCFn(typedefs map[string]string, fn *mkcgo.Func, w io.Writer) {
if !fnNeedErrWrapper(fn) {
fmt.Fprintf(w, "%s %s(%s) {\n\t", fn.Ret.Type, fnCName(fn), fnToCArgs(fn, true, true))
if !retIsVoid(fn.Ret) {
fmt.Fprintf(w, "return ")
}
fmt.Fprintf(w, "_g_%s(%s);\n", fn.ImportName, fnToCArgs(fn, false, true))
fmt.Fprintf(w, "}\n\n")
return
}
fmt.Fprintf(w, "%s %s(%s) {\n", fn.Ret.Type, fnCErrWrapperName(fn), fnCErrWrapperParams(fn, true))

fmt.Fprintf(w, "%s %s(%s) {\n", fn.Ret.Type, fnCName(fn), fnCErrWrapperParams(fn, true))
fmt.Fprintf(w, "\tmkcgo_err_clear();\n") // clear any previous error
fmt.Fprintf(w, "\t%s _ret = _g_%s(%s);\n", fn.Ret.Type, fn.ImportName, fnToCArgs(fn, false, true))
errCond := "<= 0"
Expand Down Expand Up @@ -506,7 +507,7 @@ func cTypeToGo(t string, cgo bool) (string, bool) {
// paramToC returns C source code of parameter p.
func paramToC(i int, p *mkcgo.Param, addType, addName bool) string {
if p.Type == "..." {
return ""
return "..."
}
var s string
if addType {
Expand Down Expand Up @@ -585,12 +586,33 @@ func fnCErrWrapperParams(fn *mkcgo.Func, addName bool) string {
return args
}

// fnCErrWrapperName returns the name of the error wrapper function for function f.
func fnCErrWrapperName(fn *mkcgo.Func) string {
return "_mkcgo_err_" + fn.CName
// fnGoName returns the Go function name for function f.
func fnGoName(fn *mkcgo.Func) string {
// TODO: use a prefix that is not OpenSSL specific.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO, open an issue or fix?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #269 😄

return "go_openssl_" + fn.Name
}

func fnGoNameAvailable(fn *mkcgo.Func) string {
return fnGoName(fn) + "_Available"
}

// fnCName returns the C function name for function f.
func fnCName(fn *mkcgo.Func) string {
return "_mkcgo_" + fn.Name
}

// fnCNameAvailable returns the C function name for function f
// that checks if the function is available.
func fnCNameAvailable(fn *mkcgo.Func) string {
return "_mkcgo_available_" + fn.Name
}

// fnNeedErrWrapper reports whether function fn needs an error wrapper.
func fnNeedErrWrapper(fn *mkcgo.Func) bool {
return !fn.NoError && !retIsVoid(fn.Ret)
}

// fnCalledFromGo reports whether function fn is called from Go code.
func fnCalledFromGo(fn *mkcgo.Func) bool {
return !fn.Variadic() // cgo doesn't support variadic functions
}
9 changes: 4 additions & 5 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package openssl
/*
#include <stdlib.h> // for calloc and free
#include <string.h> // for strdup
#include "shims.h"
#include "zossl.h"

// OpenSSL only allows a maximum of 16 errors to be stored in the error queue.
Expand All @@ -22,7 +21,7 @@ typedef struct ossl_err_state_st {
// mkcgo_err_clear clears the error queue in OpenSSL.
void mkcgo_err_clear() {
// Clear the error queue.
ERR_clear_error();
_mkcgo_ERR_clear_error();
}

// mkcgo_err_retrieve retrieves the error state from OpenSSL.
Expand All @@ -37,12 +36,12 @@ mkcgo_err_state mkcgo_err_retrieve() {
// Retrieve the errors from OpenSSL.
for (int i = 0; i < ERR_NUM_MAX; i++) {
const char *file;
if (OPENSSL_version_major_Available() == 1) { // Only available in OpenSSL 3.
if (_mkcgo_available_OPENSSL_version_major() == 1) { // Only available in OpenSSL 3.
// OpenSSL 3 error handling
errs->code[i] = ERR_get_error_all(&file, &errs->line[i], NULL, NULL, NULL);
errs->code[i] = _mkcgo_ERR_get_error_all(&file, &errs->line[i], NULL, NULL, NULL);
} else {
// OpenSSL 1 error handling
errs->code[i] = ERR_get_error_line(&file, &errs->line[i]);
errs->code[i] = _mkcgo_ERR_get_error_line(&file, &errs->line[i]);
}
if (errs->code[i] == 0) {
break;
Expand Down
5 changes: 2 additions & 3 deletions hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package openssl

/*
#include "shims.h"
#include "zossl.h"
// go_hash_sum copies ctx into ctx2 and calls EVP_DigestFinal_ex using ctx2.
// This is necessary because Go hash.Hash mandates that Sum has no effect
Expand All @@ -14,9 +13,9 @@ package openssl
static inline int
go_hash_sum(const _EVP_MD_CTX_PTR ctx, _EVP_MD_CTX_PTR ctx2, unsigned char *out, mkcgo_err_state *_err_state)
{
if (_mkcgo_err_EVP_MD_CTX_copy(ctx2, ctx, _err_state) != 1)
if (_mkcgo_EVP_MD_CTX_copy(ctx2, ctx, _err_state) != 1)
return -1;
if (_mkcgo_err_EVP_DigestFinal_ex(ctx2, out, NULL, _err_state) <= 0)
if (_mkcgo_EVP_DigestFinal_ex(ctx2, out, NULL, _err_state) <= 0)
return -2;
return 1;
}
Expand Down
4 changes: 2 additions & 2 deletions internal/mkcgo/mkcgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type Source struct {
Funcs []*Func
Files []string
Comments []string // All line comments. Directives in this slice start with "#"
Includes []string // All #include directives, without the #include prefix.
}

// TypeDef describes a type definition.
Expand All @@ -28,8 +29,7 @@ type Enum struct {
// Func describes a function.
type Func struct {
FuncAttributes
GoName string
CName string
Name string
Params []*Param
Ret *Return
}
Expand Down
Loading
Loading